RelativeAttentionBias
- class rectools.models.nn.transformers.hstu.RelativeAttentionBias(session_max_len: int, relative_time_attention: bool, relative_pos_attention: bool, num_buckets: int = 128)[source]
Bases:
ModuleComputes relative time and positional attention biases for STU.
- Parameters
session_max_len (int) – Maximum sequence length for user interactions (padded/truncated)
relative_time_attention (bool) – Whether to compute relative time attention from timestamps
relative_pos_attention (bool) – Whether to compute relative positional attention
num_buckets (int) – Number of buckets for quantizing timestamp differences
Methods
forward(batch)Compute relative attention biases.
Compute and return the relative positional attention bias matrix.
forward_time_attention(all_timestamps)- param all_timestamps
User interaction timestamps including the target item timestamp
Attributes
- forward(batch: Dict[str, Tensor]) Tensor[source]
Compute relative attention biases.
- Parameters
batch (Dict[str, torch.Tensor]) – Could contain payload information, in particular sequence timestamps.
- Returns
Variate of sum relative pos/time attention
- Return type
torch.Tensor (batch_size, session_max_len, session_max_len)
- forward_pos_attention() Tensor[source]
Compute and return the relative positional attention bias matrix.
- Return type
torch.Tensor (1, session_max_len, session_max_len)
- forward_time_attention(all_timestamps: Tensor) Tensor[source]
- Parameters
all_timestamps (torch.Tensor (batch_size, session_max_len+1)) – User interaction timestamps including the target item timestamp
- Returns
relative time attention
- Return type
torch.Tensor (batch_size, session_max_len, session_max_len)