SimilarityModuleBase
- class rectools.models.nn.transformers.similarity.SimilarityModuleBase(*args: Any, **kwargs: Any)[source]
Bases:
ModuleBase class for similarity module.
Methods
forward(session_embs, item_embs[, ...])Forward pass to get logits.
item_tower_forward(item_embs)Forward pass for item tower.
session_tower_forward(session_embs)Forward pass for session tower.
Attributes
- Parameters
args (Any) –
kwargs (Any) –
- forward(session_embs: Tensor, item_embs: Tensor, candidate_item_ids: Optional[Tensor] = None) Tensor[source]
Forward pass to get logits.
- Parameters
session_embs (Tensor) –
item_embs (Tensor) –
candidate_item_ids (Optional[Tensor]) –
- Return type
Tensor