DSSMModel

class rectools.models.dssm.DSSMModel(train_dataset_type: ~typing.Type[~rectools.dataset.torch_datasets.DSSMTrainDatasetBase] = <class 'rectools.dataset.torch_datasets.DSSMTrainDataset'>, user_dataset_type: ~typing.Type[~rectools.dataset.torch_datasets.DSSMUserDatasetBase] = <class 'rectools.dataset.torch_datasets.DSSMUserDataset'>, item_dataset_type: ~typing.Type[~rectools.dataset.torch_datasets.DSSMItemDatasetBase] = <class 'rectools.dataset.torch_datasets.DSSMItemDataset'>, model: ~typing.Optional[~rectools.models.dssm.DSSM] = None, n_factors: int = 128, max_epochs: int = 5, batch_size: int = 128, dataloader_num_workers: int = 0, trainer_sanity_steps: int = 2, trainer_devices: ~typing.Union[str, int] = 1, trainer_accelerator: str = 'auto', callbacks: ~typing.Optional[~typing.Union[~typing.List[~pytorch_lightning.callbacks.callback.Callback], ~pytorch_lightning.callbacks.callback.Callback]] = None, loggers: ~typing.Union[~pytorch_lightning.loggers.logger.Logger, ~typing.Iterable[~pytorch_lightning.loggers.logger.Logger], bool] = True, verbose: int = 0, deterministic: bool = False)[source]

Bases: VectorModel

Wrapper for rectools.models.dssm.DSSM

Parameters
  • train_dataset_type (Type(DSSMTrainDatasetBase), default DSSMTrainDataset) – Type of dataset used for training. A child of torch.utils.data.Dataset that implements from_dataset classmethod. Used to construct torch.utils.data.Dataset from a given rectools.dataset.dataset.Dataset.

  • user_dataset_type (Type(DSSMUserDatasetBase), default DSSMUserDataset) – Type of dataset used for user inference. A child of torch.utils.data.Dataset that implements from_dataset classmethod. Used to construct torch.utils.data.Dataset from a given rectools.dataset.dataset.Dataset.

  • item_dataset_type (Type(DSSMItemDatasetBase), default DSSMItemDataset) – Type of dataset used for item inference. A child of torch.utils.data.Dataset that implements from_dataset classmethod. Used to construct torch.utils.data.Dataset from a given rectools.dataset.dataset.Dataset.

  • model (Optional(DSSM), default None) – Which model to wrap. If model is None, an instance of default DSSM is created during fit.

  • n_factors (int, default 128) – How many hidden units to use in user and item networks. Used only if model is None.

  • max_epochs (int, default 5) – Stop training if this number of epochs is reached. Keep in mind that if any kind of early stopping callback is passed as one of the callbacks along with a validation dataset, then hitting exactly max_epochs is not guaranteed.

  • batch_size (int, default 128) – How many samples per batch to load.

  • dataloader_num_workers (int, default 0) – How many processes to use for data loading. Defaults to 0, which means that all data will be loaded in the main process.

  • trainer_sanity_steps (int, default 2) – Sanity check runs n validation batches before starting the training routine.

  • trainer_devices (str | int, default 1) – “auto” means determine the number of available devices based on the trainer_accelerator type. In case on an integer, it will be mapped to either gpus, tpu_cores, num_processes or ipus, based on the accelerator type.

  • trainer_accelerator (str, default 'auto') – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “auto”). The “auto” option recognizes the machine you are on, and selects the respective.

  • callbacks (Optional(Sequence(Callback)), default None) – Which callbacks to use. For instance, pytorch_lightning.callbacks.TQDMProgressBar, etc.

  • loggers (LightningLoggerBase | iterable(LightningLoggerBase) | bool, default True) – Which loggers to use. For instance, pytorch_lightning.loggers.TensorboardLogger, etc.

  • verbose (int, default 0) – Verbosity level (applies only to recommend loop).

  • deterministic (bool, default False) – If True, sets whether PyTorch operations must use deterministic algorithms. Use pytorch_lightning.seed_everything together with this param to fix the random state.

Inherited-members

Methods

fit(dataset, *args, **kwargs)

Fit model.

get_vectors(dataset)

recommend(users, dataset, k, filter_viewed)

Recommend items for users.

recommend_to_items(target_items, dataset, k)

Recommend items for target items.

Attributes

i2i_dist

n_threads

recommends_for_cold

recommends_for_warm

u2i_dist