DSSMModel
- class rectools.models.dssm.DSSMModel(dataset_type: Dataset[Any], model: Optional[DSSM] = None, max_epochs: int = 5, batch_size: int = 128, dataloader_num_workers: int = 0, trainer_sanity_steps: int = 2, trainer_devices: Union[str, int] = 1, trainer_accelerator: str = 'auto', callbacks: Optional[Sequence[Callback]] = None, loggers: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, verbose: int = 0)[source]
Bases:
VectorModelWrapper for rectools.models.dssm.DSSM
- Parameters
dataset_type (torch.utils.data.Dataset) – A child of torch.utils.data.Dataset that implements from_dataset classmethod. Used to construct torch.utils.data.Dataset from a given rectools.dataset.dataset.Dataset.
model (Optional(DSSM), default None) – Which model to wrap. If model is None, an instance of default DSSM is created during fit.
max_epochs (int, default 5) – Stop training if this number of epochs is reached. Keep in mind that if any kind of early stopping callback is passed as one of the callbacks along with a validation dataset, then hitting exactly max_epochs is not guaranteed.
batch_size (int, default 128) – How many samples per batch to load.
dataloader_num_workers (int, default 0) – How many processes to use for data loading. Defaults to 0, which means that all data will be loaded in the main process.
trainer_sanity_steps (int, default 2) – Sanity check runs n validation batches before starting the training routine.
trainer_devices (str | int, default 1) – “auto” means determine the number of available devices based on the trainer_accelerator type. In case on an integer, it will be mapped to either gpus, tpu_cores, num_processes or ipus, based on the accelerator type.
trainer_accelerator (str, default 'auto') – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “auto”). The “auto” option recognizes the machine you are on, and selects the respective.
callbacks (Optional(Sequence(Callback)), default None) – Which callbacks to use. For instance, pytorch_lightning.callbacks.TQDMProgressBar, etc.
loggers (LightningLoggerBase | iterable(LightningLoggerBase) | bool, default True) – Which loggers to use. For instance, pytorch_lightning.loggers.TensorboardLogger, etc.
verbose (int, default 0) – Verbosity level (applies only to recommend loop).
- Inherited-members
Methods
fit(dataset, *args, **kwargs)Fit model.
get_vectors(dataset)recommend(users, dataset, k, filter_viewed)Recommend items for users.
recommend_to_items(target_items, dataset, k)Recommend items for target items.
Attributes
i2i_distu2i_dist