Source code for rectools.dataset.torch_datasets

#  Copyright 2022-2024 MTS (Mobile Telesystems)
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""Special datasets used in neural models."""

from __future__ import annotations

import typing as tp

import numpy as np
import torch
from scipy import sparse
from torch.utils.data import Dataset as TorchDataset

from .dataset import Dataset

DSSMTrainDatasetT = tp.TypeVar("DSSMTrainDatasetT", bound="DSSMTrainDatasetBase")
DSSMItemDatasetT = tp.TypeVar("DSSMItemDatasetT", bound="DSSMItemDatasetBase")
DSSMUserDatasetT = tp.TypeVar("DSSMUserDatasetT", bound="DSSMUserDatasetBase")


[docs]class DSSMTrainDatasetBase(TorchDataset[tp.Any]): """Base class for DSSM training datasets. Used only for type hinting.""" def __init__(self, *args: tp.Any, **kwargs: tp.Any) -> None: raise NotImplementedError() @classmethod def from_dataset(cls: tp.Type[DSSMTrainDatasetT], dataset: Dataset) -> DSSMTrainDatasetT: raise NotImplementedError()
[docs]class DSSMTrainDataset(DSSMTrainDatasetBase): """ Torch dataset wrapper for `rectools.dataset.dataset.Dataset`. Implements `torch.utils.data.Dataset` for subsequent usage with `torch.utils.data.DataLoader`. Does the following: for a given index takes a row of user interactions, a row of user features and samples one positive and one negative items and then returns them as tensors. This class is intended for internal usage or advanced users who want to implement more sophisticated sampling logic. Parameters ---------- items : csr_matrix Item features. users : csr_matrix User features. interactions : csr_matrix Interactions matrix. """ def __init__( self, items: sparse.csr_matrix, users: sparse.csr_matrix, interactions: sparse.csr_matrix, ) -> None: self.items = items self.users = users self.interactions = interactions if not self.interactions.sum(1).all() or (self.interactions < 0).sum(1).any(): raise ValueError( "Impossible to sample from a row that either contains only negative items" " or contains any negatively signed integers." "Make sure that all rows from interactions have at least 1 positive item" ) @classmethod def from_dataset(cls: tp.Type[DSSMTrainDatasetT], dataset: Dataset) -> DSSMTrainDatasetT: ui_matrix = dataset.get_user_item_matrix() # We take hot here since this dataset is used for fit only item_features = dataset.get_hot_item_features() user_features = dataset.get_hot_user_features() if item_features is None: raise AttributeError("Item features attribute of dataset could not be None") if user_features is None: raise AttributeError("User features attribute of dataset could not be None") return cls(items=item_features.get_sparse(), users=user_features.get_sparse(), interactions=ui_matrix) def __len__(self) -> int: return self.interactions.shape[0] def __getitem__( self, idx: int ) -> tp.Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: interactions_vec = self.interactions[idx].toarray().flatten() probabilities = interactions_vec / interactions_vec.sum() pos_i = np.random.choice(np.arange(self.interactions.shape[1]), p=probabilities) neg_i = np.random.choice(np.arange(self.interactions.shape[1])) user_features = torch.FloatTensor(self.users[idx].toarray().flatten()) interactions = torch.FloatTensor(interactions_vec) pos = torch.FloatTensor(self.items[pos_i].toarray().flatten()) neg = torch.FloatTensor(self.items[neg_i].toarray().flatten()) return user_features, interactions, pos, neg
[docs]class DSSMItemDatasetBase(TorchDataset[tp.Any]): """Base class for DSSM item datasets. Used only for type hinting.""" def __init__(self, *args: tp.Any, **kwargs: tp.Any) -> None: raise NotImplementedError() @classmethod def from_dataset(cls: tp.Type[DSSMItemDatasetT], dataset: Dataset) -> DSSMItemDatasetT: raise NotImplementedError()
[docs]class DSSMItemDataset(DSSMItemDatasetBase): """ Torch dataset wrapper for `rectools.dataset.dataset.Dataset`. Implements `torch.utils.data.Dataset` for subsequent usage with `torch.utils.data.DataLoader`. Does the following: for a given index takes a row of item features and then returns them as tensors. This class is intended for internal usage or advanced users. """ def __init__(self, items: sparse.csr_matrix): self.items = items @classmethod def from_dataset(cls: tp.Type[DSSMItemDatasetT], dataset: Dataset) -> DSSMItemDatasetT: # We take all features here since this dataset is used for recommend only, not for fit if dataset.item_features is not None: return cls(dataset.item_features.get_sparse()) raise AttributeError("Item features attribute of dataset could not be None") def __len__(self) -> int: return self.items.shape[0] def __getitem__(self, idx: int) -> torch.FloatTensor: return torch.FloatTensor(self.items[idx].toarray().flatten())
[docs]class DSSMUserDatasetBase(TorchDataset[tp.Any]): """Base class for DSSM training datasets. Used only for type hinting.""" def __init__(self, *args: tp.Any, **kwargs: tp.Any) -> None: raise NotImplementedError() @classmethod def from_dataset( cls: tp.Type[DSSMUserDatasetT], dataset: Dataset, keep_users: tp.Optional[tp.Sequence[int]] = None, ) -> DSSMUserDatasetT: raise NotImplementedError()
[docs]class DSSMUserDataset(DSSMUserDatasetBase): """ Torch dataset wrapper for `rectools.dataset.dataset.Dataset`. Implements `torch.utils.data.Dataset` for subsequent usage with `torch.utils.data.DataLoader`. Does the following: for a given index takes a row of user interactions, a row of user features and then returns them as tensors. This class is intended for internal usage or advanced users. """ def __init__( self, users: sparse.csr_matrix, interactions: sparse.csr_matrix, keep_users: tp.Optional[tp.Sequence[int]] = None, ): if users.shape[0] != interactions.shape[0]: raise ValueError("Number of rows in user features matrix and in interactions matrix must be the same") if keep_users is not None: self.users = users[keep_users] self.interactions = interactions[keep_users] else: self.users = users self.interactions = interactions @classmethod def from_dataset( cls: tp.Type[DSSMUserDatasetT], dataset: Dataset, keep_users: tp.Optional[tp.Sequence[int]] = None, ) -> DSSMUserDatasetT: # We take all features here since this dataset is used for recommend only, not for fit if dataset.user_features is not None: return cls( dataset.user_features.get_sparse(), dataset.get_user_item_matrix(include_warm_users=True), keep_users, ) raise AttributeError("User features attribute of dataset could not be None") def __len__(self) -> int: return self.users.shape[0] def __getitem__(self, idx: int) -> tp.Tuple[torch.FloatTensor, torch.FloatTensor]: user_features = self.users[idx].toarray().flatten() interactions = self.interactions[idx].toarray().flatten() return torch.FloatTensor(user_features), torch.FloatTensor(interactions)