TransformerModelConfig

class rectools.models.nn.transformers.base.TransformerModelConfig(*, cls: ~typing.Optional[~typing.Type[~rectools.models.base.ModelBase]] = None, verbose: int = 0, data_preparator_type: ~typing.Type[~rectools.models.nn.transformers.data_preparator.TransformerDataPreparatorBase], n_blocks: int = 2, n_heads: int = 4, n_factors: int = 256, use_pos_emb: bool = True, use_causal_attn: bool = False, use_key_padding_mask: bool = False, dropout_rate: float = 0.2, session_max_len: int = 100, dataloader_num_workers: int = 0, batch_size: int = 128, loss: str = 'softmax', n_negatives: int = 1, gbce_t: float = 0.2, lr: float = 0.001, epochs: int = 3, deterministic: bool = False, recommend_batch_size: int = 256, recommend_torch_device: ~typing.Optional[str] = None, train_min_user_interactions: int = 2, item_net_block_types: ~typing.Sequence[~typing.Type[~rectools.models.nn.item_net.ItemNetBase]] = (<class 'rectools.models.nn.item_net.IdEmbeddingsItemNet'>, <class 'rectools.models.nn.item_net.CatFeaturesItemNet'>), item_net_constructor_type: ~typing.Type[~rectools.models.nn.item_net.ItemNetConstructorBase] = <class 'rectools.models.nn.item_net.SumOfEmbeddingsConstructor'>, pos_encoding_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.PositionalEncodingBase] = <class 'rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding'>, transformer_layers_type: ~typing.Type[~rectools.models.nn.transformers.net_blocks.TransformerLayersBase] = <class 'rectools.models.nn.transformers.net_blocks.PreLNTransformerLayers'>, lightning_module_type: ~typing.Type[~rectools.models.nn.transformers.lightning.TransformerLightningModuleBase] = <class 'rectools.models.nn.transformers.lightning.TransformerLightningModule'>, negative_sampler_type: ~typing.Type[~rectools.models.nn.transformers.negative_sampler.TransformerNegativeSamplerBase] = <class 'rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler'>, similarity_module_type: ~typing.Type[~rectools.models.nn.transformers.similarity.SimilarityModuleBase] = <class 'rectools.models.nn.transformers.similarity.DistanceSimilarityModule'>, backbone_type: ~typing.Type[~rectools.models.nn.transformers.torch_backbone.TransformerBackboneBase] = <class 'rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone'>, get_val_mask_func: ~typing.Optional[~collections.abc.Callable[[...], ~numpy.ndarray]] = None, get_trainer_func: ~typing.Optional[~collections.abc.Callable[[...], ~pytorch_lightning.trainer.trainer.Trainer]] = None, get_val_mask_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, get_trainer_func_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, data_preparator_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, transformer_layers_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, item_net_constructor_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, pos_encoding_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, lightning_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, negative_sampler_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, similarity_module_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None, backbone_kwargs: ~typing.Optional[~typing.Dict[str, ~typing.Any]] = None)[source]

Bases: ModelConfig

Transformer model base config.

Inherited-members

