Source code for rectools.models.nn.transformers.sasrec

#  Copyright 2025 MTS (Mobile Telesystems)
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import typing as tp
from typing import Dict

import numpy as np
import torch
from torch import nn

from ..item_net import (
    CatFeaturesItemNet,
    IdEmbeddingsItemNet,
    ItemNetBase,
    ItemNetConstructorBase,
    SumOfEmbeddingsConstructor,
)
from .base import (
    TrainerCallable,
    TransformerDataPreparatorType,
    TransformerLayersType,
    TransformerLightningModule,
    TransformerLightningModuleBase,
    TransformerModelBase,
    TransformerModelConfig,
    ValMaskCallable,
)
from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from .net_blocks import (
    LearnableInversePositionalEncoding,
    PointWiseFeedForward,
    PositionalEncodingBase,
    TransformerLayersBase,
)
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone


[docs]class SASRecDataPreparator(TransformerDataPreparatorBase): """Data preparator for SASRecModel. Parameters ---------- session_max_len : int Maximum length of user sequence. batch_size : int How many samples per batch to load. dataloader_num_workers : int Number of loader worker processes. item_extra_tokens : Sequence(Hashable) Which element to use for sequence padding. shuffle_train : bool, default True If ``True``, reshuffles data at each epoch. train_min_user_interactions : int, default 2 Minimum length of user sequence. Cannot be less than 2. get_val_mask_func : Callable, default None Function to get validation mask. n_negatives : optional(int), default ``None`` Number of negatives for BCE, gBCE and sampled_softmax losses. negative_sampler: optional(TransformerNegativeSamplerBase), default ``None`` Negative sampler. get_val_mask_func_kwargs: optional(InitKwargs), default ``None`` Additional arguments for the get_val_mask_func. Make sure all dict values have JSON serializable types. extra_cols: list(str) | None, default ``None`` Additional columns from dataset to keep beside of Columns.Inreractions add_unix_ts: bool, default ``False`` Add extra column ``unix_ts`` contains Column.Datetime converted to seconds from the beginning of the epoch """ train_session_max_len_addition: int = 1 def _collate_fn_train( self, batch: tp.List[BatchElement], ) -> Dict[str, torch.Tensor]: """ Truncate each session from right to keep `session_max_len` items. Do left padding until `session_max_len` is reached. Split to `x`, `y`, and `yw`. """ batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, self.session_max_len)) yw = np.zeros((batch_size, self.session_max_len)) for i, (ses, ses_weights, _) in enumerate(batch): x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len] y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len] yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} if self.negative_sampler is not None: batch_dict["negatives"] = self.negative_sampler.get_negatives( batch_dict, lowest_id=self.n_item_extra_tokens, highest_id=self.item_id_map.size ) if self.add_unix_ts: t = np.zeros((batch_size, self.session_max_len + 1)) # +1 target item timestamp for i, (ses, _, extras) in enumerate(batch): t[i, -len(ses) :] = extras["unix_ts"] len_to_pad = self.session_max_len + 1 - len(ses) if len_to_pad > 0: t[i, :len_to_pad] = t[i, len_to_pad] batch_dict.update({"unix_ts": torch.LongTensor(t)}) return batch_dict def _collate_fn_val(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses yw = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses for i, (ses, ses_weights, _) in enumerate(batch): input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0] # take only first target for leave-one-strategy target_idx = [idx for idx, weight in enumerate(ses_weights) if weight != 0][0] # ses: [session_len] -> x[i]: [session_max_len] x[i, -len(input_session) :] = input_session[-self.session_max_len :] y[i, -1:] = ses[target_idx] # y[i]: [1] yw[i, -1:] = ses_weights[target_idx] # yw[i]: [1] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} if self.negative_sampler is not None: batch_dict["negatives"] = self.negative_sampler.get_negatives( batch_dict, lowest_id=self.n_item_extra_tokens, highest_id=self.item_id_map.size, session_len_limit=1 ) if self.add_unix_ts: t = np.zeros((batch_size, self.session_max_len + 1)) # +1 target item timestamp for i, (ses, _, extras) in enumerate(batch): t[i, -len(ses) + 1 :] = extras["unix_ts"][1:] len_to_pad = self.session_max_len + 2 - len(ses) if len_to_pad > 0: t[i, :len_to_pad] = t[i, len_to_pad] batch_dict.update({"unix_ts": torch.LongTensor(t)}) return batch_dict def _collate_fn_recommend(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]: """Right truncation, left padding to session_max_len""" batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) if self.add_unix_ts: t = np.zeros((batch_size, self.session_max_len + 1)) for i, (ses, _, extras) in enumerate(batch): ses = ses[:-1] # drop dummy item x[i, -len(ses) :] = ses[-self.session_max_len :] t[i, -(len(ses) + 1) :] = extras["unix_ts"][-(self.session_max_len + 1) :] len_to_pad = self.session_max_len - len(ses) if len_to_pad > 0: t[i, :len_to_pad] = t[i, len_to_pad] return {"x": torch.LongTensor(x), "unix_ts": torch.LongTensor(t)} for i, (ses, _, _) in enumerate(batch): x[i, -len(ses) :] = ses[-self.session_max_len :] return {"x": torch.LongTensor(x)}
[docs]class SASRecTransformerLayer(nn.Module): """ Exactly SASRec author's transformer block architecture but with pytorch Multi-Head Attention realisation. Parameters ---------- n_factors : int Latent embeddings size. n_heads : int Number of attention heads. dropout_rate : float Probability of a hidden unit to be zeroed. """ def __init__( self, n_factors: int, n_heads: int, dropout_rate: float, ): super().__init__() # important: original architecture had another version of MHA self.multi_head_attn = torch.nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True) self.q_layer_norm = nn.LayerNorm(n_factors) self.ff_layer_norm = nn.LayerNorm(n_factors) self.feed_forward = PointWiseFeedForward(n_factors, n_factors, dropout_rate, torch.nn.ReLU()) self.dropout = torch.nn.Dropout(dropout_rate)
[docs] def forward( self, seqs: torch.Tensor, attn_mask: tp.Optional[torch.Tensor], key_padding_mask: tp.Optional[torch.Tensor], ) -> torch.Tensor: """ Forward pass through transformer block. Parameters ---------- seqs : torch.Tensor User sequences of item embeddings. attn_mask : torch.Tensor, optional Optional mask to use in forward pass of multi-head attention as `attn_mask`. key_padding_mask : torch.Tensor, optional Optional mask to use in forward pass of multi-head attention as `key_padding_mask`. Returns ------- torch.Tensor User sequences passed through transformer layers. """ q = self.q_layer_norm(seqs) mha_output, _ = self.multi_head_attn( q, seqs, seqs, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False ) seqs = q + mha_output ff_input = self.ff_layer_norm(seqs) seqs = self.feed_forward(ff_input) seqs = self.dropout(seqs) seqs += ff_input return seqs
[docs]class SASRecTransformerLayers(TransformerLayersBase): """ SASRec transformer blocks. Parameters ---------- n_blocks : int Number of transformer blocks. n_factors : int Latent embeddings size. n_heads : int Number of attention heads. dropout_rate : float Probability of a hidden unit to be zeroed. """ def __init__( self, n_blocks: int, n_factors: int, n_heads: int, dropout_rate: float, **kwargs: tp.Any, ): super().__init__() self.n_blocks = n_blocks self.transformer_blocks = nn.ModuleList( [ SASRecTransformerLayer( n_factors, n_heads, dropout_rate, ) for _ in range(self.n_blocks) ] ) self.last_layernorm = torch.nn.LayerNorm(n_factors, eps=1e-8)
[docs] def forward( self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: tp.Optional[torch.Tensor], key_padding_mask: tp.Optional[torch.Tensor], **kwargs: tp.Any, ) -> torch.Tensor: """ Forward pass through transformer blocks. Parameters ---------- seqs : torch.Tensor User sequences of item embeddings. timeline_mask : torch.Tensor Mask indicating padding elements. attn_mask : torch.Tensor, optional Optional mask to use in forward pass of multi-head attention as `attn_mask`. key_padding_mask : torch.Tensor, optional Optional mask to use in forward pass of multi-head attention as `key_padding_mask`. Returns ------- torch.Tensor User sequences passed through transformer layers. """ for i in range(self.n_blocks): seqs *= timeline_mask # [batch_size, session_max_len, n_factors] seqs = self.transformer_blocks[i](seqs, attn_mask, key_padding_mask) seqs *= timeline_mask seqs = self.last_layernorm(seqs) return seqs
[docs]class SASRecModelConfig(TransformerModelConfig): """SASRecModel config.""" data_preparator_type: TransformerDataPreparatorType = SASRecDataPreparator transformer_layers_type: TransformerLayersType = SASRecTransformerLayers use_causal_attn: bool = True
[docs]class SASRecModel(TransformerModelBase[SASRecModelConfig]): """ SASRec model: transformer-based sequential model with unidirectional attention mechanism and "Shifted Sequence" training objective. Our implementation covers multiple loss functions and a variable number of negatives for them. References ---------- Transformers tutorial: https://rectools.readthedocs.io/en/stable/examples/tutorials/transformers_tutorial.html Advanced training guide: https://rectools.readthedocs.io/en/stable/examples/tutorials/transformers_advanced_training_guide.html Public benchmark: https://github.com/blondered/bert4rec_repro Original SASRec paper: https://arxiv.org/abs/1808.09781 gBCE loss and gSASRec paper: https://arxiv.org/pdf/2308.07192 Parameters ---------- n_blocks : int, default 2 Number of transformer blocks. n_heads : int, default 4 Number of attention heads. n_factors : int, default 256 Latent embeddings size. dropout_rate : float, default 0.2 Probability of a hidden unit to be zeroed. session_max_len : int, default 100 Maximum length of user sequence. train_min_user_interactions : int, default 2 Minimum number of interactions user should have to be used for training. Should be greater than 1. loss : {"softmax", "BCE", "gBCE", "sampled_softmax"}, default "softmax" Loss function. n_negatives : int, default 1 Number of negatives for BCE, gBCE and sampled_softmax losses. gbce_t : float, default 0.2 Calibration parameter for gBCE loss. lr : float, default 0.001 Learning rate. batch_size : int, default 128 How many samples per batch to load. epochs : int, default 3 Exact number of training epochs. Will be omitted if `get_trainer_func` is specified. deterministic : bool, default ``False`` `deterministic` flag passed to lightning trainer during initialization. Use `pytorch_lightning.seed_everything` together with this parameter to fix the random seed. Will be omitted if `get_trainer_func` is specified. verbose : int, default 0 Verbosity level. Enables progress bar, model summary and logging in default lightning trainer when set to a positive integer. Will be omitted if `get_trainer_func` is specified. dataloader_num_workers : int, default 0 Number of loader worker processes. use_pos_emb : bool, default ``True`` If ``True``, learnable positional encoding will be added to session item embeddings. use_key_padding_mask : bool, default ``False`` If ``True``, key_padding_mask will be added in Multi-head Attention. use_causal_attn : bool, default ``True`` If ``True``, causal mask will be added as attn_mask in Multi-head Attention. Please note that default SASRec training task ("Shifted Sequence") does not work without causal masking. Set this parameter to ``False`` only when you change the training task with custom `data_preparator_type` or if you are absolutely sure of what you are doing. item_net_block_types : sequence of `type(ItemNetBase)`, default `(IdEmbeddingsItemNet, CatFeaturesItemNet)` Type of network returning item embeddings. (IdEmbeddingsItemNet,) - item embeddings based on ids. (CatFeaturesItemNet,) - item embeddings based on categorical features. (IdEmbeddingsItemNet, CatFeaturesItemNet) - item embeddings based on ids and categorical features. item_net_constructor_type : type(ItemNetConstructorBase), default `SumOfEmbeddingsConstructor` Type of item net blocks aggregation constructor. pos_encoding_type : type(PositionalEncodingBase), default `LearnableInversePositionalEncoding` Type of positional encoding. transformer_layers_type : type(TransformerLayersBase), default `SasRecTransformerLayers` Type of transformer layers architecture. data_preparator_type : type(TransformerDataPreparatorBase), default `SasRecDataPreparator` Type of data preparator used for dataset processing and dataloader creation. lightning_module_type : type(TransformerLightningModuleBase), default `TransformerLightningModule` Type of lightning module defining training procedure. negative_sampler_type: type(TransformerNegativeSamplerBase), default `CatalogUniformSampler` Type of negative sampler. similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` Type of similarity module. backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone` Type of torch backbone. get_val_mask_func : Callable, default ``None`` Function to get validation mask. get_trainer_func : Callable, default ``None`` Function for get custom lightning trainer. If `get_trainer_func` is None, default trainer will be created based on `epochs`, `deterministic` and `verbose` argument values. Model will be trained for the exact number of epochs. Checkpointing will be disabled. If you want to assign custom trainer after model is initialized, you can manually assign new value to model `_trainer` attribute. recommend_batch_size : int, default 256 How many samples per batch to load during `recommend`. If you want to change this parameter after model is initialized, you can manually assign new value to model `recommend_batch_size` attribute. recommend_torch_device : {"cpu", "cuda", "cuda:0", ...}, default ``None`` String representation for `torch.device` used for model inference. When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise. If you want to change this parameter after model is initialized, you can manually assign new value to model `recommend_torch_device` attribute. get_val_mask_func_kwargs: optional(InitKwargs), default ``None`` Additional keyword arguments for the get_val_mask_func. Make sure all dict values have JSON serializable types. get_trainer_func_kwargs: optional(InitKwargs), default ``None`` Additional keyword arguments for the get_trainer_func. Make sure all dict values have JSON serializable types. data_preparator_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `data_preparator_type` initialization. Make sure all dict values have JSON serializable types. transformer_layers_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `transformer_layers_type` initialization. Make sure all dict values have JSON serializable types. item_net_constructor_kwargs optional(dict), default ``None`` Additional keyword arguments to pass during `item_net_constructor_type` initialization. Make sure all dict values have JSON serializable types. pos_encoding_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `pos_encoding_type` initialization. Make sure all dict values have JSON serializable types. lightning_module_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `lightning_module_type` initialization. Make sure all dict values have JSON serializable types. negative_sampler_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `negative_sampler_type` initialization. Make sure all dict values have JSON serializable types. similarity_module_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `similarity_module_type` initialization. Make sure all dict values have JSON serializable types. backbone_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `backbone_type` initialization. Make sure all dict values have JSON serializable types. """ config_class = SASRecModelConfig def __init__( # pylint: disable=too-many-arguments, too-many-locals self, n_blocks: int = 2, n_heads: int = 4, n_factors: int = 256, dropout_rate: float = 0.2, session_max_len: int = 100, train_min_user_interactions: int = 2, loss: str = "softmax", n_negatives: int = 1, gbce_t: float = 0.2, lr: float = 0.001, batch_size: int = 128, epochs: int = 3, deterministic: bool = False, verbose: int = 0, dataloader_num_workers: int = 0, use_pos_emb: bool = True, use_key_padding_mask: bool = False, use_causal_attn: bool = True, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor, pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, transformer_layers_type: tp.Type[TransformerLayersBase] = SASRecTransformerLayers, # SASRec authors net data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, negative_sampler_type: tp.Type[TransformerNegativeSamplerBase] = CatalogUniformSampler, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, get_trainer_func_kwargs: tp.Optional[InitKwargs] = None, recommend_batch_size: int = 256, recommend_torch_device: tp.Optional[str] = None, recommend_use_torch_ranking: bool = True, recommend_n_threads: int = 0, data_preparator_kwargs: tp.Optional[InitKwargs] = None, transformer_layers_kwargs: tp.Optional[InitKwargs] = None, item_net_constructor_kwargs: tp.Optional[InitKwargs] = None, pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, negative_sampler_kwargs: tp.Optional[InitKwargs] = None, similarity_module_kwargs: tp.Optional[InitKwargs] = None, backbone_kwargs: tp.Optional[InitKwargs] = None, ): super().__init__( transformer_layers_type=transformer_layers_type, data_preparator_type=data_preparator_type, n_blocks=n_blocks, n_heads=n_heads, n_factors=n_factors, use_pos_emb=use_pos_emb, use_causal_attn=use_causal_attn, use_key_padding_mask=use_key_padding_mask, dropout_rate=dropout_rate, session_max_len=session_max_len, dataloader_num_workers=dataloader_num_workers, batch_size=batch_size, loss=loss, n_negatives=n_negatives, gbce_t=gbce_t, lr=lr, epochs=epochs, verbose=verbose, deterministic=deterministic, recommend_batch_size=recommend_batch_size, recommend_torch_device=recommend_torch_device, recommend_n_threads=recommend_n_threads, recommend_use_torch_ranking=recommend_use_torch_ranking, train_min_user_interactions=train_min_user_interactions, similarity_module_type=similarity_module_type, item_net_block_types=item_net_block_types, item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, lightning_module_type=lightning_module_type, negative_sampler_type=negative_sampler_type, backbone_type=backbone_type, get_val_mask_func=get_val_mask_func, get_trainer_func=get_trainer_func, get_val_mask_func_kwargs=get_val_mask_func_kwargs, get_trainer_func_kwargs=get_trainer_func_kwargs, data_preparator_kwargs=data_preparator_kwargs, transformer_layers_kwargs=transformer_layers_kwargs, item_net_constructor_kwargs=item_net_constructor_kwargs, pos_encoding_kwargs=pos_encoding_kwargs, lightning_module_kwargs=lightning_module_kwargs, negative_sampler_kwargs=negative_sampler_kwargs, similarity_module_kwargs=similarity_module_kwargs, backbone_kwargs=backbone_kwargs, )