BERT4RecModel
- class rectools.models.nn.transformers.bert4rec.BERT4RecModel(n_blocks: int = 2, n_heads: int = 4, n_factors: int = 256, dropout_rate: float = 0.2, mask_prob: float = 0.15, session_max_len: int = 100, train_min_user_interactions: int = 2, loss: str = 'softmax', n_negatives: int = 1, gbce_t: float = 0.2, lr: float = 0.001, batch_size: int = 128, epochs: int = 3, deterministic: bool = False, verbose: int = 0, dataloader_num_workers: int = 0, use_pos_emb: bool = True, use_key_padding_mask: bool = True, use_causal_attn: bool = False, item_net_block_types: ~typing.Sequence[~typing.Type[~rectools.models.nn.item_net.ItemNetBase]] = (<class 'rectools.models.nn.item_net.IdEmbeddingsItemNet'>, <class 'rectools.models.nn.item_net.CatFeaturesItemNet'>), item_net_constructor_type: ~typing.Type[~rectools.models.nn.item_net.ItemNetConstructorBase] = <class 'rectools.models.nn.item_net.SumOfEmbeddingsConstructor'>, pos_encoding_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.PositionalEncodingBase] = <class 'rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding'>, transformer_layers_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.TransformerLayersBase] = <class 'rectools.models.nn.transformers.net_blocks.PreLNTransformerLayers'>, data_preparator_type: ~typing.Type[~rectools.models.nn.transformers.data_preparator.TransformerDataPreparatorBase] = <class 'rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator'>, lightning_module_type: ~typing.Type[~rectools.models.nn.transformers.lightning.TransformerLightningModuleBase] = <class 'rectools.models.nn.transformers.lightning.TransformerLightningModule'>, negative_sampler_type: ~typing.Type[~rectools.models.nn.transformers.negative_sampler.TransformerNegativeSamplerBase] = <class 'rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler'>, similarity_module_type: ~typing.Type[~rectools.models.nn.transformers.similarity.SimilarityModuleBase] = <class 'rectools.models.nn.transformers.similarity.DistanceSimilarityModule'>, backbone_type: ~typing.Type[~rectools.models.nn.transformers.torch_backbone.TransformerBackboneBase] = <class 'rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone'>, get_val_mask_func: ~typing.Optional[~collections.abc.Callable[[...], ~numpy.ndarray]] = None, get_trainer_func: ~typing.Optional[~collections.abc.Callable[[...], ~pytorch_lightning.trainer.trainer.Trainer]] = None, get_val_mask_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, get_trainer_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, recommend_batch_size: int = 256, recommend_torch_device: ~typing.Optional[str] = None, recommend_use_torch_ranking: bool = True, recommend_n_threads: int = 0, data_preparator_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, transformer_layers_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, item_net_block_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, item_net_constructor_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, pos_encoding_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, lightning_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, negative_sampler_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, similarity_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, backbone_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None)[source]
Bases:
TransformerModelBase[BERT4RecModelConfig]BERT4Rec model: transformer-based sequential model with bidirectional attention mechanism and “MLM” (masked item in user sequence) training objective. Our implementation covers multiple loss functions and a variable number of negatives for them.
References
Transformers tutorial: https://rectools.readthedocs.io/en/stable/examples/tutorials/transformers_tutorial.html Advanced training guide: https://rectools.readthedocs.io/en/stable/examples/tutorials/transformers_advanced_training_guide.html Public benchmark: https://github.com/blondered/bert4rec_repro Original BERT4Rec paper: https://arxiv.org/abs/1904.06690 gBCE loss paper: https://arxiv.org/pdf/2308.07192
- Parameters
n_blocks (int, default 2) – Number of transformer blocks.
n_heads (int, default 4) – Number of attention heads.
n_factors (int, default 256) – Latent embeddings size.
dropout_rate (float, default 0.2) – Probability of a hidden unit to be zeroed.
mask_prob (float, default 0.15) – Probability of masking an item in interactions sequence.
session_max_len (int, default 100) – Maximum length of user sequence.
train_min_user_interactions (int, default 2) – Minimum number of interactions user should have to be used for training. Should be greater than 1.
loss ({"softmax", "BCE", "gBCE", "sampled_softmax"}, default "softmax") – Loss function.
n_negatives (int, default 1) – Number of negatives for BCE, gBCE and sampled_softmax losses.
gbce_t (float, default 0.2) – Calibration parameter for gBCE loss.
lr (float, default 0.001) – Learning rate.
batch_size (int, default 128) – How many samples per batch to load.
epochs (int, default 3) – Exact number of training epochs. Will be omitted if get_trainer_func is specified.
deterministic (bool, default
False) – deterministic flag passed to lightning trainer during initialization. Use pytorch_lightning.seed_everything together with this parameter to fix the random seed. Will be omitted if get_trainer_func is specified.verbose (int, default 0) – Verbosity level. Enables progress bar, model summary and logging in default lightning trainer when set to a positive integer. Will be omitted if get_trainer_func is specified.
dataloader_num_workers (int, default 0) – Number of loader worker processes.
use_pos_emb (bool, default
True) – IfTrue, learnable positional encoding will be added to session item embeddings.use_key_padding_mask (bool, default
True) – IfTrue, key_padding_mask will be added in Multi-head Attention.use_causal_attn (bool, default
False) – IfTrue, causal mask will be added as attn_mask in Multi-head Attention. Please note that default BERT4Rec training task (“MLM”) does not work with causal masking. Set this parameter toTrueonly when you change the training task with custom data_preparator_type or if you are absolutely sure of what you are doing.item_net_block_types (sequence of type(ItemNetBase), default (IdEmbeddingsItemNet, CatFeaturesItemNet)) – Type of network returning item embeddings. (IdEmbeddingsItemNet,) - item embeddings based on ids. (CatFeaturesItemNet,) - item embeddings based on categorical features. (IdEmbeddingsItemNet, CatFeaturesItemNet) - item embeddings based on ids and categorical features.
item_net_constructor_type (type(ItemNetConstructorBase), default SumOfEmbeddingsConstructor) – Type of item net blocks aggregation constructor.
pos_encoding_type (type(PositionalEncodingBase), default LearnableInversePositionalEncoding) – Type of positional encoding.
transformer_layers_type (type(TransformerLayersBase), default PreLNTransformerLayers) – Type of transformer layers architecture.
data_preparator_type (type(TransformerDataPreparatorBase), default BERT4RecDataPreparator) – Type of data preparator used for dataset processing and dataloader creation.
lightning_module_type (type(TransformerLightningModuleBase), default TransformerLightningModule) – Type of lightning module defining training procedure.
negative_sampler_type (type(TransformerNegativeSamplerBase), default CatalogUniformSampler) – Type of negative sampler.
similarity_module_type (type(SimilarityModuleBase), default DistanceSimilarityModule) – Type of similarity module.
backbone_type (type(TransformerBackboneBase), default TransformerTorchBackbone) – Type of torch backbone.
get_val_mask_func (Callable, default
None) – Function to get validation mask.get_trainer_func (Callable, default
None) – Function for get custom lightning trainer. If get_trainer_func is None, default trainer will be created based on epochs, deterministic and verbose argument values. Model will be trained for the exact number of epochs. Checkpointing will be disabled. If you want to assign custom trainer after model is initialized, you can manually assign new value to model _trainer attribute.recommend_batch_size (int, default 256) – How many samples per batch to load during recommend. If you want to change this parameter after model is initialized, you can manually assign new value to model recommend_batch_size attribute.
recommend_torch_device ({“cpu”, “cuda”, “cuda:0”, …}, default
None) – String representation for torch.device used for model inference. When set toNone, “cuda” will be used if it is available, “cpu” otherwise. If you want to change this parameter after model is initialized, you can manually assign new value to model recommend_torch_device attribute.get_val_mask_func_kwargs (optional(InitKwargs), default
None) – Additional keyword arguments for the get_val_mask_func. Make sure all dict values have JSON serializable types.get_trainer_func_kwargs (optional(InitKwargs), default
None) – Additional keyword arguments for the get_trainer_func. Make sure all dict values have JSON serializable types.data_preparator_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during data_preparator_type initialization. Make sure all dict values have JSON serializable types.transformer_layers_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during transformer_layers_type initialization. Make sure all dict values have JSON serializable types.optional(dict) (item_net_constructor_kwargs) – Additional keyword arguments to pass during item_net_constructor_type initialization. Make sure all dict values have JSON serializable types.
None (default) – Additional keyword arguments to pass during item_net_constructor_type initialization. Make sure all dict values have JSON serializable types.
pos_encoding_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during pos_encoding_type initialization. Make sure all dict values have JSON serializable types.lightning_module_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during lightning_module_type initialization. Make sure all dict values have JSON serializable types.negative_sampler_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during negative_sampler_type initialization. Make sure all dict values have JSON serializable types.similarity_module_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during similarity_module_type initialization. Make sure all dict values have JSON serializable types.backbone_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during backbone_type initialization. Make sure all dict values have JSON serializable types.recommend_use_torch_ranking (bool) –
recommend_n_threads (int) –
item_net_block_kwargs (Optional[Dict[str, Any]]) –
item_net_constructor_kwargs (Optional[Dict[str, Any]]) –
- Inherited-members
Methods
dumps()Serialize model to bytes.
fit(dataset, *args, **kwargs)Fit model.
fit_partial(dataset, *args, **kwargs)Fit model.
from_config(config)Create model from config.
from_params(params[, sep])Create model from parameters.
get_config([mode, simple_types])Return model config.
get_params([simple_types, sep])Return model parameters.
load(f)Load model from file.
load_from_checkpoint(checkpoint_path[, ...])Load model from Lightning checkpoint path.
load_weights_from_checkpoint(checkpoint_path)Load model weights from Lightning checkpoint path.
loads(data)Load model from bytes.
recommend(users, dataset, k, filter_viewed)Recommend items for users.
recommend_to_items(target_items, dataset, k)Recommend items for target items.
save(f)Save model to file.
Attributes
recommends_for_coldrecommends_for_warmrequire_recommend_contextIndicates whether recommendation context is required for predictions.
torch_modelPytorch model.
train_loss_nameval_loss_name- config_class
alias of
BERT4RecModelConfig