TransformerBackboneBase
- class rectools.models.nn.transformers.torch_backbone.TransformerBackboneBase(n_heads: int, dropout_rate: float, item_model: ItemNetBase, pos_encoding_layer: PositionalEncodingBase, transformer_layers: TransformerLayersBase, similarity_module: SimilarityModuleBase, use_causal_attn: bool = True, use_key_padding_mask: bool = False, **kwargs: Any)[source]
Bases:
ModuleBase class for transformer torch backbone.
Methods
encode_sessions(batch, item_embs)Pass user history through item embeddings.
forward(batch[, candidate_item_ids])Forward pass to get item and session embeddings.
Attributes
- Parameters
n_heads (int) –
dropout_rate (float) –
item_model (ItemNetBase) –
pos_encoding_layer (PositionalEncodingBase) –
transformer_layers (TransformerLayersBase) –
similarity_module (SimilarityModuleBase) –
use_causal_attn (bool) –
use_key_padding_mask (bool) –
kwargs (Any) –
- encode_sessions(batch: Dict[str, Tensor], item_embs: Tensor) Tensor[source]
Pass user history through item embeddings. Add positional encoding. Pass history through transformer blocks.
- Parameters
batch (Dict[str, torch.Tensor]) – Dictionary containing user sessions data.
item_embs (torch.Tensor) – Item embeddings.
- Returns
Encoded session embeddings.
- Return type
torch.Tensor. [batch_size, session_max_len, n_factors]
- forward(batch: Dict[str, Tensor], candidate_item_ids: Optional[Tensor] = None) Tensor[source]
Forward pass to get item and session embeddings. Get item embeddings. Pass user sessions through transformer blocks.
- Parameters
batch (Dict[str, torch.Tensor]) – Dictionary containing user sessions data, with “x” key containing session tensor.
candidate_item_ids (optional(torch.Tensor), default
None) – Defined item ids for similarity calculation.
- Return type
torch.Tensor