TransformerLightningModule
- class rectools.models.nn.transformers.lightning.TransformerLightningModule(torch_model: TransformerBackboneBase, model_config: Dict[str, Any], dataset_schema: Dict[str, Any], item_external_ids: Union[Sequence[Hashable], ndarray], item_extra_tokens: Sequence[Hashable], data_preparator: TransformerDataPreparatorBase, lr: float, gbce_t: float, loss: str, verbose: int = 0, train_loss_name: str = 'train_loss', val_loss_name: str = 'val_loss', adam_betas: Tuple[float, float] = (0.9, 0.98), logits_t: float = 1, **kwargs: Any)[source]
Bases:
TransformerLightningModuleBaseLightning module to train transformer models.
- Parameters
torch_model (TransformerBackboneBase) – Torch model to make recommendations.
model_config (Dict[str, Any]) – Model config.
dataset_schema (DatasetSchemaDict) – Dataset schema.
item_external_ids (ExternalIds) – External item ids from train dataset.
item_extra_tokens (Sequence(Hashable)) – Elements used for sequence padding.
lr (float) – Learning rate.
gbce_t (float) – Calibration parameter for gBCE loss.
loss (str, default "softmax") – Loss function.
adam_betas (Tuple[float, float], default (0.9, 0.98)) – Coefficients for running averages of gradient and its square.
data_preparator (TransformerDataPreparatorBase) – Data preparator.
verbose (int, default 0) – Verbosity level.
train_loss_name (str, default "train_loss") – Name of the training loss.
val_loss_name (str, default "val_loss") – Name of the training loss.
logits_t (float, default 1) – Scale factor for logits.
kwargs (Any) –
Methods
get_batch_logits(batch)Get bacth logits.
Save fitted state.
Initialize parameters with values from Xavier normal distribution.
Clear item embeddings
Save item embeddings
training_step(batch, batch_idx)Training step.
validation_step(batch, batch_idx)Validate step.
Attributes
i2i_dist- get_batch_logits(batch: Dict[str, Tensor]) Tensor[source]
Get bacth logits.
- Parameters
batch (Dict[str, Tensor]) –
- Return type
Tensor
- on_train_start() None[source]
Initialize parameters with values from Xavier normal distribution.
- Return type
None