TransformerLayersBase
- class rectools.models.nn.transformers.net_blocks.TransformerLayersBase(*args: Any, **kwargs: Any)[source]
Bases:
ModuleBase class for transformer layers.
Methods
forward(seqs, timeline_mask, attn_mask, ...)Forward pass through transformer blocks.
Attributes
- Parameters
args (Any) –
kwargs (Any) –
- forward(seqs: Tensor, timeline_mask: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], **kwargs: Any) Tensor[source]
Forward pass through transformer blocks.
- Parameters
seqs (torch.Tensor) – User sequences of item embeddings.
timeline_mask (torch.Tensor) – Mask indicating padding elements.
attn_mask (torch.Tensor, optional) – Optional mask to use in forward pass of multi-head attention as attn_mask.
key_padding_mask (torch.Tensor, optional) – Optional mask to use in forward pass of multi-head attention as key_padding_mask.
kwargs (Any) –
- Returns
User sequences passed through transformer layers.
- Return type
torch.Tensor