BERT4RecDataPreparator
- class rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator(session_max_len: int, n_negatives: Optional[int], batch_size: int, dataloader_num_workers: int, train_min_user_interactions: int, negative_sampler: Optional[TransformerNegativeSamplerBase] = None, mask_prob: float = 0.15, get_val_mask_func: Optional[Callable[[...], ndarray]] = None, shuffle_train: bool = True, get_val_mask_func_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any)[source]
Bases:
TransformerDataPreparatorBaseData Preparator for BERT4RecModel.
- Parameters
session_max_len (int) – Maximum length of user sequence.
batch_size (int) – How many samples per batch to load.
dataloader_num_workers (int) – Number of loader worker processes.
shuffle_train (bool, default True) – If
True, reshuffles data at each epoch.train_min_user_interactions (int, default 2) – Minimum length of user sequence. Cannot be less than 2.
get_val_mask_func (Callable, default None) – Function to get validation mask.
n_negatives (optional(int), default
None) – Number of negatives for BCE, gBCE and sampled_softmax losses.negative_sampler (optional(TransformerNegativeSamplerBase), default
None) – Negative sampler.mask_prob (float, default 0.15) – Probability of masking an item in interactions sequence.
get_val_mask_func_kwargs (optional(InitKwargs), default
None) – Additional arguments for the get_val_mask_func. Make sure all dict values have JSON serializable types.kwargs (Any) –
- Inherited-members
Methods
get_dataloader_recommend(dataset, batch_size)Construct recommend dataloader from processed dataset.
get_dataloader_train()Construct train dataloader from processed dataset.
get_dataloader_val()Construct validation dataloader from processed dataset.
get_known_item_ids()Return external item ids from processed dataset in sorted order.
get_known_items_sorted_internal_ids()Return internal item ids from processed dataset in sorted order.
process_dataset_train(dataset)Process train dataset and save data.
transform_dataset_i2i(dataset)Process dataset for i2i recommendations.
transform_dataset_u2i(dataset, users[, context])Process dataset for u2i recommendations.
Attributes
item_extra_tokensn_item_extra_tokensReturn number of padding elements
train_session_max_len_addition