TransformerModelBase

class rectools.models.nn.transformers.base.TransformerModelBase(data_preparator_type: ~typing.Type[~rectools.models.nn.transformers.data_preparator.TransformerDataPreparatorBase], transformer_layers_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.TransformerLayersBase] = <class 'rectools.models.nn.transformers.net_blocks.PreLNTransformerLayers'>, n_blocks: int = 2, n_heads: int = 4, n_factors: int = 256, use_pos_emb: bool = True, use_causal_attn: bool = False, use_key_padding_mask: bool = False, dropout_rate: float = 0.2, session_max_len: int = 100, dataloader_num_workers: int = 0, batch_size: int = 128, loss: str = 'softmax', n_negatives: int = 1, gbce_t: float = 0.2, lr: float = 0.001, epochs: int = 3, verbose: int = 0, deterministic: bool = False, recommend_batch_size: int = 256, recommend_torch_device: ~typing.Optional[str] = None, train_min_user_interactions: int = 2, item_net_block_types: ~typing.Sequence[~typing.Type[~rectools.models.nn.item_net.ItemNetBase]] = (<class 'rectools.models.nn.item_net.IdEmbeddingsItemNet'>, <class 'rectools.models.nn.item_net.CatFeaturesItemNet'>), item_net_constructor_type: ~typing.Type[~rectools.models.nn.item_net.ItemNetConstructorBase] = <class 'rectools.models.nn.item_net.SumOfEmbeddingsConstructor'>, pos_encoding_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.PositionalEncodingBase] = <class 'rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding'>, lightning_module_type: ~typing.Type[~rectools.models.nn.transformers.lightning.TransformerLightningModuleBase] = <class 'rectools.models.nn.transformers.lightning.TransformerLightningModule'>, negative_sampler_type: ~typing.Type[~rectools.models.nn.transformers.negative_sampler.TransformerNegativeSamplerBase] = <class 'rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler'>, similarity_module_type: ~typing.Type[~rectools.models.nn.transformers.similarity.SimilarityModuleBase] = <class 'rectools.models.nn.transformers.similarity.DistanceSimilarityModule'>, backbone_type: ~typing.Type[~rectools.models.nn.transformers.torch_backbone.TransformerBackboneBase] = <class 'rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone'>, get_val_mask_func: ~typing.Optional[~collections.abc.Callable[[...], ~numpy.ndarray]] = None, get_trainer_func: ~typing.Optional[~collections.abc.Callable[[...], ~pytorch_lightning.trainer.trainer.Trainer]] = None, get_val_mask_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, get_trainer_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, data_preparator_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, transformer_layers_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, item_net_constructor_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, pos_encoding_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, lightning_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, negative_sampler_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, similarity_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, backbone_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, **kwargs: ~typing.Any)[source]

Bases: ModelBase[TransformerModelConfig_T]

Base model for all recommender algorithms that work on transformer architecture (e.g. SASRec, Bert4Rec). To create a custom transformer model it is necessary to inherit from this class and write self.data_preparator initialization logic.

Inherited-members

Parameters
  • data_preparator_type (Type[TransformerDataPreparatorBase]) –

  • transformer_layers_type (Type[TransformerLayersBase]) –

  • n_blocks (int) –

  • n_heads (int) –

  • n_factors (int) –

  • use_pos_emb (bool) –

  • use_causal_attn (bool) –

  • use_key_padding_mask (bool) –

  • dropout_rate (float) –

  • session_max_len (int) –

  • dataloader_num_workers (int) –

  • batch_size (int) –

  • loss (str) –

  • n_negatives (int) –

  • gbce_t (float) –

  • lr (float) –

  • epochs (int) –

  • verbose (int) –

  • deterministic (bool) –

  • recommend_batch_size (int) –

  • recommend_torch_device (Optional[str]) –

  • train_min_user_interactions (int) –

  • item_net_block_types (Sequence[Type[ItemNetBase]]) –

  • item_net_constructor_type (Type[ItemNetConstructorBase]) –

  • pos_encoding_type (Type[PositionalEncodingBase]) –

  • lightning_module_type (Type[TransformerLightningModuleBase]) –

  • negative_sampler_type (Type[TransformerNegativeSamplerBase]) –

  • similarity_module_type (Type[SimilarityModuleBase]) –

  • backbone_type (Type[TransformerBackboneBase]) –

  • get_val_mask_func (Optional[Callable[[...], ndarray]]) –

  • get_trainer_func (Optional[Callable[[...], Trainer]]) –

  • get_val_mask_func_kwargs (Optional[Dict[str, Any]]) –

  • get_trainer_func_kwargs (Optional[Dict[str, Any]]) –

  • data_preparator_kwargs (Optional[Dict[str, Any]]) –

  • transformer_layers_kwargs (Optional[Dict[str, Any]]) –

  • item_net_constructor_kwargs (Optional[Dict[str, Any]]) –

  • pos_encoding_kwargs (Optional[Dict[str, Any]]) –

  • lightning_module_kwargs (Optional[Dict[str, Any]]) –

  • negative_sampler_kwargs (Optional[Dict[str, Any]]) –

  • similarity_module_kwargs (Optional[Dict[str, Any]]) –

  • backbone_kwargs (Optional[Dict[str, Any]]) –

  • kwargs (Any) –

Methods

dumps()

Serialize model to bytes.

fit(dataset, *args, **kwargs)

Fit model.

fit_partial(dataset, *args, **kwargs)

Fit model.

from_config(config)

Create model from config.

from_params(params[, sep])

Create model from parameters.

get_config([mode, simple_types])

Return model config.

get_params([simple_types, sep])

Return model parameters.

load(f)

Load model from file.

load_from_checkpoint(checkpoint_path[, ...])

Load model from Lightning checkpoint path.

load_weights_from_checkpoint(checkpoint_path)

Load model weights from Lightning checkpoint path.

loads(data)

Load model from bytes.

recommend(users, dataset, k, filter_viewed)

Recommend items for users.

recommend_to_items(target_items, dataset, k)

Recommend items for target items.

save(f)

Save model to file.

Attributes

recommends_for_cold

recommends_for_warm

require_recommend_context

Indicates whether recommendation context is required for predictions.

torch_model

Pytorch model.

train_loss_name

val_loss_name

config_class

classmethod load_from_checkpoint(checkpoint_path: Union[str, Path], map_location: Optional[Union[device, str]] = None, model_params_update: Optional[Dict[str, Any]] = None) Self[source]

Load model from Lightning checkpoint path.

Parameters
  • checkpoint_path (Union[str, Path]) – Path to checkpoint location.

  • map_location (Union[str, torch.device], optional) – Target device to load the checkpoint (e.g., ‘cpu’, ‘cuda:0’). If None, will use the device the checkpoint was saved on.

  • model_params_update (Dict[str, tp.Any], optional) –

    Contains custom values for checkpoint[‘hyper_parameters’][‘model_config’]. Has to be flattened with ‘dot’ reducer, before passed. You can use this argument to remove training-specific parameters that are not needed anymore.

    e.g. ‘get_trainer_func’

Return type

Model instance.

load_weights_from_checkpoint(checkpoint_path: Union[str, Path]) None[source]

Load model weights from Lightning checkpoint path.

Parameters

checkpoint_path (Union[str, Path]) – Path to checkpoint location.

Return type

None

property torch_model: TransformerBackboneBase

Pytorch model.