TransformerLightningModule

class rectools.models.nn.transformers.lightning.TransformerLightningModule(torch_model: TransformerBackboneBase, model_config: Dict[str, Any], dataset_schema: Dict[str, Any], item_external_ids: Union[Sequence[Hashable], ndarray], item_extra_tokens: Sequence[Hashable], data_preparator: TransformerDataPreparatorBase, lr: float, gbce_t: float, loss: str, verbose: int = 0, train_loss_name: str = 'train_loss', val_loss_name: str = 'val_loss', adam_betas: Tuple[float, float] = (0.9, 0.98), logits_t: float = 1, **kwargs: Any)[source]

Bases: TransformerLightningModuleBase

Lightning module to train transformer models.

Parameters
  • torch_model (TransformerBackboneBase) – Torch model to make recommendations.

  • model_config (Dict[str, Any]) – Model config.

  • dataset_schema (DatasetSchemaDict) – Dataset schema.

  • item_external_ids (ExternalIds) – External item ids from train dataset.

  • item_extra_tokens (Sequence(Hashable)) – Elements used for sequence padding.

  • lr (float) – Learning rate.

  • gbce_t (float) – Calibration parameter for gBCE loss.

  • loss (str, default "softmax") – Loss function.

  • adam_betas (Tuple[float, float], default (0.9, 0.98)) – Coefficients for running averages of gradient and its square.

  • data_preparator (TransformerDataPreparatorBase) – Data preparator.

  • verbose (int, default 0) – Verbosity level.

  • train_loss_name (str, default "train_loss") – Name of the training loss.

  • val_loss_name (str, default "val_loss") – Name of the training loss.

  • logits_t (float, default 1) – Scale factor for logits.

  • kwargs (Any) –

Methods

get_batch_logits(batch)

Get bacth logits.

on_train_end()

Save fitted state.

on_train_start()

Initialize parameters with values from Xavier normal distribution.

on_validation_end()

Clear item embeddings

on_validation_start()

Save item embeddings

training_step(batch, batch_idx)

Training step.

validation_step(batch, batch_idx)

Validate step.

Attributes

i2i_dist

get_batch_logits(batch: Dict[str, Tensor]) Tensor[source]

Get bacth logits.

Parameters

batch (Dict[str, Tensor]) –

Return type

Tensor

on_train_end() None[source]

Save fitted state.

Return type

None

on_train_start() None[source]

Initialize parameters with values from Xavier normal distribution.

Return type

None

on_validation_end() None[source]

Clear item embeddings

Return type

None

on_validation_start() None[source]

Save item embeddings

Return type

None

training_step(batch: Dict[str, Tensor], batch_idx: int) Tensor[source]

Training step.

Parameters
  • batch (Dict[str, Tensor]) –

  • batch_idx (int) –

Return type

Tensor

validation_step(batch: Dict[str, Tensor], batch_idx: int) Dict[str, Tensor][source]

Validate step.

Parameters
  • batch (Dict[str, Tensor]) –

  • batch_idx (int) –

Return type

Dict[str, Tensor]