SASRecDataPreparator
- class rectools.models.nn.transformers.sasrec.SASRecDataPreparator(session_max_len: int, batch_size: int, dataloader_num_workers: int, train_min_user_interactions: int = 2, get_val_mask_func: Optional[Callable] = None, shuffle_train: bool = True, n_negatives: Optional[int] = None, negative_sampler: Optional[TransformerNegativeSamplerBase] = None, get_val_mask_func_kwargs: Optional[Dict[str, Any]] = None, extra_cols: Optional[List[str]] = None, add_unix_ts: bool = False, **kwargs: Any)[source]
Bases:
TransformerDataPreparatorBaseData preparator for SASRecModel.
- Parameters
session_max_len (int) – Maximum length of user sequence.
batch_size (int) – How many samples per batch to load.
dataloader_num_workers (int) – Number of loader worker processes.
item_extra_tokens (Sequence(Hashable)) – Which element to use for sequence padding.
shuffle_train (bool, default True) – If
True, reshuffles data at each epoch.train_min_user_interactions (int, default 2) – Minimum length of user sequence. Cannot be less than 2.
get_val_mask_func (Callable, default None) – Function to get validation mask.
n_negatives (optional(int), default
None) – Number of negatives for BCE, gBCE and sampled_softmax losses.negative_sampler (optional(TransformerNegativeSamplerBase), default
None) – Negative sampler.get_val_mask_func_kwargs (optional(InitKwargs), default
None) – Additional arguments for the get_val_mask_func. Make sure all dict values have JSON serializable types.extra_cols (list(str) | None, default
None) – Additional columns from dataset to keep beside of Columns.Inreractionsadd_unix_ts (bool, default
False) – Add extra columnunix_tscontains Column.Datetime converted to seconds from the beginning of the epochkwargs (Any) –
- Inherited-members
Methods
get_dataloader_recommend(dataset, batch_size)Construct recommend dataloader from processed dataset.
get_dataloader_train()Construct train dataloader from processed dataset.
get_dataloader_val()Construct validation dataloader from processed dataset.
get_known_item_ids()Return external item ids from processed dataset in sorted order.
get_known_items_sorted_internal_ids()Return internal item ids from processed dataset in sorted order.
process_dataset_train(dataset)Process train dataset and save data.
transform_dataset_i2i(dataset)Process dataset for i2i recommendations.
transform_dataset_u2i(dataset, users[, context])Process dataset for u2i recommendations.
Attributes
item_extra_tokensn_item_extra_tokensReturn number of padding elements
train_session_max_len_addition