SASRecTransformerLayer

class rectools.models.nn.transformers.sasrec.SASRecTransformerLayer(n_factors: int, n_heads: int, dropout_rate: float)[source]

Bases: Module

Exactly SASRec author’s transformer block architecture but with pytorch Multi-Head Attention realisation.

Parameters
  • n_factors (int) – Latent embeddings size.

  • n_heads (int) – Number of attention heads.

  • dropout_rate (float) – Probability of a hidden unit to be zeroed.

Methods

forward(seqs, attn_mask, key_padding_mask)

Forward pass through transformer block.

Attributes

forward(seqs: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) Tensor[source]

Forward pass through transformer block.

Parameters
  • seqs (torch.Tensor) – User sequences of item embeddings.

  • attn_mask (torch.Tensor, optional) – Optional mask to use in forward pass of multi-head attention as attn_mask.

  • key_padding_mask (torch.Tensor, optional) – Optional mask to use in forward pass of multi-head attention as key_padding_mask.

Returns

User sequences passed through transformer layers.

Return type

torch.Tensor