RandomSplitter

class rectools.model_selection.random_split.RandomSplitter(test_fold_frac: float, n_splits: int = 1, random_state: Optional[int] = None, filter_cold_users: bool = True, filter_cold_items: bool = True, filter_already_seen: bool = True)[source]

Bases: Splitter

Slitter for cross-validation by random. Generate train and test folds with fixed test part ratio without intersections between test folds. Random splitting is applied to interactions. Users and items are not taken into account while preparing splits.

It is also possible to exclude cold users and items and already seen items.

Parameters
  • test_fold_frac (float) – Relative size of test part, must be between 0. and 1.

  • n_splits (int, default 1) – Number of test folds.

  • random_state (int, default None,) – Controls randomness of each fold. Pass an int to get reproducible result across multiple split calls.

  • filter_cold_users (bool, default True) – If True, users that are not present in train will be excluded from test. WARNING: both cold and warm users will be excluded from test.

  • filter_cold_items (bool, default True) – If True, items that are not present in train will be excluded from test. WARNING: both cold and warm items will be excluded from test.

  • filter_already_seen (bool, default True) – If True, pairs (user, item) that are present in train will be excluded from test.

Examples

>>> from rectools import Columns
>>> df = pd.DataFrame(
...     [
...         [1, 2, 1, "2021-09-01"],  # 0
...         [2, 1, 1, "2021-09-02"],  # 1
...         [2, 3, 1, "2021-09-03"],  # 2
...         [3, 2, 1, "2021-09-03"],  # 3
...         [3, 3, 1, "2021-09-04"],  # 4
...         [3, 4, 1, "2021-09-04"],  # 5
...         [1, 2, 1, "2021-09-05"],  # 6
...         [4, 2, 1, "2021-09-05"],  # 7
...     ],
...     columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
... ).astype({Columns.Datetime: "datetime64[ns]"})
>>> interactions = Interactions(df)
>>>
>>> splitter = RandomSplitter(test_fold_frac=0.25, random_state=42, n_splits=2, filter_cold_users=False,
...                     filter_cold_items=False, filter_already_seen=False)
>>> for train_ids, test_ids, _ in splitter.split(interactions):
...     print(train_ids, test_ids)
[2 7 6 1 5 0] [3 4]
[3 4 6 1 5 0] [2 7]
>>>
>>> splitter = RandomSplitter(test_fold_frac=0.25, random_state=42, n_splits=2, filter_cold_users=True,
...                     filter_cold_items=True, filter_already_seen=True)
>>> for train_ids, test_ids, _ in splitter.split(interactions):
...     print(train_ids, test_ids)
[2 7 6 1 5 0] [3 4]
[3 4 6 1 5 0] [2]
Inherited-members

Parameters
  • test_fold_frac (float) –

  • n_splits (int) –

  • random_state (Optional[int]) –

  • filter_cold_users (bool) –

  • filter_cold_items (bool) –

  • filter_already_seen (bool) –

Methods

filter(interactions, collect_fold_stats, ...)

Filter train and test indexes from one fold based on filter_cold_users, filter_cold_items,`filter_already_seen` class fields.

split(interactions[, collect_fold_stats])

Split interactions into folds and apply filtration to the result.