HSTUModelConfig
- class rectools.models.nn.transformers.hstu.HSTUModelConfig(*, cls: ~typing.Optional[~typing.Type[~rectools.models.base.ModelBase]] = None, verbose: int = 0, data_preparator_type: ~typing.Type[~rectools.models.nn.transformers.data_preparator.TransformerDataPreparatorBase] = <class 'rectools.models.nn.transformers.sasrec.SASRecDataPreparator'>, n_blocks: int = 2, n_heads: int = 4, n_factors: int = 256, use_pos_emb: bool = True, use_causal_attn: bool = True, use_key_padding_mask: bool = False, dropout_rate: float = 0.2, session_max_len: int = 100, dataloader_num_workers: int = 0, batch_size: int = 128, loss: str = 'softmax', n_negatives: int = 1, gbce_t: float = 0.2, lr: float = 0.001, epochs: int = 3, deterministic: bool = False, recommend_batch_size: int = 256, recommend_torch_device: ~typing.Optional[str] = None, train_min_user_interactions: int = 2, item_net_block_types: ~typing.Sequence[~typing.Type[~rectools.models.nn.item_net.ItemNetBase]] = (<class 'rectools.models.nn.item_net.IdEmbeddingsItemNet'>, <class 'rectools.models.nn.item_net.CatFeaturesItemNet'>), item_net_constructor_type: ~typing.Type[~rectools.models.nn.item_net.ItemNetConstructorBase] = <class 'rectools.models.nn.item_net.SumOfEmbeddingsConstructor'>, pos_encoding_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.PositionalEncodingBase] = <class 'rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding'>, transformer_layers_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.TransformerLayersBase] = <class 'rectools.models.nn.transformers.hstu.STULayers'>, lightning_module_type: ~typing.Type[~rectools.models.nn.transformers.lightning.TransformerLightningModuleBase] = <class 'rectools.models.nn.transformers.lightning.TransformerLightningModule'>, negative_sampler_type: ~typing.Type[~rectools.models.nn.transformers.negative_sampler.TransformerNegativeSamplerBase] = <class 'rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler'>, similarity_module_type: ~typing.Type[~rectools.models.nn.transformers.similarity.SimilarityModuleBase] = <class 'rectools.models.nn.transformers.similarity.DistanceSimilarityModule'>, backbone_type: ~typing.Type[~rectools.models.nn.transformers.torch_backbone.TransformerBackboneBase] = <class 'rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone'>, get_val_mask_func: ~typing.Optional[~collections.abc.Callable[[...], ~numpy.ndarray]] = None, get_trainer_func: ~typing.Optional[~collections.abc.Callable[[...], ~pytorch_lightning.trainer.trainer.Trainer]] = None, get_val_mask_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, get_trainer_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, data_preparator_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, transformer_layers_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, item_net_constructor_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, pos_encoding_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, lightning_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, negative_sampler_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, similarity_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, backbone_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, relative_time_attention: bool = True, relative_pos_attention: bool = True)[source]
Bases:
TransformerModelConfigHSTU model config.
- Inherited-members
- Parameters
cls (Optional[Type[ModelBase]]) –
verbose (int) –
data_preparator_type (Type[TransformerDataPreparatorBase]) –
n_blocks (int) –
n_heads (int) –
n_factors (int) –
use_pos_emb (bool) –
use_causal_attn (bool) –
use_key_padding_mask (bool) –
dropout_rate (float) –
session_max_len (int) –
dataloader_num_workers (int) –
batch_size (int) –
loss (str) –
n_negatives (int) –
gbce_t (float) –
lr (float) –
epochs (int) –
deterministic (bool) –
recommend_batch_size (int) –
recommend_torch_device (Optional[str]) –
train_min_user_interactions (int) –
item_net_block_types (Sequence[Type[ItemNetBase]]) –
item_net_constructor_type (Type[ItemNetConstructorBase]) –
pos_encoding_type (Type[PositionalEncodingBase]) –
transformer_layers_type (Type[TransformerLayersBase]) –
lightning_module_type (Type[TransformerLightningModuleBase]) –
negative_sampler_type (Type[TransformerNegativeSamplerBase]) –
similarity_module_type (Type[SimilarityModuleBase]) –
backbone_type (Type[TransformerBackboneBase]) –
get_val_mask_func (Optional[Callable[[...], ndarray]]) –
get_trainer_func (Optional[Callable[[...], Trainer]]) –
get_val_mask_func_kwargs (Optional[Dict[str, Any]]) –
get_trainer_func_kwargs (Optional[Dict[str, Any]]) –
data_preparator_kwargs (Optional[Dict[str, Any]]) –
transformer_layers_kwargs (Optional[Dict[str, Any]]) –
item_net_constructor_kwargs (Optional[Dict[str, Any]]) –
pos_encoding_kwargs (Optional[Dict[str, Any]]) –
lightning_module_kwargs (Optional[Dict[str, Any]]) –
negative_sampler_kwargs (Optional[Dict[str, Any]]) –
similarity_module_kwargs (Optional[Dict[str, Any]]) –
backbone_kwargs (Optional[Dict[str, Any]]) –
relative_time_attention (bool) –
relative_pos_attention (bool) –
Methods
construct([_fields_set])copy(*[, include, exclude, update, deep])Returns a copy of the model.
dict(*[, include, exclude, by_alias, ...])from_orm(obj)json(*[, include, exclude, by_alias, ...])model_construct([_fields_set])Creates a new instance of the Model class with validated data.
model_copy(*[, update, deep])!!! abstract "Usage Documentation"
model_dump(*[, mode, include, exclude, ...])!!! abstract "Usage Documentation"
model_dump_json(*[, indent, ensure_ascii, ...])!!! abstract "Usage Documentation"
model_json_schema([by_alias, ref_template, ...])Generates a JSON schema for a model class.
model_parametrized_name(params)Compute the class name for parametrizations of generic classes.
model_post_init(context, /)Override this method to perform additional initialization after __init__ and model_construct.
model_rebuild(*[, force, raise_errors, ...])Try to rebuild the pydantic-core schema for the model.
model_validate(obj, *[, strict, extra, ...])Validate a pydantic model instance.
model_validate_json(json_data, *[, strict, ...])!!! abstract "Usage Documentation"
model_validate_strings(obj, *[, strict, ...])Validate the given object with string data against the Pydantic model.
parse_file(path, *[, content_type, ...])parse_obj(obj)parse_raw(b, *[, content_type, encoding, ...])schema([by_alias, ref_template])schema_json(*[, by_alias, ref_template])update_forward_refs(**localns)validate(value)Attributes
model_computed_fieldsConfiguration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
model_extraGet extra fields set during validation.
model_fieldsmodel_fields_setReturns the set of fields that have been explicitly set on this model instance.
data_preparator_typetransformer_layers_typeuse_causal_attnrelative_time_attentionrelative_pos_attention- model_config: ClassVar[ConfigDict] = {'extra': 'forbid'}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].