BERT4RecDataPreparator

class rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator(session_max_len: int, n_negatives: Optional[int], batch_size: int, dataloader_num_workers: int, train_min_user_interactions: int, negative_sampler: Optional[TransformerNegativeSamplerBase] = None, mask_prob: float = 0.15, get_val_mask_func: Optional[Callable[[...], ndarray]] = None, shuffle_train: bool = True, get_val_mask_func_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any)[source]

Bases: TransformerDataPreparatorBase

Data Preparator for BERT4RecModel.

Parameters
  • session_max_len (int) – Maximum length of user sequence.

  • batch_size (int) – How many samples per batch to load.

  • dataloader_num_workers (int) – Number of loader worker processes.

  • shuffle_train (bool, default True) – If True, reshuffles data at each epoch.

  • train_min_user_interactions (int, default 2) – Minimum length of user sequence. Cannot be less than 2.

  • get_val_mask_func (Callable, default None) – Function to get validation mask.

  • n_negatives (optional(int), default None) – Number of negatives for BCE, gBCE and sampled_softmax losses.

  • negative_sampler (optional(TransformerNegativeSamplerBase), default None) – Negative sampler.

  • mask_prob (float, default 0.15) – Probability of masking an item in interactions sequence.

  • get_val_mask_func_kwargs (optional(InitKwargs), default None) – Additional arguments for the get_val_mask_func. Make sure all dict values have JSON serializable types.

  • kwargs (Any) –

Inherited-members

Methods

get_dataloader_recommend(dataset, batch_size)

Construct recommend dataloader from processed dataset.

get_dataloader_train()

Construct train dataloader from processed dataset.

get_dataloader_val()

Construct validation dataloader from processed dataset.

get_known_item_ids()

Return external item ids from processed dataset in sorted order.

get_known_items_sorted_internal_ids()

Return internal item ids from processed dataset in sorted order.

process_dataset_train(dataset)

Process train dataset and save data.

transform_dataset_i2i(dataset)

Process dataset for i2i recommendations.

transform_dataset_u2i(dataset, users[, context])

Process dataset for u2i recommendations.

Attributes

item_extra_tokens

n_item_extra_tokens

Return number of padding elements

train_session_max_len_addition