Source code for rectools.models.nn.transformers.net_blocks
# Copyright 2025 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
import torch
from torch import nn
[docs]class PointWiseFeedForward(nn.Module):
"""
Feed-Forward network to introduce nonlinearity into the transformer model.
This implementation is the one used by SASRec authors.
Parameters
----------
n_factors : int
Latent embeddings size.
n_factors_ff : int
How many hidden units to use in the network.
dropout_rate : float
Probability of a hidden unit to be zeroed.
activation: torch.nn.Module
Activation function module.
bias: bool, default ``True``
If ``True``, add bias to linear layers.
"""
def __init__(
self, n_factors: int, n_factors_ff: int, dropout_rate: float, activation: torch.nn.Module, bias: bool = True
) -> None:
super().__init__()
self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias)
self.ff_dropout_1 = torch.nn.Dropout(dropout_rate)
self.ff_activation = activation
self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias)
[docs] def forward(self, seqs: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Parameters
----------
seqs : torch.Tensor
User sequences of item embeddings.
Returns
-------
torch.Tensor
User sequence that passed through all layers.
"""
output = self.ff_activation(self.ff_linear_1(seqs))
fin = self.ff_linear_2(self.ff_dropout_1(output))
return fin
[docs]class SwigluFeedForward(nn.Module):
"""
Feed-Forward network to introduce nonlinearity into the transformer model.
This implementation is based on FuXi and LLama SwigLU https://arxiv.org/pdf/2502.03036,
LiGR https://arxiv.org/pdf/2502.03417
Parameters
----------
n_factors : int
Latent embeddings size.
n_factors_ff : int
How many hidden units to use in the network.
dropout_rate : float
Probability of a hidden unit to be zeroed.
bias: bool, default ``True``
If ``True``, add bias to linear layers.
"""
def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float, bias: bool = True) -> None:
super().__init__()
self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias=bias)
self.ff_dropout_1 = torch.nn.Dropout(dropout_rate)
self.ff_activation = torch.nn.SiLU()
self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias=bias)
self.ff_linear_3 = nn.Linear(n_factors, n_factors_ff, bias=bias)
[docs] def forward(self, seqs: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Parameters
----------
seqs : torch.Tensor
User sequences of item embeddings.
Returns
-------
torch.Tensor
User sequence that passed through all layers.
"""
output = self.ff_activation(self.ff_linear_1(seqs)) * self.ff_linear_3(seqs)
fin = self.ff_linear_2(self.ff_dropout_1(output))
return fin
[docs]def init_feed_forward(
n_factors: int, ff_factors_multiplier: int, dropout_rate: float, ff_activation: str, bias: bool = True
) -> nn.Module:
"""
Initialise Feed-Forward network with one of activation functions: "swiglu", "relu", "gelu".
Parameters
----------
n_factors : int
Latent embeddings size.
ff_factors_multiplier : int
How many hidden units to use in the network.
dropout_rate : float
Probability of a hidden unit to be zeroed.
ff_activation : {"swiglu", "relu", "gelu"}
Activation function to use.
bias: bool, default ``True``
If ``True``, add bias to linear layers.
Returns
-------
nn.Module
Feed-Forward network.
"""
if ff_activation == "swiglu":
return SwigluFeedForward(n_factors, n_factors * ff_factors_multiplier, dropout_rate, bias=bias)
if ff_activation == "gelu":
return PointWiseFeedForward(
n_factors, n_factors * ff_factors_multiplier, dropout_rate, activation=torch.nn.GELU(), bias=bias
)
if ff_activation == "relu":
return PointWiseFeedForward(
n_factors,
n_factors * ff_factors_multiplier,
dropout_rate,
activation=torch.nn.ReLU(),
bias=bias,
)
raise ValueError(f"Unsupported ff_activation: {ff_activation}")
[docs]class TransformerLayersBase(nn.Module):
"""Base class for transformer layers."""
[docs] def forward(
self,
seqs: torch.Tensor,
timeline_mask: torch.Tensor,
attn_mask: tp.Optional[torch.Tensor],
key_padding_mask: tp.Optional[torch.Tensor],
**kwargs: tp.Any,
) -> torch.Tensor:
"""
Forward pass through transformer blocks.
Parameters
----------
seqs: torch.Tensor
User sequences of item embeddings.
timeline_mask: torch.Tensor
Mask indicating padding elements.
attn_mask: torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `attn_mask`.
key_padding_mask: torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `key_padding_mask`.
Returns
-------
torch.Tensor
User sequences passed through transformer layers.
"""
raise NotImplementedError()
[docs]class PreLNTransformerLayer(nn.Module):
"""
Pre-LN Transformer Layer as described in "On Layer Normalization in the Transformer
Architecture" https://arxiv.org/pdf/2002.04745
Parameters
----------
n_factors: int
Latent embeddings size.
n_heads: int
Number of attention heads.
dropout_rate: float
Probability of a hidden unit to be zeroed.
ff_factors_multiplier: int
Feed-forward layers latent embedding size multiplier.
"""
def __init__(
self,
n_factors: int,
n_heads: int,
dropout_rate: float,
ff_factors_multiplier: int = 4,
):
super().__init__()
self.multi_head_attn = nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True)
self.layer_norm_1 = nn.LayerNorm(n_factors)
self.dropout_1 = nn.Dropout(dropout_rate)
self.layer_norm_2 = nn.LayerNorm(n_factors)
self.feed_forward = PointWiseFeedForward(
n_factors, n_factors * ff_factors_multiplier, dropout_rate, torch.nn.GELU()
)
self.dropout_2 = nn.Dropout(dropout_rate)
self.dropout_3 = nn.Dropout(dropout_rate)
[docs] def forward(
self,
seqs: torch.Tensor,
attn_mask: tp.Optional[torch.Tensor],
key_padding_mask: tp.Optional[torch.Tensor],
) -> torch.Tensor:
"""
Forward pass through transformer block.
Parameters
----------
seqs: torch.Tensor
User sequences of item embeddings.
attn_mask: torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `attn_mask`.
key_padding_mask: torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `key_padding_mask`.
Returns
-------
torch.Tensor
User sequences passed through transformer layers.
"""
mha_input = self.layer_norm_1(seqs)
mha_output, _ = self.multi_head_attn(
mha_input,
mha_input,
mha_input,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)
seqs = seqs + self.dropout_1(mha_output)
ff_input = self.layer_norm_2(seqs)
ff_output = self.feed_forward(ff_input)
seqs = seqs + self.dropout_2(ff_output)
seqs = self.dropout_3(seqs)
return seqs
[docs]class PreLNTransformerLayers(TransformerLayersBase):
"""
Pre-LN Transformer blocks.
Parameters
----------
n_blocks: int
Number of transformer blocks.
n_factors: int
Latent embeddings size.
n_heads: int
Number of attention heads.
dropout_rate: float
Probability of a hidden unit to be zeroed.
ff_factors_multiplier: int
Feed-forward layers latent embedding size multiplier.
"""
def __init__(
self,
n_blocks: int,
n_factors: int,
n_heads: int,
dropout_rate: float,
ff_factors_multiplier: int = 4,
**kwargs: tp.Any,
):
super().__init__()
self.n_blocks = n_blocks
self.transformer_blocks = nn.ModuleList(
[
PreLNTransformerLayer(
n_factors,
n_heads,
dropout_rate,
ff_factors_multiplier,
)
for _ in range(self.n_blocks)
]
)
[docs] def forward(
self,
seqs: torch.Tensor,
timeline_mask: torch.Tensor,
attn_mask: tp.Optional[torch.Tensor],
key_padding_mask: tp.Optional[torch.Tensor],
**kwargs: tp.Any,
) -> torch.Tensor:
"""
Forward pass through transformer blocks.
Parameters
----------
seqs: torch.Tensor
User sequences of item embeddings.
timeline_mask: torch.Tensor
Mask indicating padding elements.
attn_mask: torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `attn_mask`.
key_padding_mask: torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `key_padding_mask`.
Returns
-------
torch.Tensor
User sequences passed through transformer layers.
"""
for block_idx in range(self.n_blocks):
seqs = self.transformer_blocks[block_idx](seqs, attn_mask, key_padding_mask)
return seqs
[docs]class PositionalEncodingBase(torch.nn.Module):
"""Base class for positional encoding."""
[docs] def forward(self, sessions: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
raise NotImplementedError()
[docs]class LearnableInversePositionalEncoding(PositionalEncodingBase):
"""
Class to introduce learnable positional embeddings.
Parameters
----------
use_pos_emb : bool
If ``True``, learnable positional encoding will be added to session item embeddings.
session_max_len : int
Maximum length of user sequence.
n_factors : int
Latent embeddings size.
use_scale_factor : int
Use multiplication embedding on the root of the dimension embedding
"""
def __init__(
self,
use_pos_emb: bool,
session_max_len: int,
n_factors: int,
use_scale_factor: bool = False,
**kwargs: tp.Any,
):
super().__init__()
self.pos_emb = torch.nn.Embedding(session_max_len, n_factors) if use_pos_emb else None
self.use_scale_factor = use_scale_factor
[docs] def forward(self, sessions: torch.Tensor) -> torch.Tensor:
"""
Forward pass to add learnable positional encoding to sessions and mask padding elements.
Parameters
----------
sessions : torch.Tensor
User sessions in the form of sequences of items ids.
Returns
-------
torch.Tensor
Encoded user sessions with added positional encoding if `use_pos_emb` is ``True``.
"""
batch_size, session_max_len, n_factors = sessions.shape
if self.use_scale_factor:
sessions = sessions * (n_factors**0.5)
if self.pos_emb is not None:
# Inverse positions are appropriate for variable length sequences across different batches
# They are equal to absolute positions for fixed sequence length across different batches
positions = torch.tile(
torch.arange(session_max_len - 1, -1, -1), (batch_size, 1)
) # [batch_size, session_max_len]
sessions += self.pos_emb(positions.to(sessions.device))
return sessions