DSSMModel
- class rectools.models.dssm.DSSMModel(train_dataset_type: ~typing.Type[~rectools.dataset.torch_datasets.DSSMTrainDatasetBase] = <class 'rectools.dataset.torch_datasets.DSSMTrainDataset'>, user_dataset_type: ~typing.Type[~rectools.dataset.torch_datasets.DSSMUserDatasetBase] = <class 'rectools.dataset.torch_datasets.DSSMUserDataset'>, item_dataset_type: ~typing.Type[~rectools.dataset.torch_datasets.DSSMItemDatasetBase] = <class 'rectools.dataset.torch_datasets.DSSMItemDataset'>, model: ~typing.Optional[~rectools.models.dssm.DSSM] = None, n_factors: int = 128, max_epochs: int = 5, batch_size: int = 128, dataloader_num_workers: int = 0, trainer_sanity_steps: int = 2, trainer_devices: ~typing.Union[str, int] = 1, trainer_accelerator: str = 'auto', callbacks: ~typing.Optional[~typing.Union[~typing.List[~pytorch_lightning.callbacks.callback.Callback], ~pytorch_lightning.callbacks.callback.Callback]] = None, loggers: ~typing.Union[~pytorch_lightning.loggers.logger.Logger, ~typing.Iterable[~pytorch_lightning.loggers.logger.Logger], bool] = True, verbose: int = 0, deterministic: bool = False)[source]
Bases:
VectorModel
Wrapper for rectools.models.dssm.DSSM
- Parameters
train_dataset_type (Type(DSSMTrainDatasetBase), default DSSMTrainDataset) – Type of dataset used for training. A child of torch.utils.data.Dataset that implements from_dataset classmethod. Used to construct torch.utils.data.Dataset from a given rectools.dataset.dataset.Dataset.
user_dataset_type (Type(DSSMUserDatasetBase), default DSSMUserDataset) – Type of dataset used for user inference. A child of torch.utils.data.Dataset that implements from_dataset classmethod. Used to construct torch.utils.data.Dataset from a given rectools.dataset.dataset.Dataset.
item_dataset_type (Type(DSSMItemDatasetBase), default DSSMItemDataset) – Type of dataset used for item inference. A child of torch.utils.data.Dataset that implements from_dataset classmethod. Used to construct torch.utils.data.Dataset from a given rectools.dataset.dataset.Dataset.
model (Optional(DSSM), default None) – Which model to wrap. If model is None, an instance of default DSSM is created during fit.
n_factors (int, default 128) – How many hidden units to use in user and item networks. Used only if model is None.
max_epochs (int, default 5) – Stop training if this number of epochs is reached. Keep in mind that if any kind of early stopping callback is passed as one of the callbacks along with a validation dataset, then hitting exactly max_epochs is not guaranteed.
batch_size (int, default 128) – How many samples per batch to load.
dataloader_num_workers (int, default 0) – How many processes to use for data loading. Defaults to 0, which means that all data will be loaded in the main process.
trainer_sanity_steps (int, default 2) – Sanity check runs n validation batches before starting the training routine.
trainer_devices (str | int, default 1) – “auto” means determine the number of available devices based on the trainer_accelerator type. In case on an integer, it will be mapped to either gpus, tpu_cores, num_processes or ipus, based on the accelerator type.
trainer_accelerator (str, default 'auto') – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “auto”). The “auto” option recognizes the machine you are on, and selects the respective.
callbacks (Optional(Sequence(Callback)), default None) – Which callbacks to use. For instance, pytorch_lightning.callbacks.TQDMProgressBar, etc.
loggers (LightningLoggerBase | iterable(LightningLoggerBase) | bool, default True) – Which loggers to use. For instance, pytorch_lightning.loggers.TensorboardLogger, etc.
verbose (int, default 0) – Verbosity level (applies only to recommend loop).
deterministic (bool, default
False
) – IfTrue
, sets whether PyTorch operations must use deterministic algorithms. Use pytorch_lightning.seed_everything together with this param to fix the random state.
- Inherited-members
Methods
fit
(dataset, *args, **kwargs)Fit model.
get_vectors
(dataset)recommend
(users, dataset, k, filter_viewed)Recommend items for users.
recommend_to_items
(target_items, dataset, k)Recommend items for target items.
Attributes
i2i_dist
n_threads
recommends_for_cold
recommends_for_warm
u2i_dist