RelativeAttentionBias

class rectools.models.nn.transformers.hstu.RelativeAttentionBias(session_max_len: int, relative_time_attention: bool, relative_pos_attention: bool, num_buckets: int = 128)[source]

Bases: Module

Computes relative time and positional attention biases for STU.

Parameters
  • session_max_len (int) – Maximum sequence length for user interactions (padded/truncated)

  • relative_time_attention (bool) – Whether to compute relative time attention from timestamps

  • relative_pos_attention (bool) – Whether to compute relative positional attention

  • num_buckets (int) – Number of buckets for quantizing timestamp differences

Methods

forward(batch)

Compute relative attention biases.

forward_pos_attention()

Compute and return the relative positional attention bias matrix.

forward_time_attention(all_timestamps)

param all_timestamps

User interaction timestamps including the target item timestamp

Attributes

forward(batch: Dict[str, Tensor]) Tensor[source]

Compute relative attention biases.

Parameters

batch (Dict[str, torch.Tensor]) – Could contain payload information, in particular sequence timestamps.

Returns

Variate of sum relative pos/time attention

Return type

torch.Tensor (batch_size, session_max_len, session_max_len)

forward_pos_attention() Tensor[source]

Compute and return the relative positional attention bias matrix.

Return type

torch.Tensor (1, session_max_len, session_max_len)

forward_time_attention(all_timestamps: Tensor) Tensor[source]
Parameters

all_timestamps (torch.Tensor (batch_size, session_max_len+1)) – User interaction timestamps including the target item timestamp

Returns

relative time attention

Return type

torch.Tensor (batch_size, session_max_len, session_max_len)