{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Model configs and saving examples\n", "\n", "There are some common methods for RecTools models that allow running experiments from configs and simplify framework integration with experiment trackers (e.g. MlFlow). They include:\n", "\n", "* `from_config`\n", "* `from_params`\n", "* `get_config`\n", "* `get_params`\n", "\n", "We also allow saving and loading models with methods:\n", "\n", "* `save`\n", "* `load`\n", "\n", "For convenience we also have common functions that do not depend on specific model class or instance. They can be used with any rectools model:\n", "* `model_from_config`\n", "* `model_from_params`\n", "* `load_model`\n", "\n", "\n", "In this example we will show basic usage for all of these methods and common functions as well as config examples for our models." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from datetime import timedelta\n", "import pandas as pd\n", "\n", "from rectools.models import (\n", " SASRecModel,\n", " BERT4RecModel,\n", " ImplicitItemKNNWrapperModel, \n", " ImplicitALSWrapperModel, \n", " ImplicitBPRWrapperModel, \n", " EASEModel, \n", " PopularInCategoryModel, \n", " PopularModel, \n", " RandomModel, \n", " LightFMWrapperModel,\n", " PureSVDModel,\n", " model_from_config,\n", " load_model,\n", " model_from_params\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic usage\n", "### `from_config` and `model_from_config`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`from_config` method allows model initialization from a dictionary of model hyper-params." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "config = {\n", " \"popularity\": \"n_interactions\",\n", " \"period\": timedelta(weeks=2),\n", "}\n", "model = PopularModel.from_config(config)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also use `model_from_config` function to initialise any rectools model. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = {\n", " \"cls\": \"PopularModel\", # always specify \"cls\" for `model_from_config` function\n", " # \"cls\": \"rectools.models.PopularModel\", # will work too\n", " \"popularity\": \"n_interactions\",\n", " \"period\": timedelta(weeks=2),\n", "}\n", "model = model_from_config(config)\n", "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `get_config` and `get_params`\n", "`get_config` method returns a dictionary of model hyper-params. In contrast to the previous method, here you will get a full list of model parameters, even the ones that were not specified during model initialization but instead were set to their default values." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': rectools.models.popular.PopularModel,\n", " 'verbose': 0,\n", " 'popularity': ,\n", " 'period': datetime.timedelta(days=14),\n", " 'begin_from': None,\n", " 'add_cold': False,\n", " 'inverse': False}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.get_config()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can directly use output of `get_config` method to create new model instances using `from_config` method. New instances will have exactly the same hyper-params as the source model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "source_config = model.get_config()\n", "new_model = PopularModel.from_config(source_config)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To get model config in json-compatible format pass `simple_types=True`. See how `popularity` parameter changes for the Popular model in the example below:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'PopularModel',\n", " 'verbose': 0,\n", " 'popularity': 'n_interactions',\n", " 'period': {'days': 14},\n", " 'begin_from': None,\n", " 'add_cold': False,\n", " 'inverse': False}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.get_config(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`get_params` method allows to get model hyper-parameters as a flat dictionary which is often more convenient for experiment trackers. \n", "\n", "\n", "Don't forget to pass `simple_types=True` to make the format json-compatible. Note that you can't initialize a new model from the output of this method." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'PopularModel',\n", " 'verbose': 0,\n", " 'popularity': 'n_interactions',\n", " 'period.days': 14,\n", " 'begin_from': None,\n", " 'add_cold': False,\n", " 'inverse': False}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `from_params` and `model_from_params`\n", "`from_params` model class methods and `model_from_params` function act exactly like `from_config` but always expect dict of model parameters in a \"flat\" form. \n", "\"Flat-dict\" form of configs is very useful for hyper-parameters search (e.g. with Optuna)\n", "\n", "See example below:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = {\n", " \"cls\": \"PopularModel\",\n", " \"popularity\": \"n_interactions\",\n", " \"period.days\": 14, # flat form with ``.`` as a separator\n", "}\n", "model = model_from_params(params)\n", "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `save`, `load` and `load_model`\n", "`save` and `load` model methods do exactly what you would expect from their naming :)\n", "Fit model to dataset before saving. Weights will be loaded during `load` method." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "220" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.save(\"pop_model.pkl\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loaded = PopularModel.load(\"pop_model.pkl\")\n", "loaded" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also use `load_model` function to load any rectools model." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loaded = load_model(\"pop_model.pkl\")\n", "loaded" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configs examples for all models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SASRec" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", "/data/home/dmtikhono1/RecTools/.venv/lib/python3.9/site-packages/pydantic/main.py:426: UserWarning: Pydantic serializer warnings:\n", " Expected `str` but got `tuple` with value `('rectools.models.nn.item...net.CatFeaturesItemNet')` - serialized value may not be as expected\n", " return self.__pydantic_serializer__.to_python(\n" ] }, { "data": { "text/plain": [ "{'cls': 'SASRecModel',\n", " 'verbose': 0,\n", " 'data_preparator_type': 'rectools.models.nn.transformers.sasrec.SASRecDataPreparator',\n", " 'n_blocks': 1,\n", " 'n_heads': 1,\n", " 'n_factors': 64,\n", " 'use_pos_emb': True,\n", " 'use_causal_attn': True,\n", " 'use_key_padding_mask': False,\n", " 'dropout_rate': 0.2,\n", " 'session_max_len': 100,\n", " 'dataloader_num_workers': 0,\n", " 'batch_size': 128,\n", " 'loss': 'softmax',\n", " 'n_negatives': 1,\n", " 'gbce_t': 0.2,\n", " 'lr': 0.001,\n", " 'epochs': 2,\n", " 'deterministic': False,\n", " 'recommend_batch_size': 256,\n", " 'recommend_torch_device': None,\n", " 'train_min_user_interactions': 2,\n", " 'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',\n", " 'rectools.models.nn.item_net.CatFeaturesItemNet'],\n", " 'item_net_constructor_type': 'rectools.models.nn.item_net.SumOfEmbeddingsConstructor',\n", " 'pos_encoding_type': 'rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding',\n", " 'transformer_layers_type': 'rectools.models.nn.transformers.sasrec.SASRecTransformerLayers',\n", " 'lightning_module_type': 'rectools.models.nn.transformers.lightning.TransformerLightningModule',\n", " 'get_val_mask_func': None,\n", " 'get_trainer_func': None,\n", " 'data_preparator_kwargs': None,\n", " 'transformer_layers_kwargs': None,\n", " 'item_net_constructor_kwargs': None,\n", " 'pos_encoding_kwargs': None,\n", " 'lightning_module_kwargs': None}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = {\n", " \"epochs\": 2,\n", " \"n_blocks\": 1,\n", " \"n_heads\": 1,\n", " \"n_factors\": 64, \n", "}\n", "\n", "model = SASRecModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transformer models (SASRec and BERT4Rec) in RecTools may accept functions and classes as arguments. These types of arguments are fully compatible with RecTools configs. You can eigther pass them as python objects or as strings that define their import paths.\n", "\n", "Below is an example of both approaches:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "data": { "text/plain": [ "{'cls': 'SASRecModel',\n", " 'verbose': 0,\n", " 'data_preparator_type': 'rectools.models.nn.transformers.sasrec.SASRecDataPreparator',\n", " 'n_blocks': 2,\n", " 'n_heads': 4,\n", " 'n_factors': 256,\n", " 'use_pos_emb': True,\n", " 'use_causal_attn': True,\n", " 'use_key_padding_mask': False,\n", " 'dropout_rate': 0.2,\n", " 'session_max_len': 100,\n", " 'dataloader_num_workers': 0,\n", " 'batch_size': 128,\n", " 'loss': 'softmax',\n", " 'n_negatives': 1,\n", " 'gbce_t': 0.2,\n", " 'lr': 0.001,\n", " 'epochs': 3,\n", " 'deterministic': False,\n", " 'recommend_batch_size': 256,\n", " 'recommend_torch_device': None,\n", " 'train_min_user_interactions': 2,\n", " 'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',\n", " 'rectools.models.nn.item_net.CatFeaturesItemNet'],\n", " 'item_net_constructor_type': 'rectools.models.nn.item_net.SumOfEmbeddingsConstructor',\n", " 'pos_encoding_type': 'rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding',\n", " 'transformer_layers_type': 'rectools.models.nn.transformers.sasrec.SASRecTransformerLayers',\n", " 'lightning_module_type': 'rectools.models.nn.transformers.lightning.TransformerLightningModule',\n", " 'get_val_mask_func': '__main__.leave_one_out_mask',\n", " 'get_trainer_func': None,\n", " 'data_preparator_kwargs': None,\n", " 'transformer_layers_kwargs': None,\n", " 'item_net_constructor_kwargs': None,\n", " 'pos_encoding_kwargs': None,\n", " 'lightning_module_kwargs': None}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series:\n", " rank = (\n", " interactions\n", " .sort_values(Columns.Datetime, ascending=False, kind=\"stable\")\n", " .groupby(Columns.User, sort=False)\n", " .cumcount()\n", " )\n", " return rank == 0\n", "\n", "config = {\n", " # function to get validation mask\n", " \"get_val_mask_func\": leave_one_out_mask,\n", " # path to transformer layers class\n", " \"transformer_layers_type\": \"rectools.models.nn.transformers.sasrec.SASRecTransformerLayers\",\n", "}\n", "\n", "model = SASRecModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### BERT4Rec" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "data": { "text/plain": [ "{'cls': 'BERT4RecModel',\n", " 'verbose': 0,\n", " 'data_preparator_type': 'rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator',\n", " 'n_blocks': 1,\n", " 'n_heads': 1,\n", " 'n_factors': 64,\n", " 'use_pos_emb': True,\n", " 'use_causal_attn': False,\n", " 'use_key_padding_mask': True,\n", " 'dropout_rate': 0.2,\n", " 'session_max_len': 100,\n", " 'dataloader_num_workers': 0,\n", " 'batch_size': 128,\n", " 'loss': 'softmax',\n", " 'n_negatives': 1,\n", " 'gbce_t': 0.2,\n", " 'lr': 0.001,\n", " 'epochs': 2,\n", " 'deterministic': False,\n", " 'recommend_batch_size': 256,\n", " 'recommend_torch_device': None,\n", " 'train_min_user_interactions': 2,\n", " 'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',\n", " 'rectools.models.nn.item_net.CatFeaturesItemNet'],\n", " 'item_net_constructor_type': 'rectools.models.nn.item_net.SumOfEmbeddingsConstructor',\n", " 'pos_encoding_type': 'rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding',\n", " 'transformer_layers_type': 'rectools.models.nn.transformers.net_blocks.PreLNTransformerLayers',\n", " 'lightning_module_type': 'rectools.models.nn.transformers.lightning.TransformerLightningModule',\n", " 'get_val_mask_func': '__main__.leave_one_out_mask',\n", " 'get_trainer_func': None,\n", " 'data_preparator_kwargs': None,\n", " 'transformer_layers_kwargs': None,\n", " 'item_net_constructor_kwargs': None,\n", " 'pos_encoding_kwargs': None,\n", " 'lightning_module_kwargs': None,\n", " 'mask_prob': 0.2}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = {\n", " \"epochs\": 2,\n", " \"n_blocks\": 1,\n", " \"n_heads\": 1,\n", " \"n_factors\": 64,\n", " \"mask_prob\": 0.2,\n", " \"get_val_mask_func\": leave_one_out_mask, # function to get validation mask\n", " # path to transformer layers class\n", " \"transformer_layers_type\": \"rectools.models.nn.transformers.base.PreLNTransformerLayers\", \n", "}\n", "\n", "model = BERT4RecModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ItemKNN\n", "`ImplicitItemKNNWrapperModel` is a wrapper. \n", "Use \"model\" key in config to specify wrapped model class and params:\n", "\n", "Specify which model you want to wrap under the \"model.cls\" key. Options are:\n", "- \"TFIDFRecommender\"\n", "- \"CosineRecommender\"\n", "- \"BM25Recommender\"\n", "- \"ItemItemRecommender\"\n", "- A path to a class (including any custom class) that can be imported. Like \"implicit.nearest_neighbours.TFIDFRecommender\"\n", "\n", "Specify wrapped model hyper-params under the \"model\" dict relevant keys." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'ImplicitItemKNNWrapperModel',\n", " 'verbose': 0,\n", " 'model.cls': 'TFIDFRecommender',\n", " 'model.K': 50,\n", " 'model.num_threads': 1}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = {\n", " \"model\": {\n", " \"cls\": \"TFIDFRecommender\", # or \"implicit.nearest_neighbours.TFIDFRecommender\"\n", " \"K\": 50, \n", " \"num_threads\": 1\n", " } \n", "}\n", "\n", "model = ImplicitItemKNNWrapperModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'ImplicitItemKNNWrapperModel',\n", " 'verbose': 0,\n", " 'model.cls': 'TFIDFRecommender',\n", " 'model.K': 50,\n", " 'model.num_threads': 1}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = { # flat form\n", " \"model.cls\": \"TFIDFRecommender\", \n", " \"model.K\": 50,\n", " \"model.num_threads\": 1,\n", "}\n", "model = ImplicitItemKNNWrapperModel.from_params(params)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### iALS\n", "`ImplicitALSWrapperModel` is a wrapper. \n", "Use \"model\" key in config to specify wrapped model class and params: \n", "\n", "Specify which model you want to wrap under the \"model.cls\" key. Since there is only one default model, you can skip this step. \"implicit.als.AlternatingLeastSquares\" will be used by default. Also you can pass a path to a class (including any custom class) that can be imported.\n", "\n", "Specify wrapped model hyper-params under the \"model\" dict relevant keys. \n", "\n", "Specify wrapper hyper-params under relevant config keys." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'ImplicitALSWrapperModel',\n", " 'verbose': 0,\n", " 'model.cls': 'AlternatingLeastSquares',\n", " 'model.factors': 16,\n", " 'model.regularization': 0.01,\n", " 'model.alpha': 1.0,\n", " 'model.dtype': 'float32',\n", " 'model.use_gpu': True,\n", " 'model.iterations': 2,\n", " 'model.calculate_training_loss': False,\n", " 'model.random_state': 32,\n", " 'fit_features_together': True,\n", " 'recommend_n_threads': None,\n", " 'recommend_use_gpu_ranking': None}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = {\n", " \"model\": {\n", " # \"cls\": \"AlternatingLeastSquares\", # will work too\n", " # \"cls\": \"implicit.als.AlternatingLeastSquares\", # will work too\n", " \"factors\": 16,\n", " \"num_threads\": 2,\n", " \"iterations\": 2,\n", " \"random_state\": 32\n", " },\n", " \"fit_features_together\": True,\n", "}\n", "\n", "model = ImplicitALSWrapperModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'ImplicitALSWrapperModel',\n", " 'verbose': 0,\n", " 'model.cls': 'AlternatingLeastSquares',\n", " 'model.factors': 16,\n", " 'model.regularization': 0.01,\n", " 'model.alpha': 1.0,\n", " 'model.dtype': 'float32',\n", " 'model.use_gpu': True,\n", " 'model.iterations': 2,\n", " 'model.calculate_training_loss': False,\n", " 'model.random_state': 32,\n", " 'fit_features_together': False,\n", " 'recommend_n_threads': None,\n", " 'recommend_use_gpu_ranking': False}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = { # flat form\n", " \"model.factors\": 16, \n", " \"model.iterations\": 2,\n", " \"model.random_state\": 32,\n", " \"recommend_use_gpu_ranking\": False,\n", "}\n", "model = ImplicitALSWrapperModel.from_params(params)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### BPR-MF\n", "`ImplicitBPRWrapperModel` is a wrapper. \n", "Use \"model\" key in config to specify wrapped model class and params: \n", "\n", "Specify which model you want to wrap un\\der the \"model.cls\" key. Since there is only one default model, you can skip this step. \"implicit.bpr.BayesianPersonalizedRanking\" will be used by default. Also you can pass a path to a class (including any custom class) that can be imported.\n", "\n", "Specify wrapped model hyper-params under the \"model\" dict relevant keys. \n", "\n", "Specify wrapper hyper-params under relevant config keys." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'ImplicitBPRWrapperModel',\n", " 'verbose': 0,\n", " 'model.cls': 'BayesianPersonalizedRanking',\n", " 'model.factors': 16,\n", " 'model.learning_rate': 0.01,\n", " 'model.regularization': 0.01,\n", " 'model.dtype': 'float64',\n", " 'model.iterations': 2,\n", " 'model.verify_negative_samples': True,\n", " 'model.random_state': 32,\n", " 'model.use_gpu': True,\n", " 'recommend_n_threads': None,\n", " 'recommend_use_gpu_ranking': False}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = {\n", " \"model\": {\n", " # \"cls\": \"BayesianPersonalizedRanking\", # will work too\n", " # \"cls\": \"implicit.bpr.BayesianPersonalizedRanking\", # will work too\n", " \"factors\": 16,\n", " \"iterations\": 2,\n", " \"random_state\": 32\n", " },\n", " \"recommend_use_gpu_ranking\": False,\n", "}\n", "\n", "model = ImplicitBPRWrapperModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'ImplicitBPRWrapperModel',\n", " 'verbose': 0,\n", " 'model.cls': 'BayesianPersonalizedRanking',\n", " 'model.factors': 16,\n", " 'model.learning_rate': 0.01,\n", " 'model.regularization': 0.01,\n", " 'model.dtype': 'float64',\n", " 'model.iterations': 2,\n", " 'model.verify_negative_samples': True,\n", " 'model.random_state': 32,\n", " 'model.use_gpu': True,\n", " 'recommend_n_threads': None,\n", " 'recommend_use_gpu_ranking': False}" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = { # flat form\n", " \"model.factors\": 16, \n", " \"model.iterations\": 2,\n", " \"model.random_state\": 32,\n", " \"recommend_use_gpu_ranking\": False,\n", "}\n", "model = ImplicitBPRWrapperModel.from_params(params)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### EASE" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'EASEModel',\n", " 'verbose': 1,\n", " 'regularization': 100.0,\n", " 'recommend_n_threads': 0,\n", " 'recommend_use_gpu_ranking': True}" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = {\n", " \"regularization\": 100,\n", " \"verbose\": 1,\n", "}\n", "\n", "model = EASEModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PureSVD" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'PureSVDModel',\n", " 'verbose': 0,\n", " 'factors': 32,\n", " 'tol': 0.0,\n", " 'maxiter': None,\n", " 'random_state': None,\n", " 'use_gpu': False,\n", " 'recommend_n_threads': 0,\n", " 'recommend_use_gpu_ranking': True}" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = {\n", " \"factors\": 32,\n", "}\n", "\n", "model = PureSVDModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### LightFM\n", "\n", "`LightFMWrapperModel` is a wrapper. \n", "Use \"model\" key in config to specify wrapped model class and params: \n", "\n", "Specify which model you want to wrap under the \"model.cls\" key. Since there is only one default model, you can skip this step. \"LightFM\" will be used by default. Also you can pass a path to a class (including any custom class) that can be imported. Like \"lightfm.lightfm.LightFM\"\n", "\n", "Specify wrapped model hyper-params under the \"model\" dict relevant keys. \n", "\n", "Specify wrapper hyper-params under relevant config keys." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'LightFMWrapperModel',\n", " 'verbose': 0,\n", " 'model.cls': 'LightFM',\n", " 'model.no_components': 16,\n", " 'model.k': 5,\n", " 'model.n': 10,\n", " 'model.learning_schedule': 'adagrad',\n", " 'model.loss': 'warp',\n", " 'model.learning_rate': 0.03,\n", " 'model.rho': 0.95,\n", " 'model.epsilon': 1e-06,\n", " 'model.item_alpha': 0.0,\n", " 'model.user_alpha': 0.0,\n", " 'model.max_sampled': 10,\n", " 'model.random_state': 32,\n", " 'epochs': 2,\n", " 'num_threads': 1,\n", " 'recommend_n_threads': None,\n", " 'recommend_use_gpu_ranking': True}" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = {\n", " \"model\": {\n", " # \"cls\": \"lightfm.lightfm.LightFM\", # will work too \n", " # \"cls\": \"LightFM\", # will work too \n", " \"no_components\": 16,\n", " \"learning_rate\": 0.03,\n", " \"random_state\": 32,\n", " \"loss\": \"warp\"\n", " },\n", " \"epochs\": 2,\n", "}\n", "\n", "model = LightFMWrapperModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'LightFMWrapperModel',\n", " 'verbose': 0,\n", " 'model.cls': 'LightFM',\n", " 'model.no_components': 16,\n", " 'model.k': 5,\n", " 'model.n': 10,\n", " 'model.learning_schedule': 'adagrad',\n", " 'model.loss': 'warp',\n", " 'model.learning_rate': 0.03,\n", " 'model.rho': 0.95,\n", " 'model.epsilon': 1e-06,\n", " 'model.item_alpha': 0.0,\n", " 'model.user_alpha': 0.0,\n", " 'model.max_sampled': 10,\n", " 'model.random_state': 32,\n", " 'epochs': 2,\n", " 'num_threads': 1,\n", " 'recommend_n_threads': None,\n", " 'recommend_use_gpu_ranking': True}" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = { # flat form\n", " \"model.no_components\": 16, \n", " \"model.learning_rate\": 0.03,\n", " \"model.random_state\": 32,\n", " \"model.loss\": \"warp\",\n", " \"epochs\": 2,\n", "}\n", "\n", "model = LightFMWrapperModel.from_params(params)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Popular" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'PopularModel',\n", " 'verbose': 0,\n", " 'popularity': 'n_interactions',\n", " 'period.days': 14,\n", " 'begin_from': None,\n", " 'add_cold': False,\n", " 'inverse': False}" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datetime import timedelta\n", "config = {\n", " \"popularity\": \"n_interactions\",\n", " \"period\": timedelta(weeks=2),\n", "}\n", "\n", "model = PopularModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Popular in category" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'PopularInCategoryModel',\n", " 'verbose': 0,\n", " 'popularity': 'n_interactions',\n", " 'period.days': 1,\n", " 'begin_from': None,\n", " 'add_cold': False,\n", " 'inverse': False,\n", " 'category_feature': 'genres',\n", " 'n_categories': None,\n", " 'mixing_strategy': 'group',\n", " 'ratio_strategy': 'proportional'}" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = {\n", " \"popularity\": \"n_interactions\",\n", " \"period\": timedelta(days=1),\n", " \"category_feature\": \"genres\",\n", " \"mixing_strategy\": \"group\"\n", "}\n", "\n", "model = PopularInCategoryModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'PopularInCategoryModel',\n", " 'verbose': 0,\n", " 'popularity': 'n_interactions',\n", " 'period.days': 1,\n", " 'begin_from': None,\n", " 'add_cold': False,\n", " 'inverse': False,\n", " 'category_feature': 'genres',\n", " 'n_categories': None,\n", " 'mixing_strategy': 'group',\n", " 'ratio_strategy': 'proportional'}" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = { # flat form\n", " \"popularity\": \"n_interactions\",\n", " \"period.days\": 1,\n", " \"category_feature\": \"genres\",\n", " \"mixing_strategy\": \"group\"\n", "}\n", "\n", "model = PopularInCategoryModel.from_params(params)\n", "model.get_params(simple_types=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Random" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cls': 'RandomModel', 'verbose': 0, 'random_state': 32}" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = {\n", " \"random_state\": 32,\n", "}\n", "\n", "model = RandomModel.from_config(config)\n", "model.get_params(simple_types=True)" ] } ], "metadata": { "kernelspec": { "display_name": "rectools", "language": "python", "name": "rectools" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 2 }