Source code for rectools.models.nn.transformers.torch_backbone
# Copyright 2025 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
import torch
from ..item_net import ItemNetBase
from .net_blocks import PositionalEncodingBase, TransformerLayersBase
from .similarity import SimilarityModuleBase
[docs]class TransformerBackboneBase(torch.nn.Module):
"""Base class for transformer torch backbone."""
def __init__(
self,
n_heads: int,
dropout_rate: float,
item_model: ItemNetBase,
pos_encoding_layer: PositionalEncodingBase,
transformer_layers: TransformerLayersBase,
similarity_module: SimilarityModuleBase,
use_causal_attn: bool = True,
use_key_padding_mask: bool = False,
**kwargs: tp.Any,
) -> None:
"""
Initialize transformer torch backbone.
Parameters
----------
n_heads : int
Number of attention heads.
dropout_rate : float
Probability of a hidden unit to be zeroed.
item_model : ItemNetBase
Network for item embeddings.
pos_encoding_layer : PositionalEncodingBase
Positional encoding layer.
transformer_layers : TransformerLayersBase
Transformer layers.
similarity_module : SimilarityModuleBase
Similarity module.
use_causal_attn : bool, default True
If ``True``, causal mask is used in multi-head self-attention.
use_key_padding_mask : bool, default False
If ``True``, key padding mask is used in multi-head self-attention.
**kwargs : Any
Additional keyword arguments for future extensions.
"""
super().__init__()
self.item_model = item_model
self.pos_encoding_layer = pos_encoding_layer
self.emb_dropout = torch.nn.Dropout(dropout_rate)
self.transformer_layers = transformer_layers
self.similarity_module = similarity_module
self.use_causal_attn = use_causal_attn
self.use_key_padding_mask = use_key_padding_mask
self.n_heads = n_heads
[docs] def encode_sessions(self, batch: tp.Dict[str, torch.Tensor], item_embs: torch.Tensor) -> torch.Tensor:
"""
Pass user history through item embeddings.
Add positional encoding.
Pass history through transformer blocks.
Parameters
----------
batch : Dict[str, torch.Tensor]
Dictionary containing user sessions data.
item_embs : torch.Tensor
Item embeddings.
Returns
-------
torch.Tensor. [batch_size, session_max_len, n_factors]
Encoded session embeddings.
"""
raise NotImplementedError()
[docs] def forward(
self,
batch: tp.Dict[str, torch.Tensor], # batch["x"]: [batch_size, session_max_len]
candidate_item_ids: tp.Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass to get item and session embeddings.
Get item embeddings.
Pass user sessions through transformer blocks.
Parameters
----------
batch : Dict[str, torch.Tensor]
Dictionary containing user sessions data, with "x" key containing session tensor.
candidate_item_ids : optional(torch.Tensor), default ``None``
Defined item ids for similarity calculation.
Returns
-------
torch.Tensor
"""
raise NotImplementedError()
[docs]class TransformerTorchBackbone(TransformerBackboneBase):
"""
Torch model for encoding user sessions based on transformer architecture.
Parameters
----------
n_heads : int
Number of attention heads.
dropout_rate : float
Probability of a hidden unit to be zeroed.
item_model : ItemNetBase
Network for item embeddings.
pos_encoding_layer : PositionalEncodingBase
Positional encoding layer.
transformer_layers : TransformerLayersBase
Transformer layers.
similarity_module : SimilarityModuleBase
Similarity module.
use_causal_attn : bool, default True
If ``True``, causal mask is used in multi-head self-attention.
use_key_padding_mask : bool, default False
If ``True``, key padding mask is used in multi-head self-attention.
**kwargs : Any
Additional keyword arguments for future extensions.
"""
def __init__(
self,
n_heads: int,
dropout_rate: float,
item_model: ItemNetBase,
pos_encoding_layer: PositionalEncodingBase,
transformer_layers: TransformerLayersBase,
similarity_module: SimilarityModuleBase,
use_causal_attn: bool = True,
use_key_padding_mask: bool = False,
**kwargs: tp.Any,
) -> None:
super().__init__(
n_heads=n_heads,
dropout_rate=dropout_rate,
item_model=item_model,
pos_encoding_layer=pos_encoding_layer,
transformer_layers=transformer_layers,
similarity_module=similarity_module,
use_causal_attn=use_causal_attn,
use_key_padding_mask=use_key_padding_mask,
**kwargs,
)
@staticmethod
def _convert_mask_to_float(mask: torch.Tensor, query: torch.Tensor) -> torch.Tensor:
return torch.zeros_like(mask, dtype=query.dtype).masked_fill_(mask, float("-inf"))
def _merge_masks(
self, attn_mask: torch.Tensor, key_padding_mask: torch.Tensor, query: torch.Tensor
) -> torch.Tensor:
"""
Merge `attn_mask` and `key_padding_mask` as a new `attn_mask`.
Both masks are expanded to shape ``(batch_size * n_heads, session_max_len, session_max_len)``
and combined with logical ``or``.
Diagonal elements in last two dimensions are set equal to ``0``.
This prevents nan values in gradients for pytorch < 2.5.0 when both masks are present in forward pass of
`torch.nn.MultiheadAttention` (https://github.com/pytorch/pytorch/issues/41508).
Parameters
----------
attn_mask: torch.Tensor. [session_max_len, session_max_len]
Boolean causal attention mask.
key_padding_mask: torch.Tensor. [batch_size, session_max_len]
Boolean padding mask.
query: torch.Tensor
Query tensor used to acquire correct shapes and dtype for new `attn_mask`.
Returns
-------
torch.Tensor. [batch_size * n_heads, session_max_len, session_max_len]
Merged mask to use as new `attn_mask` with zeroed diagonal elements in last 2 dimensions.
"""
batch_size, seq_len, _ = query.shape
key_padding_mask_expanded = self._convert_mask_to_float( # [batch_size, session_max_len]
key_padding_mask, query
).view(
batch_size, 1, seq_len
) # [batch_size, 1, session_max_len]
attn_mask_expanded = (
self._convert_mask_to_float(attn_mask, query) # [session_max_len, session_max_len]
.view(1, seq_len, seq_len)
.expand(batch_size, -1, -1)
) # [batch_size, session_max_len, session_max_len]
merged_mask = attn_mask_expanded + key_padding_mask_expanded
res = (
merged_mask.view(batch_size, 1, seq_len, seq_len)
.expand(-1, self.n_heads, -1, -1)
.reshape(-1, seq_len, seq_len)
) # [batch_size * n_heads, session_max_len, session_max_len]
torch.diagonal(res, dim1=1, dim2=2).zero_()
return res
[docs] def encode_sessions(self, batch: tp.Dict[str, torch.Tensor], item_embs: torch.Tensor) -> torch.Tensor:
"""
Pass user history through item embeddings.
Add positional encoding.
Pass history through transformer blocks.
Parameters
----------
batch : Dict[str, torch.Tensor]
Dictionary containing user sessions data.
item_embs : torch.Tensor
Item embeddings.
Returns
-------
torch.Tensor. [batch_size, session_max_len, n_factors]
Encoded session embeddings.
"""
sessions = batch["x"] # [batch_size, session_max_len]
session_max_len = sessions.shape[1]
attn_mask = None
key_padding_mask = None
timeline_mask = (sessions != 0).unsqueeze(-1) # [batch_size, session_max_len, 1]
seqs = item_embs[sessions] # [batch_size, session_max_len, n_factors]
seqs = self.pos_encoding_layer(seqs)
seqs = self.emb_dropout(seqs)
if self.use_causal_attn:
attn_mask = ~torch.tril(
torch.ones((session_max_len, session_max_len), dtype=torch.bool, device=sessions.device)
)
if self.use_key_padding_mask:
key_padding_mask = sessions == 0
if attn_mask is not None: # merge masks to prevent nan gradients for torch < 2.5.0
attn_mask = self._merge_masks(attn_mask, key_padding_mask, seqs)
key_padding_mask = None
seqs = self.transformer_layers(seqs, timeline_mask, attn_mask, key_padding_mask, batch=batch)
return seqs
[docs] def forward(
self,
batch: tp.Dict[str, torch.Tensor], # batch["x"]: [batch_size, session_max_len]
candidate_item_ids: tp.Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass to get item and session embeddings.
Get item embeddings.
Pass user sessions through transformer blocks.
Parameters
----------
batch : Dict[str, torch.Tensor]
Dictionary containing user sessions data, with "x" key containing session tensor.
candidate_item_ids : optional(torch.Tensor), default ``None``
Defined item ids for similarity calculation.
Returns
-------
torch.Tensor
"""
item_embs = self.item_model.get_all_embeddings() # [n_items + n_item_extra_tokens, n_factors]
session_embs = self.encode_sessions(batch, item_embs) # [batch_size, session_max_len, n_factors]
logits = self.similarity_module(session_embs, item_embs, candidate_item_ids)
return logits