TransformerModelBase
- class rectools.models.nn.transformers.base.TransformerModelBase(data_preparator_type: ~typing.Type[~rectools.models.nn.transformers.data_preparator.TransformerDataPreparatorBase], transformer_layers_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.TransformerLayersBase] = <class 'rectools.models.nn.transformers.net_blocks.PreLNTransformerLayers'>, n_blocks: int = 2, n_heads: int = 4, n_factors: int = 256, use_pos_emb: bool = True, use_causal_attn: bool = False, use_key_padding_mask: bool = False, dropout_rate: float = 0.2, session_max_len: int = 100, dataloader_num_workers: int = 0, batch_size: int = 128, loss: str = 'softmax', n_negatives: int = 1, gbce_t: float = 0.2, lr: float = 0.001, epochs: int = 3, verbose: int = 0, deterministic: bool = False, recommend_batch_size: int = 256, recommend_torch_device: ~typing.Optional[str] = None, train_min_user_interactions: int = 2, item_net_block_types: ~typing.Sequence[~typing.Type[~rectools.models.nn.item_net.ItemNetBase]] = (<class 'rectools.models.nn.item_net.IdEmbeddingsItemNet'>, <class 'rectools.models.nn.item_net.CatFeaturesItemNet'>), item_net_constructor_type: ~typing.Type[~rectools.models.nn.item_net.ItemNetConstructorBase] = <class 'rectools.models.nn.item_net.SumOfEmbeddingsConstructor'>, pos_encoding_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.PositionalEncodingBase] = <class 'rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding'>, lightning_module_type: ~typing.Type[~rectools.models.nn.transformers.lightning.TransformerLightningModuleBase] = <class 'rectools.models.nn.transformers.lightning.TransformerLightningModule'>, negative_sampler_type: ~typing.Type[~rectools.models.nn.transformers.negative_sampler.TransformerNegativeSamplerBase] = <class 'rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler'>, similarity_module_type: ~typing.Type[~rectools.models.nn.transformers.similarity.SimilarityModuleBase] = <class 'rectools.models.nn.transformers.similarity.DistanceSimilarityModule'>, backbone_type: ~typing.Type[~rectools.models.nn.transformers.torch_backbone.TransformerBackboneBase] = <class 'rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone'>, get_val_mask_func: ~typing.Optional[~collections.abc.Callable[[...], ~numpy.ndarray]] = None, get_trainer_func: ~typing.Optional[~collections.abc.Callable[[...], ~pytorch_lightning.trainer.trainer.Trainer]] = None, get_val_mask_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, get_trainer_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, data_preparator_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, transformer_layers_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, item_net_constructor_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, pos_encoding_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, lightning_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, negative_sampler_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, similarity_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, backbone_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, **kwargs: ~typing.Any)[source]
Bases:
ModelBase[TransformerModelConfig_T]Base model for all recommender algorithms that work on transformer architecture (e.g. SASRec, Bert4Rec). To create a custom transformer model it is necessary to inherit from this class and write self.data_preparator initialization logic.
- Inherited-members
- Parameters
data_preparator_type (Type[TransformerDataPreparatorBase]) –
transformer_layers_type (Type[TransformerLayersBase]) –
n_blocks (int) –
n_heads (int) –
n_factors (int) –
use_pos_emb (bool) –
use_causal_attn (bool) –
use_key_padding_mask (bool) –
dropout_rate (float) –
session_max_len (int) –
dataloader_num_workers (int) –
batch_size (int) –
loss (str) –
n_negatives (int) –
gbce_t (float) –
lr (float) –
epochs (int) –
verbose (int) –
deterministic (bool) –
recommend_batch_size (int) –
recommend_torch_device (Optional[str]) –
train_min_user_interactions (int) –
item_net_block_types (Sequence[Type[ItemNetBase]]) –
item_net_constructor_type (Type[ItemNetConstructorBase]) –
pos_encoding_type (Type[PositionalEncodingBase]) –
lightning_module_type (Type[TransformerLightningModuleBase]) –
negative_sampler_type (Type[TransformerNegativeSamplerBase]) –
similarity_module_type (Type[SimilarityModuleBase]) –
backbone_type (Type[TransformerBackboneBase]) –
get_val_mask_func (Optional[Callable[[...], ndarray]]) –
get_trainer_func (Optional[Callable[[...], Trainer]]) –
get_val_mask_func_kwargs (Optional[Dict[str, Any]]) –
get_trainer_func_kwargs (Optional[Dict[str, Any]]) –
data_preparator_kwargs (Optional[Dict[str, Any]]) –
transformer_layers_kwargs (Optional[Dict[str, Any]]) –
item_net_constructor_kwargs (Optional[Dict[str, Any]]) –
pos_encoding_kwargs (Optional[Dict[str, Any]]) –
lightning_module_kwargs (Optional[Dict[str, Any]]) –
negative_sampler_kwargs (Optional[Dict[str, Any]]) –
similarity_module_kwargs (Optional[Dict[str, Any]]) –
backbone_kwargs (Optional[Dict[str, Any]]) –
kwargs (Any) –
Methods
dumps()Serialize model to bytes.
fit(dataset, *args, **kwargs)Fit model.
fit_partial(dataset, *args, **kwargs)Fit model.
from_config(config)Create model from config.
from_params(params[, sep])Create model from parameters.
get_config([mode, simple_types])Return model config.
get_params([simple_types, sep])Return model parameters.
load(f)Load model from file.
load_from_checkpoint(checkpoint_path[, ...])Load model from Lightning checkpoint path.
load_weights_from_checkpoint(checkpoint_path)Load model weights from Lightning checkpoint path.
loads(data)Load model from bytes.
recommend(users, dataset, k, filter_viewed)Recommend items for users.
recommend_to_items(target_items, dataset, k)Recommend items for target items.
save(f)Save model to file.
Attributes
recommends_for_coldrecommends_for_warmrequire_recommend_contextIndicates whether recommendation context is required for predictions.
Pytorch model.
train_loss_nameval_loss_nameconfig_class- classmethod load_from_checkpoint(checkpoint_path: Union[str, Path], map_location: Optional[Union[device, str]] = None, model_params_update: Optional[Dict[str, Any]] = None) Self[source]
Load model from Lightning checkpoint path.
- Parameters
checkpoint_path (Union[str, Path]) – Path to checkpoint location.
map_location (Union[str, torch.device], optional) – Target device to load the checkpoint (e.g., ‘cpu’, ‘cuda:0’). If None, will use the device the checkpoint was saved on.
model_params_update (Dict[str, tp.Any], optional) –
Contains custom values for checkpoint[‘hyper_parameters’][‘model_config’]. Has to be flattened with ‘dot’ reducer, before passed. You can use this argument to remove training-specific parameters that are not needed anymore.
e.g. ‘get_trainer_func’
- Return type
Model instance.
- load_weights_from_checkpoint(checkpoint_path: Union[str, Path]) None[source]
Load model weights from Lightning checkpoint path.
- Parameters
checkpoint_path (Union[str, Path]) – Path to checkpoint location.
- Return type
None
- property torch_model: TransformerBackboneBase
Pytorch model.