Transformer Models Advanced Training Guide
This guide is showing advanced features of RecTools transformer models training.
Table of Contents
Prepare data
Advanced training guide
Validation fold
Validation loss
Callback for Early Stopping
Callbacks for Checkpoints
Loading Checkpoints
Callbacks for RecSys metrics
RecSys metrics for Early Stopping anf Checkpoints
Advanced training full example
Running full training with all of the described validation features on Kion dataset
More RecTools features for transformers
Saving and loading models
Configs for transformer models
Classes and function in configs
Multi-gpu training
[1]:
import os
import itertools
import typing as tp
import warnings
from collections import Counter
from pathlib import Path
import pandas as pd
import numpy as np
import torch
from lightning_fabric import seed_everything
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Callback
from rectools import Columns, ExternalIds
from rectools.dataset import Dataset
from rectools.metrics import NDCG, Recall, Serendipity, calc_metrics
from rectools.models import BERT4RecModel, SASRecModel, load_model
from rectools.models.nn.item_net import IdEmbeddingsItemNet
from rectools.models.nn.transformers.base import TransformerModelBase
# Enable deterministic behaviour with CUDA >= 10.2
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", FutureWarning)
Prepare data
[2]:
# %%time
!wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_en.zip -O data_en.zip
!unzip -o data_en.zip
!rm data_en.zip
[2]:
# Download dataset
DATA_PATH = Path("./data_en")
items = pd.read_csv(DATA_PATH / 'items_en.csv', index_col=0)
interactions = (
pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=["last_watch_dt"])
.rename(columns={"last_watch_dt": Columns.Datetime})
)
print(interactions.shape)
interactions.head(2)
(5476251, 5)
[2]:
| user_id | item_id | datetime | total_dur | watched_pct | |
|---|---|---|---|---|---|
| 0 | 176549 | 9506 | 2021-05-11 | 4250 | 72.0 |
| 1 | 699317 | 1659 | 2021-05-29 | 8317 | 100.0 |
[3]:
interactions[Columns.User].nunique(), interactions[Columns.Item].nunique()
[3]:
(962179, 15706)
[4]:
# Process interactions
interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)
raw_interactions = interactions[["user_id", "item_id", "datetime", "weight"]]
print(raw_interactions.shape)
raw_interactions.head(2)
dataset = Dataset.construct(raw_interactions)
(5476251, 4)
[5]:
RANDOM_STATE=60
torch.use_deterministic_algorithms(True)
seed_everything(RANDOM_STATE, workers=True)
Seed set to 60
[5]:
60
Advanced Training
Validation fold
Models do not create validation fold during fit by default. However, there is a simple way to force it.
Let’s assume that we want to use Leave-One-Out validation for specific set of users. To apply it we need to implement get_val_mask_func with required logic and pass it to model during initialization.
This function should receive interactions with standard RecTools columns and return a binary mask which identifies interactions that should not be used during model training. But instrad should be used for validation loss calculation. They will also be available for Lightning Callbacks to allow RecSys metrics computations.
Please make sure you do not use ``partial`` while doing this. Partial functions cannot be by serialized using RecTools.
[6]:
# Implement `get_val_mask_func`
N_VAL_USERS = 2048
unique_users = raw_interactions[Columns.User].unique()
VAL_USERS = unique_users[: N_VAL_USERS]
def leave_one_out_mask_for_users(interactions: pd.DataFrame, val_users: ExternalIds) -> np.ndarray:
rank = (
interactions
.sort_values(Columns.Datetime, ascending=False, kind="stable")
.groupby(Columns.User, sort=False)
.cumcount()
)
val_mask = (
(interactions[Columns.User].isin(val_users))
& (rank == 0)
)
return val_mask.values
# We do not use `partial` for correct serialization of the model
def get_val_mask_func(interactions: pd.DataFrame):
return leave_one_out_mask_for_users(interactions, val_users = VAL_USERS)
In this guide we are going to use custom Lighhning trainers. We need to implement function that return desired Lightining trainer and pass it to model during initialization.
[7]:
# Function to get custom trainer
def get_debug_trainer() -> Trainer:
return Trainer(
accelerator="gpu",
devices=1,
min_epochs=2,
max_epochs=2,
deterministic=True,
enable_model_summary=False,
enable_progress_bar=False,
enable_checkpointing=False,
limit_train_batches=2, # limit train batches for quick debug runs
logger = CSVLogger("test_logs"), # We use CSV logging for this guide but there are many other options
)
[8]:
model = SASRecModel(
n_factors=64,
n_blocks=2,
n_heads=2,
dropout_rate=0.2,
train_min_user_interactions=5,
session_max_len=50,
verbose=0,
deterministic=True,
item_net_block_types=(IdEmbeddingsItemNet,),
get_val_mask_func=get_val_mask_func, # pass our custom `get_val_mask_func`
get_trainer_func=get_debug_trainer, # pass our custom trainer func
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Validation loss
Let’s check how the validation loss is being logged.
[33]:
# Fit model. Validation fold and validation loss computation will be done under the hood.
model.fit(dataset);
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
`Trainer.fit` stopped: `max_epochs=2` reached.
Let’s look at model logs. We can access logs directory with model.fit_trainer.log_dir
[34]:
# What's inside the logs directory?
!ls $model.fit_trainer.log_dir
hparams.yaml metrics.csv
[35]:
# Losses and metrics are in the `metrics.csv`
# Let's look at logs
!tail $model.fit_trainer.log_dir/metrics.csv
epoch,step,train_loss,val_loss
0,1,,22.365339279174805
0,1,22.38391876220703,
1,3,,22.189851760864258
1,3,22.898216247558594,
Callback for Early Stopping
By default RecTools transfomers train for exact amount of epochs (specified in epochs argument). When get_trainer_func is provided, number of model training epochs depends on Lightning trainer arguments instead.
Now that we have validation loss logged, let’s use it for model Early Stopping. It will ensure that model will not resume training if validation loss (or any other custom metric) doesn’t impove. We have Lightning Callbacks for that.
[36]:
early_stopping_callback = EarlyStopping(
monitor=SASRecModel.val_loss_name, # or just pass "val_loss" here
mode="min",
min_delta=1. # just for a quick test of functionality
)
trainer = Trainer(
accelerator='gpu',
devices=1,
min_epochs=1, # minimum number of epochs to train before early stopping
max_epochs=20, # maximum number of epochs to train
deterministic=True,
limit_train_batches=2, # use only 2 batches for each epoch for a test run
enable_checkpointing=False,
logger = CSVLogger("test_logs"),
callbacks=early_stopping_callback, # pass our callback
enable_progress_bar=False,
enable_model_summary=False,
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
_trainer attribute.[38]:
# Replace trainer with our custom one
model._trainer = trainer
# Fit model. Everything will happen under the hood
model.fit(dataset);
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Here model stopped training after 4 epochs because validation loss wasn’t improving by our specified min_delta
[39]:
# Let's check out logs
!tail $model.fit_trainer.log_dir/metrics.csv
epoch,step,train_loss,val_loss
0,1,,22.343637466430664
0,1,22.36273765563965,
1,3,,22.159835815429688
1,3,22.33755874633789,
2,5,,21.94308853149414
2,5,22.244243621826172,
3,7,,21.702259063720703
3,7,22.196012496948242,
Callback for Checkpoints
Checkpoints are model states that are saved periodically during training.
[40]:
# Checkpoint last epoch
last_epoch_ckpt = ModelCheckpoint(filename="last_epoch")
# Checkpoints based on validation loss
least_val_loss_ckpt = ModelCheckpoint(
monitor=SASRecModel.val_loss_name, # or just pass "val_loss" here,
mode="min",
filename="{epoch}-{val_loss:.2f}",
save_top_k=2, # Let's save top 2 checkpoints for validation loss
)
[41]:
trainer = Trainer(
accelerator="gpu",
devices=1,
min_epochs=1,
max_epochs=6,
deterministic=True,
limit_train_batches=2, # use only 2 batches for each epoch for a test run
logger = CSVLogger("test_logs"),
callbacks=[last_epoch_ckpt, least_val_loss_ckpt], # pass our callbacks for checkpoints
enable_progress_bar=False,
enable_model_summary=False,
)
# Replace trainer with our custom one
model._trainer = trainer
# Fit model. Everything will happen under the hood
model.fit(dataset);
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
`Trainer.fit` stopped: `max_epochs=6` reached.
Let’s look at model checkpoints that were saved. By default they are neing saved to checkpoints directory in model.fit_trainer.log_dir
[42]:
# We have 2 checkpoints for 2 best validation loss values and one for last epoch
!ls $model.fit_trainer.log_dir/checkpoints
epoch=4-val_loss=21.52.ckpt epoch=5-val_loss=21.24.ckpt last_epoch.ckpt
Loading checkpoints is very simple with load_from_weights_from_checkpoint method.
[44]:
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
model.load_weights_from_checkpoint(ckpt_path)
model.recommend(users=VAL_USERS[:1], dataset=dataset, filter_viewed=True, k=5)
[44]:
| user_id | item_id | score | rank | |
|---|---|---|---|---|
| 0 | 176549 | 15297 | 0.675964 | 1 |
| 1 | 176549 | 2657 | 0.661444 | 2 |
| 2 | 176549 | 10440 | 0.562942 | 3 |
| 3 | 176549 | 4495 | 0.557208 | 4 |
| 4 | 176549 | 6443 | 0.546108 | 5 |
You can also load both model and its weights from checkpoint using load_from_checkpoint class method. Note that there is an important limitation: loaded model will not have ``fit_trainer`` and can’t be saved again. But it is fully ready for recommendations.
[45]:
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
loaded = SASRecModel.load_from_checkpoint(ckpt_path)
loaded.recommend(users=VAL_USERS[:1], dataset=dataset, filter_viewed=True, k=5)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[45]:
| user_id | item_id | score | rank | |
|---|---|---|---|---|
| 0 | 176549 | 15297 | 0.675964 | 1 |
| 1 | 176549 | 2657 | 0.661444 | 2 |
| 2 | 176549 | 10440 | 0.562942 | 3 |
| 3 | 176549 | 4495 | 0.557208 | 4 |
| 4 | 176549 | 6443 | 0.546108 | 5 |
Callbacks for RecSys metrics during training
Monitoring RecSys metrics (or any other custom things) on validation fold is not available out of the box, but we can create a custom Lightning Callback for that.
Below is an example of calculating standard RecTools metrics on validation fold during training. We use it as an explicit example that any customization is possible. But it is recommend to implement metrics calculation using torch for faster computations.
Please look at PyTorch Lightning documentation for more details on custom callbacks.
[9]:
# Implement custom Callback for RecTools metrics computation within validation epochs during training.
class ValidationMetrics(Callback):
def __init__(self, top_k: int, val_metrics: tp.Dict, verbose: int = 0) -> None:
self.top_k = top_k
self.val_metrics = val_metrics
self.verbose = verbose
self.epoch_n_users: int = 0
self.batch_metrics: tp.List[tp.Dict[str, float]] = []
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: tp.Dict[str, torch.Tensor],
batch: tp.Dict[str, torch.Tensor],
batch_idx: int,
dataloader_idx: int = 0
) -> None:
logits = outputs["logits"]
if logits is None:
logits = pl_module.torch_model.encode_sessions(batch, pl_module.item_embs)[:, -1, :]
_, sorted_batch_recos = logits.topk(k=self.top_k)
batch_recos = sorted_batch_recos.tolist()
targets = batch["y"].tolist()
batch_val_users = list(
itertools.chain.from_iterable(
itertools.repeat(idx, len(recos)) for idx, recos in enumerate(batch_recos)
)
)
batch_target_users = list(
itertools.chain.from_iterable(
itertools.repeat(idx, len(targets)) for idx, targets in enumerate(targets)
)
)
batch_recos_df = pd.DataFrame(
{
Columns.User: batch_val_users,
Columns.Item: list(itertools.chain.from_iterable(batch_recos)),
}
)
batch_recos_df[Columns.Rank] = batch_recos_df.groupby(Columns.User, sort=False).cumcount() + 1
interactions = pd.DataFrame(
{
Columns.User: batch_target_users,
Columns.Item: list(itertools.chain.from_iterable(targets)),
}
)
prev_interactions = pl_module.data_preparator.train_dataset.interactions.df
catalog = prev_interactions[Columns.Item].unique()
batch_metrics = calc_metrics(
self.val_metrics,
batch_recos_df,
interactions,
prev_interactions,
catalog
)
batch_n_users = batch["x"].shape[0]
self.batch_metrics.append({metric: value * batch_n_users for metric, value in batch_metrics.items()})
self.epoch_n_users += batch_n_users
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
epoch_metrics = dict(sum(map(Counter, self.batch_metrics), Counter()))
epoch_metrics = {metric: value / self.epoch_n_users for metric, value in epoch_metrics.items()}
self.log_dict(epoch_metrics, on_step=False, on_epoch=True, prog_bar=self.verbose > 0)
self.batch_metrics.clear()
self.epoch_n_users = 0
When custom metrics callback is implemented, we can use the values of these metrics for both Early Stopping and Checkpoints.
[10]:
# Initialize callbacks for metrics calculation and checkpoint based on NDCG value
metrics = {
"NDCG@10": NDCG(k=10),
"Recall@10": Recall(k=10),
"Serendipity@10": Serendipity(k=10),
}
top_k = max([metric.k for metric in metrics.values()])
# Callback for calculating RecSys metrics
val_metrics_callback = ValidationMetrics(top_k=top_k, val_metrics=metrics, verbose=0)
# Callback for checkpoint based on maximization of NDCG@10
best_ndcg_ckpt = ModelCheckpoint(
monitor="NDCG@10",
mode="max",
filename="{epoch}-{NDCG@10:.2f}",
)
[11]:
trainer = Trainer(
accelerator="gpu",
devices=1,
min_epochs=1,
max_epochs=6,
deterministic=True,
limit_train_batches=2, # use only 2 batches for each epoch for a test run
logger = CSVLogger("test_logs"),
callbacks=[val_metrics_callback, best_ndcg_ckpt], # pass our callbacks
enable_progress_bar=False,
enable_model_summary=False,
)
# Replace trainer with our custom one
model._trainer = trainer
# Fit model. Everything will happen under the hood
model.fit(dataset)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
`Trainer.fit` stopped: `max_epochs=6` reached.
[11]:
<rectools.models.nn.transformers.sasrec.SASRecModel at 0x7f27b1e9a430>
We have checkpoint for best NDCG@10 model in the usual directory for checkpoints
[12]:
!ls $model.fit_trainer.log_dir/checkpoints
epoch=5-NDCG@10=0.01.ckpt
We also now have metrics in our logs. Let’s load them
[13]:
def get_logs(model: TransformerModelBase) -> tp.Tuple[pd.DataFrame, ...]:
log_path = Path(model.fit_trainer.log_dir) / "metrics.csv"
epoch_metrics_df = pd.read_csv(log_path)
loss_df = epoch_metrics_df[["epoch", "train_loss"]].dropna()
val_loss_df = epoch_metrics_df[["epoch", "val_loss"]].dropna()
loss_df = pd.merge(loss_df, val_loss_df, how="inner", on="epoch")
loss_df.reset_index(drop=True, inplace=True)
metrics_df = epoch_metrics_df.drop(columns=["train_loss", "val_loss"]).dropna()
metrics_df.reset_index(drop=True, inplace=True)
return loss_df, metrics_df
loss_df, metrics_df = get_logs(model)
loss_df.head()
[13]:
| epoch | train_loss | val_loss | |
|---|---|---|---|
| 0 | 0 | 22.383919 | 22.365339 |
| 1 | 1 | 22.898216 | 22.189852 |
| 2 | 2 | 22.218102 | 21.964468 |
| 3 | 3 | 22.875019 | 21.701391 |
| 4 | 4 | 21.739164 | 21.417864 |
[14]:
metrics_df.head()
[14]:
| NDCG@10 | Recall@10 | Serendipity@10 | epoch | step | |
|---|---|---|---|---|---|
| 0 | 0.000052 | 0.000657 | 0.000004 | 0 | 1 |
| 1 | 0.002204 | 0.024984 | 0.000006 | 1 | 3 |
| 2 | 0.006865 | 0.071006 | 0.000004 | 2 | 5 |
| 3 | 0.009856 | 0.097304 | 0.000003 | 3 | 7 |
| 4 | 0.010442 | 0.107824 | 0.000002 | 4 | 9 |
[15]:
del model
torch.cuda.empty_cache()
Advanced training full example
Running full training with all of the described validation features on Kion dataset
[17]:
# seed again for reproducibility of this piece of code
seed_everything(RANDOM_STATE, workers=True)
# Callbacks
val_metrics_callback = ValidationMetrics(top_k=top_k, val_metrics=metrics, verbose=0)
best_ndcg_ckpt = ModelCheckpoint(
monitor="NDCG@10",
mode="max",
filename="{epoch}-{NDCG@10:.2f}",
)
last_epoch_ckpt = ModelCheckpoint(filename="{epoch}-last_epoch")
early_stopping_callback = EarlyStopping(
monitor="NDCG@10",
patience=5,
mode="max",
)
# Function to get custom trainer with desired callbacks
def get_custom_trainer() -> Trainer:
return Trainer(
accelerator="gpu",
devices=[1],
min_epochs=1,
max_epochs=100,
deterministic=True,
logger = CSVLogger("sasrec_logs"),
enable_progress_bar=False,
enable_model_summary=False,
callbacks=[
val_metrics_callback, # calculate RecSys metrics
best_ndcg_ckpt, # save best NDCG model checkpoint
last_epoch_ckpt, # save model checkpoint after last epoch
early_stopping_callback, # early stopping on NDCG
],
)
# Model
model = SASRecModel(
n_factors=256,
n_blocks=2,
n_heads=4,
dropout_rate=0.2,
train_min_user_interactions=5,
session_max_len=50,
verbose=1,
deterministic=True,
item_net_block_types=(IdEmbeddingsItemNet,),
get_val_mask_func=get_val_mask_func, # pass our custom `get_val_mask_func`
get_trainer_func=get_custom_trainer, # pass function to initialize our custom trainer
)
# Fit model. Everything will happen under the hood
model.fit(dataset);
Seed set to 60
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Early stopping was triggered. We have checkpoints for best NDCG model (on epoch 14) and on last epoch (19)
[28]:
!ls $model.fit_trainer.log_dir/checkpoints
epoch=14-NDCG@10=0.03.ckpt epoch=19-last_epoch.ckpt
Loading best NDCG model from checkpoint and recommending
[29]:
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "epoch=14-NDCG@10=0.03.ckpt")
best_model = SASRecModel.load_from_checkpoint(ckpt_path)
best_model.recommend(users=VAL_USERS[:1], dataset=dataset, filter_viewed=True, k=5)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
[29]:
| user_id | item_id | score | rank | |
|---|---|---|---|---|
| 0 | 176549 | 11749 | 2.610277 | 1 |
| 1 | 176549 | 2025 | 2.577398 | 2 |
| 2 | 176549 | 9342 | 2.394489 | 3 |
| 3 | 176549 | 14488 | 2.366664 | 4 |
| 4 | 176549 | 7571 | 2.289778 | 5 |
Let’s also look at our logs for losses and metrics
[30]:
loss_df, metrics_df = get_logs(model)
pd.concat([metrics_df.head(5), metrics_df.tail(5)])
[30]:
| NDCG@10 | Recall@10 | Serendipity@10 | epoch | step | |
|---|---|---|---|---|---|
| 0 | 0.023663 | 0.183432 | 0.000067 | 0 | 2362 |
| 1 | 0.027919 | 0.209730 | 0.000122 | 1 | 4725 |
| 2 | 0.029360 | 0.216305 | 0.000166 | 2 | 7088 |
| 3 | 0.030170 | 0.226824 | 0.000203 | 3 | 9451 |
| 4 | 0.030412 | 0.225510 | 0.000161 | 4 | 11814 |
| 15 | 0.031640 | 0.226167 | 0.000186 | 15 | 37807 |
| 16 | 0.031333 | 0.230769 | 0.000203 | 16 | 40170 |
| 17 | 0.031238 | 0.228139 | 0.000184 | 17 | 42533 |
| 18 | 0.031893 | 0.232084 | 0.000195 | 18 | 44896 |
| 19 | 0.031560 | 0.230112 | 0.000179 | 19 | 47259 |
Don’t be surprised by the fact that validation loss is less then train loss in the plot below. - First, this is data-specific, you may not see this in other datasets. - Second, validation loss is calculated after the full training epoch while train loss is computed for each batch during training when model still hasn’t seen other batches and hasn’t updated weights. - Validation loss is calculated only in the last item in validation users history. While train loss for SASRec is calculated for each item in user histor except the first one and the validation one.
[31]:
loss_df.plot(kind="line", x="epoch", title="Losses");
[32]:
metrics_df[["epoch", "NDCG@10"]].plot(kind="line", x="epoch", title="NDCG");
More RecTools features for transformers
Saving and loading models
Transformer models can be saved and loaded just like any other RecTools models.
Note that you can’t use these common functions for savings and loading lightning checkpoints. Use ``load_from_checkpoint`` method instead.
Note that you shouldn’t change code for custom functions and classes that were passed to model during initialization if you want to have correct model saving and loading.
[33]:
model.save("my_model.pkl")
[33]:
54579980
[34]:
loaded = load_model("my_model.pkl")
print(type(loaded))
loaded.recommend(users=VAL_USERS[:1], dataset=dataset, filter_viewed=True, k=5)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
<class 'rectools.models.nn.sasrec.SASRecModel'>
[34]:
| user_id | item_id | score | rank | |
|---|---|---|---|---|
| 0 | 176549 | 2599 | 2.681841 | 1 |
| 1 | 176549 | 12225 | 2.516873 | 2 |
| 2 | 176549 | 2025 | 2.416028 | 3 |
| 3 | 176549 | 11749 | 2.410308 | 4 |
| 4 | 176549 | 14120 | 2.356824 | 5 |
Configs for transformer models
from_config, from_params, get_config and get_params methods are fully available for transformers just like for any other models.
[2]:
config = {
"epochs": 2,
"n_blocks": 1,
"n_heads": 1,
"n_factors": 64,
}
model = SASRecModel.from_config(config)
model.get_params(simple_types=True)
[2]:
{'cls': 'SASRecModel',
'verbose': 0,
'data_preparator_type': 'rectools.models.nn.transformers.sasrec.SASRecDataPreparator',
'n_blocks': 1,
'n_heads': 1,
'n_factors': 64,
'use_pos_emb': True,
'use_causal_attn': True,
'use_key_padding_mask': False,
'dropout_rate': 0.2,
'session_max_len': 100,
'dataloader_num_workers': 0,
'batch_size': 128,
'loss': 'softmax',
'n_negatives': 1,
'gbce_t': 0.2,
'lr': 0.001,
'epochs': 2,
'deterministic': False,
'recommend_batch_size': 256,
'recommend_device': None,
'train_min_user_interactions': 2,
'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',
'rectools.models.nn.item_net.CatFeaturesItemNet'],
'item_net_constructor_type': 'rectools.models.nn.item_net.SumOfEmbeddingsConstructor',
'pos_encoding_type': 'rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding',
'transformer_layers_type': 'rectools.models.nn.transformers.sasrec.SASRecTransformerLayers',
'lightning_module_type': 'rectools.models.nn.transformers.lightning.TransformerLightningModule',
'get_val_mask_func': None,
'get_trainer_func': None,
'data_preparator_kwargs': None,
'transformer_layers_kwargs': None,
'item_net_constructor_kwargs': None,
'pos_encoding_kwargs': None,
'lightning_module_kwargs': None}
Transformer models in RecTools may accept functions and classes as arguments. These types of arguments are fully compatible with RecTools configs. You can eigther pass them as python objects or as strings that define their import paths.
Note that you shouldn’t change code for those functions and classes if you want to have reproducible config and correct model saving and loading.
Below is an example of both approaches to pass them to configs:
[3]:
config = {
"get_val_mask_func": get_val_mask_func, # function to get validation mask
"get_trainer_func": get_custom_trainer, # function to get custom trainer
# path to transformer layers class:
"transformer_layers_type": "rectools.models.nn.transformers.sasrec.SASRecTransformerLayers",
}
model = SASRecModel.from_config(config)
model.get_params(simple_types=True)
[3]:
{'cls': 'SASRecModel',
'verbose': 0,
'data_preparator_type': 'rectools.models.nn.transformers.sasrec.SASRecDataPreparator',
'n_blocks': 2,
'n_heads': 4,
'n_factors': 256,
'use_pos_emb': True,
'use_causal_attn': True,
'use_key_padding_mask': False,
'dropout_rate': 0.2,
'session_max_len': 100,
'dataloader_num_workers': 0,
'batch_size': 128,
'loss': 'softmax',
'n_negatives': 1,
'gbce_t': 0.2,
'lr': 0.001,
'epochs': 3,
'deterministic': False,
'recommend_batch_size': 256,
'recommend_device': None,
'train_min_user_interactions': 2,
'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',
'rectools.models.nn.item_net.CatFeaturesItemNet'],
'item_net_constructor_type': 'rectools.models.nn.item_net.SumOfEmbeddingsConstructor',
'pos_encoding_type': 'rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding',
'transformer_layers_type': 'rectools.models.nn.transformers.sasrec.SASRecTransformerLayers',
'lightning_module_type': 'rectools.models.nn.transformers.lightning.TransformerLightningModule',
'get_val_mask_func': '__main__.get_val_mask_func',
'get_trainer_func': '__main__.get_custom_trainer',
'data_preparator_kwargs': None,
'transformer_layers_kwargs': None,
'item_net_constructor_kwargs': None,
'pos_encoding_kwargs': None,
'lightning_module_kwargs': None}
Note that if you didn’t pass custom get_trainer_func, you can still replace default trainer after model initialization. But this way custom trainer will not be saved with the model and will not appear in model config and params.
[37]:
model._trainer = trainer
Multi-gpu training
RecTools models use PyTorch Lightning to handle multi-gpu training. Please refer to Lightning documentation for details. We do not cover it in this guide.