Source code for rectools.metrics.scoring

#  Copyright 2022-2025 MTS (Mobile Telesystems)
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""Metrics calculation module."""

import typing as tp
import warnings

import pandas as pd

from rectools.utils import select_by_type

from .auc import AucMetric, calc_auc_metrics
from .base import Catalog, MetricAtK, merge_reco
from .catalog import CatalogMetric, calc_catalog_metrics
from .classification import ClassificationMetric, SimpleClassificationMetric, calc_classification_metrics
from .diversity import DiversityMetric, calc_diversity_metrics
from .dq import CrossDQMetric, RecoDQMetric, calc_cross_dq_metrics, calc_reco_dq_metrics
from .intersection import IntersectionMetric, calc_intersection_metrics
from .novelty import NoveltyMetric, calc_novelty_metrics
from .popularity import PopularityMetric, calc_popularity_metrics
from .ranking import RankingMetric, calc_ranking_metrics
from .serendipity import SerendipityMetric, calc_serendipity_metrics


[docs]def calc_metrics( # noqa # pylint: disable=too-many-branches,too-many-locals,too-many-statements metrics: tp.Mapping[str, MetricAtK], reco: pd.DataFrame, interactions: tp.Optional[pd.DataFrame] = None, prev_interactions: tp.Optional[pd.DataFrame] = None, catalog: tp.Optional[Catalog] = None, ref_reco: tp.Optional[tp.Union[pd.DataFrame, tp.Dict[tp.Hashable, pd.DataFrame]]] = None, ) -> tp.Dict[str, float]: """ Calculate metrics. Parameters ---------- metrics : dict(str -> Metric) Dict of metric objects to calculate, where key is metric name and value is metric object. reco : pd.DataFrame Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. interactions : pd.DataFrame, optional Interactions table with columns `Columns.User`, `Columns.Item`. Obligatory only for some types of metrics. prev_interactions : pd.DataFrame Table with previous user-item interactions, with columns `Columns.User`, `Columns.Item`. Obligatory only for some types of metrics. catalog : collection, optional Collection of unique item ids that could be used for recommendations. Obligatory only if `ClassificationMetric` or `SerendipityMetric` instances present in `metrics`. ref_reco : Union[pd.DataFrame, Dict[Hashable, pd.DataFrame]], optional Reference recommendations table(s) with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. For multiple intersection calculations we can pass multiple models recommendations in a dict: ``ref_reco = {"one": ref_reco_one, "two": ref_reco_two}`` Obligatory only if `IntersectionMetric` instances present in `metrics`. Returns ------- dict(str->float) Dictionary where keys are the same with keys in `metrics` and values are metric calculation results. Raises ------ ValueError If obligatory argument for some metric not set. Examples -------- >>> from rectools import Columns >>> from rectools.metrics import Accuracy, NDCG >>> reco = pd.DataFrame( ... { ... Columns.User: [1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4], ... Columns.Item: [7, 8, 1, 2, 1, 2, 3, 4, 1, 2, 3], ... Columns.Rank: [1, 2, 1, 2, 1, 2, 3, 4, 1, 2, 3], ... } ... ) >>> interactions = pd.DataFrame( ... { ... Columns.User: [1, 1, 2, 3, 3, 3, 4, 4, 4], ... Columns.Item: [1, 2, 1, 1, 3, 4, 1, 2, 3], ... Columns.Datetime: [1, 1, 1, 1, 1, 2, 2, 2, 2], ... } ... ) >>> split_dt = 2 >>> df_train = interactions.loc[interactions[Columns.Datetime] < split_dt] >>> df_test = interactions.loc[interactions[Columns.Datetime] >= split_dt] >>> metrics = { ... 'ndcg@1': NDCG(k=1), ... 'accuracy@1': Accuracy(k=1) ... } >>> calc_metrics( ... metrics, ... reco=reco, ... interactions=df_test, ... prev_interactions=df_train, ... catalog=df_train[Columns.Item].unique() ... ) {'accuracy@1': 0.3333333333333333, 'ndcg@1': 0.5} """ merged = None results = {} expected_results_len = len(metrics) # Classification classification_metrics = select_by_type(metrics, (ClassificationMetric, SimpleClassificationMetric)) if classification_metrics: if interactions is None: raise ValueError("For calculating classification metrics it's necessary to set 'interactions'") merged = merge_reco(reco, interactions) classification_values = calc_classification_metrics(classification_metrics, merged, catalog) results.update(classification_values) # Ranking ranking_metrics = select_by_type(metrics, RankingMetric) if ranking_metrics: if interactions is None: raise ValueError("For calculating ranking metrics it's necessary to set 'interactions'") merged = merged if merged is not None else merge_reco(reco, interactions) ranking_values = calc_ranking_metrics(ranking_metrics, merged) results.update(ranking_values) # AUC based ranking auc_metrics = select_by_type(metrics, AucMetric) if auc_metrics: if interactions is None: raise ValueError("For calculating AUC-like metrics it's necessary to set 'interactions'") auc_values = calc_auc_metrics(auc_metrics, reco, interactions) results.update(auc_values) # Novelty novelty_metrics = select_by_type(metrics, NoveltyMetric) if novelty_metrics: if prev_interactions is None: raise ValueError("For calculating novelty metrics it's necessary to set 'prev_interactions'") novelty_values = calc_novelty_metrics(novelty_metrics, reco, prev_interactions) results.update(novelty_values) # Catalog catalog_metrics = select_by_type(metrics, CatalogMetric) if catalog_metrics: if catalog is None: raise ValueError("For calculating catalog metrics it's necessary to set 'catalog'") catalog_values = calc_catalog_metrics(catalog_metrics, reco, catalog) results.update(catalog_values) # Popularity popularity_metrics = select_by_type(metrics, PopularityMetric) if popularity_metrics: if prev_interactions is None: raise ValueError("For calculating popularity metrics it's necessary to set 'prev_interactions'") popularity_values = calc_popularity_metrics(popularity_metrics, reco, prev_interactions) results.update(popularity_values) # Diversity diversity_metrics = select_by_type(metrics, DiversityMetric) if diversity_metrics: diversity_values = calc_diversity_metrics(diversity_metrics, reco) results.update(diversity_values) # Serendipity serendipity_metrics = select_by_type(metrics, SerendipityMetric) if serendipity_metrics: if interactions is None: raise ValueError("For calculating serendipity metrics it's necessary to set 'interactions'") if prev_interactions is None: raise ValueError("For calculating serendipity metrics it's necessary to set 'prev_interactions'") if catalog is None: raise ValueError("For calculating serendipity metrics it's necessary to set 'catalog'") serendipity_values = calc_serendipity_metrics( serendipity_metrics, reco, interactions, prev_interactions, catalog, ) results.update(serendipity_values) # Intersection intersection_metrics = select_by_type(metrics, IntersectionMetric) if intersection_metrics: if not ref_reco: raise ValueError("For calculating intersection metrics it's necessary to set 'ref_reco'") intersection_values = calc_intersection_metrics( intersection_metrics, reco, ref_reco, ) results.update(intersection_values) expected_results_len += len(intersection_values) - len(intersection_metrics) # DQ cross_dq_metrics = select_by_type(metrics, CrossDQMetric) if cross_dq_metrics: if interactions is None: raise ValueError("For calculating some of the required DQ metrics it's necessary to set 'interactions'") cross_dq_values = calc_cross_dq_metrics(cross_dq_metrics, reco, interactions) results.update(cross_dq_values) reco_dq_metrics = select_by_type(metrics, RecoDQMetric) if reco_dq_metrics: reco_dq_values = calc_reco_dq_metrics(reco_dq_metrics, reco) results.update(reco_dq_values) if len(results) < expected_results_len: warnings.warn("Custom metrics are not supported.") return {k: v.item() if hasattr(v, "item") else v for k, v in results.items()}