Parameters
  • cls (Optional[Type[ModelBase]]) –

  • verbose (int) –

  • data_preparator_type (Type[TransformerDataPreparatorBase]) –

  • n_blocks (int) –

  • n_heads (int) –

  • n_factors (int) –

  • use_pos_emb (bool) –

  • use_causal_attn (bool) –

  • use_key_padding_mask (bool) –

  • dropout_rate (float) –

  • session_max_len (int) –

  • dataloader_num_workers (int) –

  • batch_size (int) –

  • loss (str) –

  • n_negatives (int) –

  • gbce_t (float) –

  • lr (float) –

  • epochs (int) –

  • deterministic (bool) –

  • recommend_batch_size (int) –

  • recommend_torch_device (Optional[str]) –

  • train_min_user_interactions (int) –

  • item_net_block_types (Sequence[Type[ItemNetBase]]) –

  • item_net_constructor_type (Type[ItemNetConstructorBase]) –

  • pos_encoding_type (Type[PositionalEncodingBase]) –

  • transformer_layers_type (Type[TransformerLayersBase]) –

  • lightning_module_type (Type[TransformerLightningModuleBase]) –

  • negative_sampler_type (Type[TransformerNegativeSamplerBase]) –

  • similarity_module_type (Type[SimilarityModuleBase]) –

  • backbone_type (Type[TransformerBackboneBase]) –

  • get_val_mask_func (Optional[Callable[[...], ndarray]]) –

  • get_trainer_func (Optional[Callable[[...], Trainer]]) –

  • get_val_mask_func_kwargs (Optional[Dict[str, Any]]) –

  • get_trainer_func_kwargs (Optional[Dict[str, Any]]) –

  • data_preparator_kwargs (Optional[Dict[str, Any]]) –

  • transformer_layers_kwargs (Optional[Dict[str, Any]]) –

  • item_net_constructor_kwargs (Optional[Dict[str, Any]]) –

  • pos_encoding_kwargs (Optional[Dict[str, Any]]) –

  • lightning_module_kwargs (Optional[Dict[str, Any]]) –

  • negative_sampler_kwargs (Optional[Dict[str, Any]]) –

  • similarity_module_kwargs (Optional[Dict[str, Any]]) –

  • backbone_kwargs (Optional[Dict[str, Any]]) –

Methods

construct([_fields_set])

copy(*[, include, exclude, update, deep])

Returns a copy of the model.

dict(*[, include, exclude, by_alias, ...])

from_orm(obj)

json(*[, include, exclude, by_alias, ...])

model_construct([_fields_set])

Creates a new instance of the Model class with validated data.

model_copy(*[, update, deep])

!!! abstract "Usage Documentation"

model_dump(*[, mode, include, exclude, ...])

!!! abstract "Usage Documentation"

model_dump_json(*[, indent, ensure_ascii, ...])

!!! abstract "Usage Documentation"

model_json_schema([by_alias, ref_template, ...])

Generates a JSON schema for a model class.

model_parametrized_name(params)

Compute the class name for parametrizations of generic classes.

model_post_init(context, /)

Override this method to perform additional initialization after __init__ and model_construct.

model_rebuild(*[, force, raise_errors, ...])

Try to rebuild the pydantic-core schema for the model.

model_validate(obj, *[, strict, extra, ...])

Validate a pydantic model instance.

model_validate_json(json_data, *[, strict, ...])

!!! abstract "Usage Documentation"

model_validate_strings(obj, *[, strict, ...])

Validate the given object with string data against the Pydantic model.

parse_file(path, *[, content_type, ...])

parse_obj(obj)

parse_raw(b, *[, content_type, encoding, ...])

schema([by_alias, ref_template])

schema_json(*[, by_alias, ref_template])

update_forward_refs(**localns)

validate(value)

Attributes

model_computed_fields

model_config

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_extra

Get extra fields set during validation.

model_fields

model_fields_set

Returns the set of fields that have been explicitly set on this model instance.

data_preparator_type

n_blocks

n_heads

n_factors

use_pos_emb

use_causal_attn

use_key_padding_mask

dropout_rate

session_max_len

dataloader_num_workers

batch_size

loss

n_negatives

gbce_t

lr

epochs

verbose

deterministic

recommend_batch_size

recommend_torch_device

train_min_user_interactions

item_net_block_types

item_net_constructor_type

pos_encoding_type

transformer_layers_type

lightning_module_type

negative_sampler_type

similarity_module_type

backbone_type

get_val_mask_func

get_trainer_func

get_val_mask_func_kwargs

get_trainer_func_kwargs

data_preparator_kwargs

transformer_layers_kwargs

item_net_constructor_kwargs

pos_encoding_kwargs

lightning_module_kwargs

negative_sampler_kwargs

similarity_module_kwargs

backbone_kwargs

model_config: ClassVar[ConfigDict] = {'extra': 'forbid'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].