DistanceSimilarityModule

class rectools.models.nn.transformers.similarity.DistanceSimilarityModule(distance: str = 'dot', **kwargs: Any)[source]

Bases: SimilarityModuleBase

Distance similarity module.

Methods

forward(session_embs, item_embs[, ...])

Forward pass to get logits.

Attributes

dist_available

epsilon_cosine_dist

Parameters
  • distance (str) –

  • kwargs (Any) –

forward(session_embs: Tensor, item_embs: Tensor, candidate_item_ids: Optional[Tensor] = None) Tensor[source]

Forward pass to get logits.

Parameters
  • session_embs (Tensor) –

  • item_embs (Tensor) –

  • candidate_item_ids (Optional[Tensor]) –

Return type

Tensor