Source code for rectools.model_selection.splitter

#  Copyright 2023 MTS (Mobile Telesystems)
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""Splitter."""

import typing as tp

import numpy as np
import pandas as pd

from rectools import Columns
from rectools.dataset import Interactions
from rectools.model_selection.utils import get_not_seen_mask


[docs]class Splitter: """ Base class to construct data splitters. It cannot be used directly. New splitter can be defined by subclassing the `Splitter` class and implementing `_split_without_filter` method. Check specific class descriptions to get more information. """ def __init__( self, filter_cold_users: bool = True, filter_cold_items: bool = True, filter_already_seen: bool = True ) -> None: self.filter_cold_users = filter_cold_users self.filter_cold_items = filter_cold_items self.filter_already_seen = filter_already_seen
[docs] def split( self, interactions: Interactions, collect_fold_stats: bool = False, ) -> tp.Iterator[tp.Tuple[np.ndarray, np.ndarray, tp.Dict[str, tp.Any]]]: """ Split interactions into folds and apply filtration to the result. Parameters ---------- interactions : Interactions User-item interactions. collect_fold_stats : bool, default False Add some stats to split info, like size of train and test part, number of users and items. Returns ------- iterator(array, array, dict) Yields tuples with train part row numbers, test part row numbers and split info. """ for train_idx, test_idx, split_info in self._split_without_filter(interactions, collect_fold_stats): yield self.filter(interactions, collect_fold_stats, train_idx, test_idx, split_info)
def _split_without_filter( self, interactions: Interactions, collect_fold_stats: bool = False, ) -> tp.Iterator[tp.Tuple[np.ndarray, np.ndarray, tp.Dict[str, tp.Any]]]: """ Split interactions into folds. Parameters ---------- interactions : Interactions User-item interactions. collect_fold_stats : bool, default False Add some stats to split info, like size of train and test part, number of users and items. Returns ------- iterator(array, array, dict) Yields tuples with train part row numbers, test part row numbers and split info. """ raise NotImplementedError
[docs] def filter( self, interactions: Interactions, collect_fold_stats: bool, train_idx: np.ndarray, test_idx: np.ndarray, split_info: tp.Dict[str, tp.Any], ) -> tp.Tuple[np.ndarray, np.ndarray, tp.Dict[str, tp.Any]]: """ Filter train and test indexes from one fold based on `filter_cold_users`, `filter_cold_items`,`filter_already_seen` class fields. They are set to `True` by default. Parameters ---------- interactions : Interactions User-item interactions. collect_fold_stats : bool, default False Add some stats to split info, like size of train and test part, number of users and items. train_idx : array Train part row numbers. test_idx : array Test part row numbers. split_info : dict Information about the split. Returns ------- Tuple(array, array, dict) Returns tuple with filtered train part row numbers, test part row numbers and split info. """ need_ui = self.filter_cold_users or self.filter_cold_items or self.filter_already_seen or collect_fold_stats if need_ui: df = interactions.df train_users = df[Columns.User].values[train_idx] train_items = df[Columns.Item].values[train_idx] test_users = df[Columns.User].values[test_idx] test_items = df[Columns.Item].values[test_idx] unq_train_users = None unq_train_items = None if self.filter_cold_users: unq_train_users = pd.unique(train_users) mask = np.isin(test_users, unq_train_users) test_users = test_users[mask] test_items = test_items[mask] test_idx = test_idx[mask] if self.filter_cold_items: unq_train_items = pd.unique(train_items) mask = np.isin(test_items, unq_train_items) test_users = test_users[mask] test_items = test_items[mask] test_idx = test_idx[mask] if self.filter_already_seen: mask = get_not_seen_mask(train_users, train_items, test_users, test_items) test_users = test_users[mask] test_items = test_items[mask] test_idx = test_idx[mask] if collect_fold_stats: if unq_train_users is None: unq_train_users = pd.unique(train_users) if unq_train_items is None: unq_train_items = pd.unique(train_items) split_info["train"] = train_users.size split_info["train_users"] = unq_train_users.size split_info["train_items"] = unq_train_items.size split_info["test"] = test_users.size split_info["test_users"] = pd.unique(test_users).size split_info["test_items"] = pd.unique(test_items).size return train_idx, test_idx, split_info