TransformerLayersBase

class rectools.models.nn.transformers.net_blocks.TransformerLayersBase(*args: Any, **kwargs: Any)[source]

Bases: Module

Base class for transformer layers.

Methods

forward(seqs, timeline_mask, attn_mask, ...)

Forward pass through transformer blocks.

Attributes

Parameters
  • args (Any) –

  • kwargs (Any) –

forward(seqs: Tensor, timeline_mask: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], **kwargs: Any) Tensor[source]

Forward pass through transformer blocks.

Parameters
  • seqs (torch.Tensor) – User sequences of item embeddings.

  • timeline_mask (torch.Tensor) – Mask indicating padding elements.

  • attn_mask (torch.Tensor, optional) – Optional mask to use in forward pass of multi-head attention as attn_mask.

  • key_padding_mask (torch.Tensor, optional) – Optional mask to use in forward pass of multi-head attention as key_padding_mask.

  • kwargs (Any) –

Returns

User sequences passed through transformer layers.

Return type

torch.Tensor