TransformerBackboneBase

class rectools.models.nn.transformers.torch_backbone.TransformerBackboneBase(n_heads: int, dropout_rate: float, item_model: ItemNetBase, pos_encoding_layer: PositionalEncodingBase, transformer_layers: TransformerLayersBase, similarity_module: SimilarityModuleBase, use_causal_attn: bool = True, use_key_padding_mask: bool = False, **kwargs: Any)[source]

Bases: Module

Base class for transformer torch backbone.

Methods

encode_sessions(batch, item_embs)

Pass user history through item embeddings.

forward(batch[, candidate_item_ids])

Forward pass to get item and session embeddings.

Attributes

Parameters
encode_sessions(batch: Dict[str, Tensor], item_embs: Tensor) Tensor[source]

Pass user history through item embeddings. Add positional encoding. Pass history through transformer blocks.

Parameters
  • batch (Dict[str, torch.Tensor]) – Dictionary containing user sessions data.

  • item_embs (torch.Tensor) – Item embeddings.

Returns

Encoded session embeddings.

Return type

torch.Tensor. [batch_size, session_max_len, n_factors]

forward(batch: Dict[str, Tensor], candidate_item_ids: Optional[Tensor] = None) Tensor[source]

Forward pass to get item and session embeddings. Get item embeddings. Pass user sessions through transformer blocks.

Parameters
  • batch (Dict[str, torch.Tensor]) – Dictionary containing user sessions data, with “x” key containing session tensor.

  • candidate_item_ids (optional(torch.Tensor), default None) – Defined item ids for similarity calculation.

Return type

torch.Tensor