STULayers

class rectools.models.nn.transformers.hstu.STULayers(n_blocks: int, n_factors: int, n_heads: int, linear_hidden_dim: int, attention_dim: int, session_max_len: int, relative_time_attention: bool, relative_pos_attention: bool, attn_dropout_rate: float = 0.0, dropout_rate: float = 0.2, epsilon: float = 1e-06, **kwargs: Any)[source]

Bases: TransformerLayersBase

STULayers transformer blocks.

Parameters
  • n_blocks (int) – Numbers of stacked STU.

  • n_factors (int) – Latent embeddings size.

  • n_heads (int) – Number of attention heads.

  • linear_hidden_dim (int) – U, V size.

  • attention_dim (int) – Q, K size.

  • session_max_len (int) – Maximum length of user sequence padded or truncated to.

  • relative_time_attention (bool) – Whether to use relative time attention.

  • relative_pos_attention (bool) – Whether to use relative positional attention

  • attn_dropout_rate (float, default 0.2) – Probability of an attention unit to be zeroed.

  • dropout_rate (float, default 0.2) – Probability of a hidden unit to be zeroed.

  • epsilon (float, default 1e-6) – A value passed to LayerNorm for numerical stability.

  • kwargs (Any) –

Methods

forward(seqs, timeline_mask, attn_mask, ...)

Forward pass through STU blocks.

Attributes

forward(seqs: Tensor, timeline_mask: Tensor, attn_mask: Tensor, key_padding_mask: Optional[Tensor], batch: Dict[str, Tensor], **kwargs: Any) Tensor[source]

Forward pass through STU blocks.

Parameters
  • seqs (torch.Tensor) – User sequences of item embeddings.

  • timeline_mask (torch.Tensor) – Mask indicating padding elements.

  • attn_mask (torch.Tensor, optional) – Mask to use in forward pass of multi-head attention as attn_mask.

  • key_padding_mask (torch.Tensor, optional) – Mask to use in forward pass of multi-head attention as key_padding_mask.

  • batch (Dict[str, torch.Tensor]) – Could contain payload information,in particular sequence timestamps.

  • kwargs (Any) –

Returns

User sequences passed through transformer layers.

Return type

torch.Tensor