SimilarityModuleBase

class rectools.models.nn.transformers.similarity.SimilarityModuleBase(*args: Any, **kwargs: Any)[source]

Bases: Module

Base class for similarity module.

Methods

forward(session_embs, item_embs[, ...])

Forward pass to get logits.

item_tower_forward(item_embs)

Forward pass for item tower.

session_tower_forward(session_embs)

Forward pass for session tower.

Attributes

Parameters
  • args (Any) –

  • kwargs (Any) –

forward(session_embs: Tensor, item_embs: Tensor, candidate_item_ids: Optional[Tensor] = None) Tensor[source]

Forward pass to get logits.

Parameters
  • session_embs (Tensor) –

  • item_embs (Tensor) –

  • candidate_item_ids (Optional[Tensor]) –

Return type

Tensor

item_tower_forward(item_embs: Tensor) Tensor[source]

Forward pass for item tower.

Parameters

item_embs (Tensor) –

Return type

Tensor

session_tower_forward(session_embs: Tensor) Tensor[source]

Forward pass for session tower.

Parameters

session_embs (Tensor) –

Return type

Tensor