HSTUModel
- class rectools.models.nn.transformers.hstu.HSTUModel(n_blocks: int = 2, n_heads: int = 4, n_factors: int = 256, dropout_rate: float = 0.2, session_max_len: int = 100, train_min_user_interactions: int = 2, loss: str = 'softmax', n_negatives: int = 1, gbce_t: float = 0.2, lr: float = 0.001, batch_size: int = 128, epochs: int = 3, deterministic: bool = False, verbose: int = 0, dataloader_num_workers: int = 0, use_pos_emb: bool = True, use_key_padding_mask: bool = False, use_causal_attn: bool = True, relative_time_attention: bool = True, relative_pos_attention: bool = True, item_net_block_types: ~typing.Sequence[~typing.Type[~rectools.models.nn.item_net.ItemNetBase]] = (<class 'rectools.models.nn.item_net.IdEmbeddingsItemNet'>, <class 'rectools.models.nn.item_net.CatFeaturesItemNet'>), item_net_constructor_type: ~typing.Type[~rectools.models.nn.item_net.ItemNetConstructorBase] = <class 'rectools.models.nn.item_net.SumOfEmbeddingsConstructor'>, pos_encoding_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.PositionalEncodingBase] = <class 'rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding'>, transformer_layers_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.TransformerLayersBase] = <class 'rectools.models.nn.transformers.hstu.STULayers'>, data_preparator_type: ~typing.Type[~rectools.models.nn.transformers.data_preparator.TransformerDataPreparatorBase] = <class 'rectools.models.nn.transformers.sasrec.SASRecDataPreparator'>, lightning_module_type: ~typing.Type[~rectools.models.nn.transformers.lightning.TransformerLightningModuleBase] = <class 'rectools.models.nn.transformers.lightning.TransformerLightningModule'>, negative_sampler_type: ~typing.Type[~rectools.models.nn.transformers.negative_sampler.TransformerNegativeSamplerBase] = <class 'rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler'>, similarity_module_type: ~typing.Type[~rectools.models.nn.transformers.similarity.SimilarityModuleBase] = <class 'rectools.models.nn.transformers.similarity.DistanceSimilarityModule'>, backbone_type: ~typing.Type[~rectools.models.nn.transformers.torch_backbone.TransformerBackboneBase] = <class 'rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone'>, get_val_mask_func: ~typing.Optional[~collections.abc.Callable[[...], ~numpy.ndarray]] = None, get_trainer_func: ~typing.Optional[~collections.abc.Callable[[...], ~pytorch_lightning.trainer.trainer.Trainer]] = None, get_val_mask_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, get_trainer_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, recommend_batch_size: int = 256, recommend_torch_device: ~typing.Optional[str] = None, recommend_use_torch_ranking: bool = True, recommend_n_threads: int = 0, data_preparator_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, transformer_layers_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, item_net_constructor_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, pos_encoding_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, lightning_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, negative_sampler_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, similarity_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, backbone_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None)[source]
Bases:
TransformerModelBase[HSTUModelConfig]HSTU model: transformer-based sequential model with unidirectional pointwise aggregated attention mechanism, combined with “Shifted Sequence” training objective. Our implementation covers multiple loss functions and a variable number of negatives for them.
References
HSTU tutorial: https://rectools.readthedocs.io/en/stable/examples/tutorials/transformers_HSTU_tutorial.html Original paper: https://arxiv.org/abs/2402.17152
- Parameters
n_blocks (int, default 2) – Number of transformer blocks.
n_heads (int, default 4) – Number of attention heads.
n_factors (int, default 256) – Latent embeddings size.
dropout_rate (float, default 0.2) – Probability of a hidden unit to be zeroed.
session_max_len (int, default 100) – Maximum length of user sequence.
train_min_user_interactions (int, default 2) – Minimum number of interactions user should have to be used for training. Should be greater than 1.
loss ({"softmax", "BCE", "gBCE", "sampled_softmax"}, default "softmax") – Loss function.
n_negatives (int, default 1) – Number of negatives for BCE, gBCE and sampled_softmax losses.
gbce_t (float, default 0.2) – Calibration parameter for gBCE loss.
lr (float, default 0.001) – Learning rate.
batch_size (int, default 128) – How many samples per batch to load.
epochs (int, default 3) – Exact number of training epochs. Will be omitted if get_trainer_func is specified.
deterministic (bool, default
False) – deterministic flag passed to lightning trainer during initialization. Use pytorch_lightning.seed_everything together with this parameter to fix the random seed. Will be omitted if get_trainer_func is specified.verbose (int, default 0) – Verbosity level. Enables progress bar, model summary and logging in default lightning trainer when set to a positive integer. Will be omitted if get_trainer_func is specified.
dataloader_num_workers (int, default 0) – Number of loader worker processes.
use_pos_emb (bool, default
True) – IfTrue, learnable positional encoding will be added to session item embeddings.use_key_padding_mask (bool, default
False) – IfTrue, key_padding_mask will be added in Multi-head Attention.use_causal_attn (bool, default
True) – IfTrue, causal mask will be added as attn_mask in Multi-head Attention. Please note that default SASRec training task (“Shifted Sequence”) does not work without causal masking. Set this parameter toFalseonly when you change the training task with custom data_preparator_type or if you are absolutely sure of what you are doing.relative_time_attention (bool) – Whether to use relative time attention.
relative_pos_attention (bool) – Whether to use relative positional attention
item_net_block_types (sequence of type(ItemNetBase), default (IdEmbeddingsItemNet, CatFeaturesItemNet)) – Type of network returning item embeddings. (IdEmbeddingsItemNet,) - item embeddings based on ids. (CatFeaturesItemNet,) - item embeddings based on categorical features. (IdEmbeddingsItemNet, CatFeaturesItemNet) - item embeddings based on ids and categorical features.
item_net_constructor_type (type(ItemNetConstructorBase), default SumOfEmbeddingsConstructor) – Type of item net blocks aggregation constructor.
pos_encoding_type (type(PositionalEncodingBase), default LearnableInversePositionalEncoding) – Type of positional encoding.
transformer_layers_type (type(TransformerLayersBase), default STULayers) – Type of transformer layers architecture.
data_preparator_type (type(TransformerDataPreparatorBase), default HSTUDataPreparator) – Type of data preparator used for dataset processing and dataloader creation.
lightning_module_type (type(TransformerLightningModuleBase), default TransformerLightningModule) – Type of lightning module defining training procedure.
negative_sampler_type (type(TransformerNegativeSamplerBase), default CatalogUniformSampler) – Type of negative sampler.
similarity_module_type (type(SimilarityModuleBase), default DistanceSimilarityModule) – Type of similarity module.
backbone_type (type(TransformerBackboneBase), default TransformerTorchBackbone) – Type of torch backbone.
get_val_mask_func (Callable, default
None) – Function to get validation mask.get_trainer_func (Callable, default
None) – Function for get custom lightning trainer. If get_trainer_func is None, default trainer will be created based on epochs, deterministic and verbose argument values. Model will be trained for the exact number of epochs. Checkpointing will be disabled. If you want to assign custom trainer after model is initialized, you can manually assign new value to model _trainer attribute.recommend_batch_size (int, default 256) – How many samples per batch to load during recommend. If you want to change this parameter after model is initialized, you can manually assign new value to model recommend_batch_size attribute.
recommend_torch_device ({“cpu”, “cuda”, “cuda:0”, …}, default
None) – String representation for torch.device used for model inference. When set toNone, “cuda” will be used if it is available, “cpu” otherwise. If you want to change this parameter after model is initialized, you can manually assign new value to model recommend_torch_device attribute.get_val_mask_func_kwargs (optional(InitKwargs), default
None) – Additional keyword arguments for the get_val_mask_func. Make sure all dict values have JSON serializable types.get_trainer_func_kwargs (optional(InitKwargs), default
None) – Additional keyword arguments for the get_trainer_func. Make sure all dict values have JSON serializable types.data_preparator_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during data_preparator_type initialization. Make sure all dict values have JSON serializable types.transformer_layers_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during transformer_layers_type initialization. Make sure all dict values have JSON serializable types.optional(dict) (item_net_constructor_kwargs) – Additional keyword arguments to pass during item_net_constructor_type initialization. Make sure all dict values have JSON serializable types.
None (default) – Additional keyword arguments to pass during item_net_constructor_type initialization. Make sure all dict values have JSON serializable types.
pos_encoding_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during pos_encoding_type initialization. Make sure all dict values have JSON serializable types.lightning_module_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during lightning_module_type initialization. Make sure all dict values have JSON serializable types.negative_sampler_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during negative_sampler_type initialization. Make sure all dict values have JSON serializable types.similarity_module_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during similarity_module_type initialization. Make sure all dict values have JSON serializable types.backbone_kwargs (optional(dict), default
None) – Additional keyword arguments to pass during backbone_type initialization. Make sure all dict values have JSON serializable types. Let’s add comment about our changes for default module kwargs:model (To precisely follow the original authors implementations of the) –
recommend_use_torch_ranking (bool) –
recommend_n_threads (int) –
item_net_constructor_kwargs (Optional[Dict[str, Any]]) –
:param : :param the following kwargs for specific modules will be replaced from their default versions: :param used in other Transformer models: :param 1)use_scale_factor in pos_encoding_kwargs will be set to True: :param 2)distance in similarity_module_kwargs will be set to cosine: if not explicitly provided as others options
- Inherited-members
- Parameters
n_blocks (int) –
n_heads (int) –
n_factors (int) –
dropout_rate (float) –
session_max_len (int) –
train_min_user_interactions (int) –
loss (str) –
n_negatives (int) –
gbce_t (float) –
lr (float) –
batch_size (int) –
epochs (int) –
deterministic (bool) –
verbose (int) –
dataloader_num_workers (int) –
use_pos_emb (bool) –
use_key_padding_mask (bool) –
use_causal_attn (bool) –
relative_time_attention (bool) –
relative_pos_attention (bool) –
item_net_block_types (Sequence[Type[ItemNetBase]]) –
item_net_constructor_type (Type[ItemNetConstructorBase]) –
pos_encoding_type (Type[PositionalEncodingBase]) –
transformer_layers_type (Type[TransformerLayersBase]) –
data_preparator_type (Type[TransformerDataPreparatorBase]) –
lightning_module_type (Type[TransformerLightningModuleBase]) –
negative_sampler_type (Type[TransformerNegativeSamplerBase]) –
similarity_module_type (Type[SimilarityModuleBase]) –
backbone_type (Type[TransformerBackboneBase]) –
get_val_mask_func (Optional[Callable[[...], ndarray]]) –
get_trainer_func (Optional[Callable[[...], Trainer]]) –
get_val_mask_func_kwargs (Optional[Dict[str, Any]]) –
get_trainer_func_kwargs (Optional[Dict[str, Any]]) –
recommend_batch_size (int) –
recommend_torch_device (Optional[str]) –
recommend_use_torch_ranking (bool) –
recommend_n_threads (int) –
data_preparator_kwargs (Optional[Dict[str, Any]]) –
transformer_layers_kwargs (Optional[Dict[str, Any]]) –
item_net_constructor_kwargs (Optional[Dict[str, Any]]) –
pos_encoding_kwargs (Optional[Dict[str, Any]]) –
lightning_module_kwargs (Optional[Dict[str, Any]]) –
negative_sampler_kwargs (Optional[Dict[str, Any]]) –
similarity_module_kwargs (Optional[Dict[str, Any]]) –
backbone_kwargs (Optional[Dict[str, Any]]) –
Methods
dumps()Serialize model to bytes.
fit(dataset, *args, **kwargs)Fit model.
fit_partial(dataset, *args, **kwargs)Fit model.
from_config(config)Create model from config.
from_params(params[, sep])Create model from parameters.
get_config([mode, simple_types])Return model config.
get_params([simple_types, sep])Return model parameters.
load(f)Load model from file.
load_from_checkpoint(checkpoint_path[, ...])Load model from Lightning checkpoint path.
load_weights_from_checkpoint(checkpoint_path)Load model weights from Lightning checkpoint path.
loads(data)Load model from bytes.
recommend(users, dataset, k, filter_viewed)Recommend items for users.
recommend_to_items(target_items, dataset, k)Recommend items for target items.
save(f)Save model to file.
Attributes
recommends_for_coldrecommends_for_warmIndicates whether the model requires context for accurate recommendations.
torch_modelPytorch model.
train_loss_nameval_loss_name- config_class
alias of
HSTUModelConfig
- property require_recommend_context: bool
Indicates whether the model requires context for accurate recommendations.
bool