DSSMModel

class rectools.models.dssm.DSSMModel(dataset_type: Dataset[Any], model: Optional[DSSM] = None, max_epochs: int = 5, batch_size: int = 128, dataloader_num_workers: int = 0, trainer_sanity_steps: int = 2, trainer_devices: Union[str, int] = 1, trainer_accelerator: str = 'auto', callbacks: Optional[Sequence[Callback]] = None, loggers: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, verbose: int = 0)[source]

Bases: VectorModel

Wrapper for rectools.models.dssm.DSSM

Parameters
  • dataset_type (torch.utils.data.Dataset) – A child of torch.utils.data.Dataset that implements from_dataset classmethod. Used to construct torch.utils.data.Dataset from a given rectools.dataset.dataset.Dataset.

  • model (Optional(DSSM), default None) – Which model to wrap. If model is None, an instance of default DSSM is created during fit.

  • max_epochs (int, default 5) – Stop training if this number of epochs is reached. Keep in mind that if any kind of early stopping callback is passed as one of the callbacks along with a validation dataset, then hitting exactly max_epochs is not guaranteed.

  • batch_size (int, default 128) – How many samples per batch to load.

  • dataloader_num_workers (int, default 0) – How many processes to use for data loading. Defaults to 0, which means that all data will be loaded in the main process.

  • trainer_sanity_steps (int, default 2) – Sanity check runs n validation batches before starting the training routine.

  • trainer_devices (str | int, default 1) – “auto” means determine the number of available devices based on the trainer_accelerator type. In case on an integer, it will be mapped to either gpus, tpu_cores, num_processes or ipus, based on the accelerator type.

  • trainer_accelerator (str, default 'auto') – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “auto”). The “auto” option recognizes the machine you are on, and selects the respective.

  • callbacks (Optional(Sequence(Callback)), default None) – Which callbacks to use. For instance, pytorch_lightning.callbacks.TQDMProgressBar, etc.

  • loggers (LightningLoggerBase | iterable(LightningLoggerBase) | bool, default True) – Which loggers to use. For instance, pytorch_lightning.loggers.TensorboardLogger, etc.

  • verbose (int, default 0) – Verbosity level (applies only to recommend loop).

Inherited-members

Methods

fit(dataset, *args, **kwargs)

Fit model.

get_vectors(dataset)

recommend(users, dataset, k, filter_viewed)

Recommend items for users.

recommend_to_items(target_items, dataset, k)

Recommend items for target items.

Attributes

i2i_dist

u2i_dist