From 084a1169057fe178e865c98f8d9a16a377d38534 Mon Sep 17 00:00:00 2001 From: Tiny_Snow Date: Sun, 25 Jan 2026 14:16:06 +0800 Subject: [PATCH 1/2] fix: correct docs in main_seqrec.py --- src/genrec/main_seqrec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genrec/main_seqrec.py b/src/genrec/main_seqrec.py index 48634da..00e1276 100644 --- a/src/genrec/main_seqrec.py +++ b/src/genrec/main_seqrec.py @@ -9,7 +9,7 @@ save_predictions: false # whether to save the predictions on the test set dataset: type: seqrec - interaction_data_path: data/movielens/train.parquet + interaction_data_path: /path/to/movielens-1m/proc/user2item.pkl ... # dataset-specific parameters collator: type: seqrec From b1552514a19b62d1f512cf82ff3ab9b3dfaea549 Mon Sep 17 00:00:00 2001 From: Tiny_Snow Date: Tue, 27 Jan 2026 21:03:29 +0800 Subject: [PATCH 2/2] feat: construct quantizer pipeline and rqvae model/trainer --- .gitignore | 2 + scripts/configs/quantizer/rqvae.yaml | 75 ++++ scripts/configs/quantizer/template.yaml | 103 ++++++ scripts/configs/seqrec/template.yaml | 4 +- src/genrec/datasets/__init__.py | 10 + src/genrec/datasets/base.py | 4 + src/genrec/datasets/dataset_genrec.py | 17 + src/genrec/datasets/dataset_quantizer.py | 72 +++- src/genrec/datasets/dataset_seqrec.py | 27 +- src/genrec/datasets/modules/lm_encoders.py | 14 +- src/genrec/main_quantizer.py | 298 ++++++++++++++++ src/genrec/main_seqrec.py | 2 +- src/genrec/models/__init__.py | 18 +- src/genrec/models/model_quantizer/__init__.py | 32 ++ src/genrec/models/model_quantizer/base.py | 323 +++++++++++++++++ src/genrec/models/model_quantizer/rqvae.py | 268 ++++++++++++++ src/genrec/models/model_seqrec/base.py | 142 ++++---- src/genrec/models/model_seqrec/hstu.py | 4 +- src/genrec/models/model_seqrec/hstu_spring.py | 6 +- src/genrec/models/model_seqrec/sasrec.py | 2 +- src/genrec/models/modules/__init__.py | 3 +- src/genrec/models/modules/attention.py | 6 +- src/genrec/models/modules/feedforward.py | 45 +++ src/genrec/models/modules/layers.py | 16 +- src/genrec/models/modules/posemb.py | 4 +- src/genrec/trainers/__init__.py | 14 +- .../trainers/trainer_quantizer/__init__.py | 24 ++ src/genrec/trainers/trainer_quantizer/base.py | 335 ++++++++++++++++++ .../trainers/trainer_quantizer/rqvae.py | 82 +++++ .../trainer_quantizer/utils/__init__.py | 18 + .../trainer_quantizer/utils/callbacks.py | 68 ++++ .../trainer_quantizer/utils/evaluations.py | 184 ++++++++++ src/genrec/trainers/trainer_seqrec/base.py | 163 +++------ .../trainers/trainer_seqrec/bce_d2lr.py | 2 +- .../trainers/trainer_seqrec/bce_dros.py | 2 +- .../trainers/trainer_seqrec/bce_logdet.py | 3 +- .../trainers/trainer_seqrec/bce_r2rec.py | 2 +- .../trainers/trainer_seqrec/bce_resn.py | 2 +- src/genrec/trainers/trainer_seqrec/sl.py | 2 +- src/genrec/trainers/trainer_seqrec/sl_d2lr.py | 2 +- src/genrec/trainers/trainer_seqrec/sl_dros.py | 2 +- .../trainers/trainer_seqrec/sl_logdet.py | 2 +- .../trainers/trainer_seqrec/sl_r2rec.py | 2 +- src/genrec/trainers/trainer_seqrec/sl_resn.py | 2 +- .../trainers/trainer_seqrec/utils/__init__.py | 19 + .../{ => trainer_seqrec}/utils/callbacks.py | 2 +- .../{ => trainer_seqrec}/utils/evaluations.py | 108 +++++- src/genrec/trainers/utils/__init__.py | 17 - tests/datasets/test_dataset.py | 63 +++- tests/models/model_quantizer/__init__.py | 1 + tests/models/model_quantizer/test_base.py | 46 +++ tests/models/model_quantizer/test_rqvae.py | 253 +++++++++++++ tests/trainers/trainer_quantizer/__init__.py | 1 + .../trainer_quantizer/test_evaluations.py | 93 +++++ .../trainers/trainer_quantizer/test_rqvae.py | 93 +++++ .../trainer_quantizer/utils/__init__.py | 1 + .../trainer_quantizer/utils/test_base.py | 252 +++++++++++++ .../trainer_quantizer/utils/test_callbacks.py | 64 ++++ tests/trainers/trainer_seqrec/__init__.py | 1 + tests/trainers/trainer_seqrec/test_base.py | 4 +- .../trainers/trainer_seqrec/utils/__init__.py | 1 + .../utils/test_callbacks.py | 2 +- .../utils/test_evaluations.py | 8 +- 63 files changed, 3140 insertions(+), 297 deletions(-) create mode 100644 scripts/configs/quantizer/rqvae.yaml create mode 100644 scripts/configs/quantizer/template.yaml create mode 100644 src/genrec/models/model_quantizer/base.py create mode 100644 src/genrec/models/model_quantizer/rqvae.py create mode 100644 src/genrec/trainers/trainer_quantizer/__init__.py create mode 100644 src/genrec/trainers/trainer_quantizer/base.py create mode 100644 src/genrec/trainers/trainer_quantizer/rqvae.py create mode 100644 src/genrec/trainers/trainer_quantizer/utils/__init__.py create mode 100644 src/genrec/trainers/trainer_quantizer/utils/callbacks.py create mode 100644 src/genrec/trainers/trainer_quantizer/utils/evaluations.py create mode 100644 src/genrec/trainers/trainer_seqrec/utils/__init__.py rename src/genrec/trainers/{ => trainer_seqrec}/utils/callbacks.py (98%) rename src/genrec/trainers/{ => trainer_seqrec}/utils/evaluations.py (62%) delete mode 100644 src/genrec/trainers/utils/__init__.py create mode 100644 tests/models/model_quantizer/__init__.py create mode 100644 tests/models/model_quantizer/test_base.py create mode 100644 tests/models/model_quantizer/test_rqvae.py create mode 100644 tests/trainers/trainer_quantizer/__init__.py create mode 100644 tests/trainers/trainer_quantizer/test_evaluations.py create mode 100644 tests/trainers/trainer_quantizer/test_rqvae.py create mode 100644 tests/trainers/trainer_quantizer/utils/__init__.py create mode 100644 tests/trainers/trainer_quantizer/utils/test_base.py create mode 100644 tests/trainers/trainer_quantizer/utils/test_callbacks.py create mode 100644 tests/trainers/trainer_seqrec/utils/__init__.py rename tests/trainers/{ => trainer_seqrec}/utils/test_callbacks.py (96%) rename tests/trainers/{ => trainer_seqrec}/utils/test_evaluations.py (92%) diff --git a/.gitignore b/.gitignore index 45d15ac..94bcddc 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,8 @@ __pycache__/ # Data, scripts & artifacts data data/ +.hf-models +.hf-models/ outputs/ logs/ scripts_running/ diff --git a/scripts/configs/quantizer/rqvae.yaml b/scripts/configs/quantizer/rqvae.yaml new file mode 100644 index 0000000..f74c06b --- /dev/null +++ b/scripts/configs/quantizer/rqvae.yaml @@ -0,0 +1,75 @@ +# An example config file for RQ-VAE on Movielens-1M dataset with grid search +# If you want to add hyperparameter search space, use the "search__" prefix before the parameter name. + +# global settings +pretrained_ckpt: null # optional path to a pretrained checkpoint to load +save_predictions: true # whether to save the predictions on the test set + +# dataset settings +dataset: + type: quantizer + + interaction_data_path: /path/to/movielens-1m/proc/user2item.pkl # TODO: path to interaction data file + textual_data_path: /path/to/movielens-1m/proc/item2title.pkl # TODO: path to textual data file + + lm_encoder_type: sentence_t5 + lm_encoder_path: /path/to/sentence-transformers/sentence-t5-base # TODO: path to pretrained language model encoder + +# collator settings +collator: + type: quantizer + +# model settings +model: + type: rqvae + + config: + # base model parameters + hidden_sizes: [512, 256, 128] + num_codebooks: 3 + codebook_size: 256 + codebook_dim: 32 + + # subclass model parameters + kmeans_init: true + kmeans_max_iter: 10 + +# trainer settings +trainer: + type: rqvae + + config: + # training arguments - Run control + do_train: true + do_eval: true + do_predict: true + + # training arguments - Optimization & schedule + num_train_epochs: 20000 + per_device_train_batch_size: 1024 + per_device_eval_batch_size: 1024 + gradient_accumulation_steps: 1 # batch_size = per_device_train_batch_size * num_devices * gradient_accumulation_steps + learning_rate: 1.0e-3 + weight_decay: 0.1 + lr_scheduler_type: linear + warmup_ratio: 0.05 + + # training arguments - Evaluation & checkpointing + metric_for_best_model: eval_loss # should exist in the metrics + greater_is_better: false # use loss to evaluate the best model + + # training arguments - Parallelism & precision + bf16: false + tf32: true + + # base trainer parameters + eval_interval: 100 # run metrics every epoch + train_stop_epoch: -1 # by default, do not stop training early + metrics: + - ["codebook_usage", {}] + - ["code_collision", {}] + codebook_loss_weight: 1.0 + commitment_loss_weight: 0.25 + model_loss_weight: 0.0 + + # subclass trainer parameters diff --git a/scripts/configs/quantizer/template.yaml b/scripts/configs/quantizer/template.yaml new file mode 100644 index 0000000..7402d48 --- /dev/null +++ b/scripts/configs/quantizer/template.yaml @@ -0,0 +1,103 @@ +# A template config file for quantizer training +# If you want to add hyperparameter search space, use the "search__" prefix before the parameter name. + +# global settings +seed: 42 +output_dir: null # TODO: output directory to save model checkpoints, logs, and results +pretrained_ckpt: null # optional path to a pretrained checkpoint to load +save_predictions: true # whether to save the predictions on the test set + +# dataset settings +dataset: + type: quantizer + + interaction_data_path: null # TODO: path to interaction data file + textual_data_path: null # TODO: path to textual data file + + lm_encoder_type: sentence_t5 + lm_encoder_path: null # TODO: path to pretrained language model encoder + + aux_item_embeddings_path: null # TODO: path to auxiliary item embeddings file (supposed to be .npy file) + +# collator settings +collator: + type: quantizer + + # no default parameters for quantizer collator at the moment + +# model settings +model: + type: rqvae + + config: + # base model parameters + hidden_sizes: [512, 256, 128] + num_codebooks: 3 + codebook_size: 256 + codebook_dim: 32 + + # subclass model parameters + kmeans_init: true + kmeans_max_iter: 10 + +# trainer settings +trainer: + type: rqvae + + config: + # training arguments - Run control + do_train: true + do_eval: true + do_predict: true + overwrite_output_dir: true + remove_unused_columns: false + + # training arguments - Optimization & schedule + num_train_epochs: 20000 + per_device_train_batch_size: 1024 + per_device_eval_batch_size: 1024 + gradient_accumulation_steps: 1 # batch_size = per_device_train_batch_size * num_devices * gradient_accumulation_steps + learning_rate: 1.0e-3 + weight_decay: 0.1 + max_grad_norm: 1.0 + optim: adamw_torch + lr_scheduler_type: linear + warmup_ratio: 0.05 + + # training arguments - Evaluation & checkpointing + eval_strategy: epoch + save_strategy: epoch + eval_delay: 0 # skip warmup + eval_accumulation_steps: 1 + save_total_limit: 1 # keep only the best checkpoint + load_best_model_at_end: true # load the best model when finished training + metric_for_best_model: eval_loss # should exist in the metrics + greater_is_better: false # use loss to evaluate the best model + prediction_loss_only: false + save_safetensors: true + + # training arguments - Parallelism & precision + dataloader_num_workers: 0 + dataloader_pin_memory: true + dataloader_drop_last: false + ddp_find_unused_parameters: true + ddp_broadcast_buffers: false + gradient_checkpointing: false + bf16: false + tf32: true + + # training arguments - Logging / tracking + logging_strategy: epoch + report_to: ["tensorboard"] + + # base trainer parameters + eval_interval: 100 # run metrics every epoch + train_stop_epoch: -1 # by default, do not stop training early + metrics: + - ["codebook_usage", {}] + - ["code_collision", {}] + codebook_loss_weight: 1.0 + commitment_loss_weight: 0.25 + model_loss_weight: 0.0 + + # subclass trainer parameters diff --git a/scripts/configs/seqrec/template.yaml b/scripts/configs/seqrec/template.yaml index bc5b4e6..1cec5e7 100644 --- a/scripts/configs/seqrec/template.yaml +++ b/scripts/configs/seqrec/template.yaml @@ -1,4 +1,4 @@ -# Example config file for sequence recommendation tasks +# A template config file for sequence recommendation tasks # If you want to add hyperparameter search space, use the "search__" prefix before the parameter name. # global settings @@ -50,7 +50,6 @@ trainer: # training arguments - Optimization & schedule num_train_epochs: 200 - train_stop_epoch: -1 # by default, do not stop training early per_device_train_batch_size: 512 per_device_eval_batch_size: 1024 gradient_accumulation_steps: 1 # batch_size = per_device_train_batch_size * num_devices * gradient_accumulation_steps @@ -90,6 +89,7 @@ trainer: # base trainer parameters norm_embeddings: false # whether to L2-normalize user and item embeddings eval_interval: 5 # run metrics every epoch + train_stop_epoch: -1 # by default, do not stop training early metrics: - ["hr", {}] - ["ndcg", {}] diff --git a/src/genrec/datasets/__init__.py b/src/genrec/datasets/__init__.py index ee92ed6..be787be 100644 --- a/src/genrec/datasets/__init__.py +++ b/src/genrec/datasets/__init__.py @@ -52,3 +52,13 @@ "SeqRecDataset", "SeqRecExample", ] + +from .modules import LMEncoder, LMEncoderFactory, NegativeSampler, NegativeSamplerFactory, PrefixTree + +__all__ += [ + "LMEncoder", + "LMEncoderFactory", + "NegativeSampler", + "NegativeSamplerFactory", + "PrefixTree", +] diff --git a/src/genrec/datasets/base.py b/src/genrec/datasets/base.py index 4f1ecb8..22ea71e 100644 --- a/src/genrec/datasets/base.py +++ b/src/genrec/datasets/base.py @@ -171,6 +171,7 @@ def __init__( sid_cache: Optional[Int[np.ndarray, "I+1 C"]] = None, textual_data_path: Optional[Union[pd.DataFrame, str, Path]] = None, lm_encoder: Optional[LMEncoder] = None, + **kwargs: Any, ) -> None: """Initialises the dataset and materialises user-level metadata. @@ -188,6 +189,7 @@ def __init__( pickle file with `ItemID` and `Title` columns. lm_encoder (Optional[LMEncoder]): Optional encoder used to transform item titles into dense embeddings. + **kwargs (Any): Additional keyword arguments for the dataset. """ if split not in { "train", @@ -489,6 +491,7 @@ def __init__( no_pad_keys: Dict[str, type], pad_values: Dict[str, np.generic], seed: int = 42, + **kwargs: Any, ) -> None: """Configures the collator. @@ -498,6 +501,7 @@ def __init__( pad_values (Dict[str, np.generic]): Padding values per field, e.g., {"field1": 0, "field2": -100}. If a field is missing, defaults to 0. seed (int): Random seed for the collator's internal RNG. + **kwargs (Any): Additional keyword arguments for the collator. """ SeedWorkerMixin.__init__(self, global_seed=seed) diff --git a/src/genrec/datasets/dataset_genrec.py b/src/genrec/datasets/dataset_genrec.py index 613e4f2..57a786e 100644 --- a/src/genrec/datasets/dataset_genrec.py +++ b/src/genrec/datasets/dataset_genrec.py @@ -77,6 +77,23 @@ def __init__( textual_data_path: Optional[Union[pd.DataFrame, str, Path]] = None, lm_encoder: Optional[LMEncoder] = None, ) -> None: + """Initialises the dataset and materialises user-level metadata. + + Args: + interaction_data_path (Union[pd.DataFrame, str, Path]): Pandas DataFrame or path to a + pickle file containing `UserID` and `ItemID` columns. We assume that the `UserID` + begins from 0 and that `ItemID` begins from 1, both being contiguous integers. The + `ItemID` of 0 is reserved for padding. + split (DatasetSplitLiteral): Dataset split controlling example generation strategy. + max_seq_length (int): Maximum length of interaction histories. + min_seq_length (int): Minimum length of interaction histories. + sid_cache (Optional[Int[np.ndarray, "I+1 C"]]): Optional mapping from item ID to SID + sequence, stored as numpy arrays. + textual_data_path (Optional[Union[pd.DataFrame, str, Path]]): Optional DataFrame or + pickle file with `ItemID` and `Title` columns. + lm_encoder (Optional[LMEncoder]): Optional encoder used to transform item titles into + dense embeddings. + """ super().__init__( interaction_data_path, split, diff --git a/src/genrec/datasets/dataset_quantizer.py b/src/genrec/datasets/dataset_quantizer.py index 5c0c4e0..386dce1 100644 --- a/src/genrec/datasets/dataset_quantizer.py +++ b/src/genrec/datasets/dataset_quantizer.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from jaxtyping import Float, Int import numpy as np @@ -39,10 +39,13 @@ class QuantizerExample(RecExample): Attributes: item_id: Identifier of the item. item_embedding: Dense embedding vector of the item. + aux_item_embedding: Optional auxiliary dense embedding vector of the item, + e.g., produced by SeqRec model. """ item_id: int item_embedding: Float[np.ndarray, "D"] + aux_item_embedding: Optional[Float[np.ndarray, "D_aux"]] = None @RecDatasetFactory.register("quantizer") @@ -52,31 +55,67 @@ class QuantizerDataset(RecDataset[QuantizerExample]): def __init__( self, interaction_data_path: Union[pd.DataFrame, str, Path], - split: DatasetSplitLiteral, - max_seq_length: int, - min_seq_length: int, - sid_cache: Optional[Int[np.ndarray, "I+1 C"]] = None, textual_data_path: Optional[Union[pd.DataFrame, str, Path]] = None, lm_encoder: Optional[LMEncoder] = None, + aux_item_embeddings: Optional[Float[np.ndarray, "I+1 D_aux"]] = None, + **kwargs: Any, ) -> None: - assert split == "train", "QuantizerDataset only supports the 'train' split." + """Initialises the dataset and materialises user-level metadata. + + Args: + interaction_data_path (Union[pd.DataFrame, str, Path]): Pandas DataFrame or path to a + pickle file containing `UserID` and `ItemID` columns. We assume that the `UserID` + begins from 0 and that `ItemID` begins from 1, both being contiguous integers. The + `ItemID` of 0 is reserved for padding. + textual_data_path (Optional[Union[pd.DataFrame, str, Path]]): Optional DataFrame or + pickle file with `ItemID` and `Title` columns. + lm_encoder (Optional[LMEncoder]): Optional encoder used to transform item titles into + dense embeddings. + aux_item_embeddings (Optional[Float[np.ndarray, "I+1 D_aux"]]): Optional auxiliary + item embeddings, e.g., produced by SeqRec model. + **kwargs (Any): Additional keyword arguments for the dataset. + """ + + self._aux_item_embeddings = aux_item_embeddings super().__init__( interaction_data_path, - split, - max_seq_length, - min_seq_length, - sid_cache, - textual_data_path, - lm_encoder, + DatasetSplitLiteral.TRAIN, + textual_data_path=textual_data_path, + lm_encoder=lm_encoder, + **kwargs, ) + if aux_item_embeddings is not None: + assert ( + aux_item_embeddings.shape[0] == self.item_size + 1 + ), "The number of auxiliary item embeddings must equal item_size + 1." + def _build_examples(self) -> List[QuantizerExample]: """Generates training examples for quantizer training.""" assert self._item_embeddings is not None, "Item embeddings are required for quantizer training." - examples = [QuantizerExample(i, self._item_embeddings[i]) for i in range(1, self.item_size + 1)] + examples = [ + QuantizerExample( + i, + self._item_embeddings[i], + self._aux_item_embeddings[i] if self._aux_item_embeddings is not None else None, + ) + for i in range(1, self.item_size + 1) + ] return examples + @property + def aux_item_embeddings(self) -> Optional[Float[np.ndarray, "I+1 D_aux"]]: + """Exposes the auxiliary item embeddings, if available.""" + return self._aux_item_embeddings + + @property + def aux_embedding_dim(self) -> Optional[int]: + """Returns the dimensionality of the auxiliary item embeddings, if available.""" + if self._aux_item_embeddings is None: # pragma: no cover - embedding absent + return None + return self._aux_item_embeddings.shape[1] + @RecCollatorConfigFactory.register("quantizer") @dataclass @@ -97,6 +136,8 @@ class QuantizerCollator(RecCollator[QuantizerExample]): Item IDs. item_embedding: `Float[torch.Tensor, "B D"]`. Item dense embeddings. + aux_item_embedding: `Optional[Float[torch.Tensor, "B D_aux"]]`. + Auxiliary item dense embeddings, if available. """ def __init__( @@ -104,6 +145,7 @@ def __init__( dataset: QuantizerDataset, config: Optional[QuantizerCollatorConfig] = None, seed: int = 42, + **kwargs: Any, ) -> None: """Configures the collator. @@ -111,6 +153,7 @@ def __init__( dataset (QuantizerDataset): Dataset split from which examples are drawn. config (Optional[QuantizerCollatorConfig]): Optional collator configuration instance. seed (int): Random seed for the collator's internal RNG. + **kwargs (Any): Additional keyword arguments for the collator. """ self._config = config or QuantizerCollatorConfig() @@ -118,7 +161,8 @@ def __init__( no_pad_keys: Dict[str, type] = { "item_id": np.int64, "item_embedding": np.float32, + "aux_item_embedding": np.float32, } pad_values: Dict[str, np.generic] = {} - super().__init__(need_pad_keys, no_pad_keys, pad_values, seed) + super().__init__(need_pad_keys, no_pad_keys, pad_values, seed, **kwargs) diff --git a/src/genrec/datasets/dataset_seqrec.py b/src/genrec/datasets/dataset_seqrec.py index 784e0e7..85de599 100644 --- a/src/genrec/datasets/dataset_seqrec.py +++ b/src/genrec/datasets/dataset_seqrec.py @@ -4,9 +4,9 @@ from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union -from jaxtyping import Float, Int +from jaxtyping import Int import numpy as np import pandas as pd @@ -67,7 +67,26 @@ def __init__( sid_cache: Optional[Int[np.ndarray, "I+1 C"]] = None, textual_data_path: Optional[Union[pd.DataFrame, str, Path]] = None, lm_encoder: Optional[LMEncoder] = None, + **kwargs: Any, ) -> None: + """Initialises the dataset and materialises user-level metadata. + + Args: + interaction_data_path (Union[pd.DataFrame, str, Path]): Pandas DataFrame or path to a + pickle file containing `UserID` and `ItemID` columns. We assume that the `UserID` + begins from 0 and that `ItemID` begins from 1, both being contiguous integers. The + `ItemID` of 0 is reserved for padding. + split (DatasetSplitLiteral): Dataset split controlling example generation strategy. + max_seq_length (int): Maximum length of interaction histories. + min_seq_length (int): Minimum length of interaction histories. + sid_cache (Optional[Int[np.ndarray, "I+1 C"]]): Optional mapping from item ID to SID + sequence, stored as numpy arrays. + textual_data_path (Optional[Union[pd.DataFrame, str, Path]]): Optional DataFrame or + pickle file with `ItemID` and `Title` columns. + lm_encoder (Optional[LMEncoder]): Optional encoder used to transform item titles into + dense embeddings. + **kwargs (Any): Additional keyword arguments for the dataset. + """ super().__init__( interaction_data_path, split, @@ -76,6 +95,7 @@ def __init__( sid_cache, textual_data_path, lm_encoder, + **kwargs, ) # recompute training set item popularity self._train_item_popularity = self._compute_train_item_popularity() @@ -188,6 +208,7 @@ def __init__( dataset: SeqRecDataset, config: Optional[SeqRecCollatorConfig] = None, seed: int = 42, + **kwargs: Any, ) -> None: """Configures the collator. @@ -224,7 +245,7 @@ def __init__( "attention_mask": np.int64(0), } - super().__init__(need_pad_keys, no_pad_keys, pad_values, seed) + super().__init__(need_pad_keys, no_pad_keys, pad_values, seed, **kwargs) def _process_before_padding( self, diff --git a/src/genrec/datasets/modules/lm_encoders.py b/src/genrec/datasets/modules/lm_encoders.py index c03425e..37da7aa 100644 --- a/src/genrec/datasets/modules/lm_encoders.py +++ b/src/genrec/datasets/modules/lm_encoders.py @@ -3,7 +3,7 @@ from __future__ import annotations from pathlib import Path -from typing import Callable, Protocol, Sequence, Type, runtime_checkable +from typing import Callable, Optional, Protocol, Sequence, Type, Union, runtime_checkable from jaxtyping import Float import numpy as np @@ -68,8 +68,8 @@ class SentenceT5Encoder: def __init__( self, model_name: str = "sentence-transformers/sentence-t5-base", - device: str | None = None, - local_model_dir: str | None = None, + device: Optional[str] = None, + local_model_dir: Optional[Union[str, Path]] = None, allow_download: bool = True, ) -> None: """Initialises the encoder. @@ -78,8 +78,8 @@ def __init__( model_name (str): Hugging Face identifier for the model to load when a local directory is not provided. device (Optional[str]): Optional device string forwarded to SentenceTransformers. - local_model_dir (Optional[str]): Directory containing a cached model or to be used as - a cache location. + local_model_dir (Optional[Union[str, Path]]): Directory containing a cached model or to + be used as a cache location. allow_download (bool): Whether to download the model if `local_model_dir` does not exist yet. @@ -89,8 +89,8 @@ def __init__( """ model_source = model_name - self._model_dir: Path | None = None - cache_folder: str | None = None + self._model_dir: Optional[Path] = None + cache_folder: Optional[str] = None if local_model_dir is not None: resolved_dir = Path(local_model_dir).expanduser().resolve() self._model_dir = resolved_dir diff --git a/src/genrec/main_quantizer.py b/src/genrec/main_quantizer.py index e69de29..e617633 100644 --- a/src/genrec/main_quantizer.py +++ b/src/genrec/main_quantizer.py @@ -0,0 +1,298 @@ +"""Entry point for quantizer training experiments. + +Expected configuration schema:: + + output_dir: /path/to/trial_output # directory to save model checkpoints, logs, and results + seed: 42 + pretrained_ckpt: /optional/pretrained/run # optional path to a pretrained checkpoint to load + save_predictions: true # whether to save the predictions on the test set + dataset: + type: quantizer + interaction_data_path: /path/to/movielens-1m/proc/user2item.pkl + textual_data_path: /path/to/movielens-1m/proc/item2title.pkl + lm_encoder_type: sentence_t5 + lm_encoder_path: /path/to/sentence-t5-base + ... # dataset-specific parameters + collator: + type: quantizer + ... # collator-specific parameters + model: + type: rqvae + config: + num_codebooks: 3 + ... # model hyper-parameters + trainer: + type: rqvae + config: + num_train_epochs: 20000 + ... # trainer hyper-parameters +""" + +from __future__ import annotations + +import argparse +import copy +import gzip +import json +import pickle +from pathlib import Path +from typing import Any, BinaryIO, Dict, Optional, cast + +from accelerate import Accelerator +from accelerate.utils import set_seed +from jaxtyping import Float, Int +import numpy as np +from rich import print_json +import torch +from transformers.utils import logging +import yaml + +from .datasets import ( + DatasetSplitLiteral, + LMEncoderFactory, + QuantizerCollator, + QuantizerCollatorConfig, + QuantizerDataset, +) +from .models import QuantizerModel, QuantizerModelConfigFactory, QuantizerModelFactory +from .trainers import QuantizerTrainerFactory, QuantizerTrainingArgumentsFactory + +__all__ = [ + "main", +] + + +logger = logging.get_logger(__name__) + + +def load_config(config_path: Path, *, is_main_process: bool) -> Dict[str, Any]: + """Loads and prints a configuration file in JSON/YAML format.""" + if not config_path.exists(): # pragma: no cover - defensive check + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + if config_path.suffix in {".yaml", ".yml"}: # pragma: no cover - not default format + with open(config_path, "r", encoding="utf-8") as f: + configs: Dict[str, Any] = yaml.safe_load(f) + elif config_path.suffix == ".json": + with open(config_path, "r", encoding="utf-8") as f: + configs: Dict[str, Any] = json.load(f) + else: # pragma: no cover - defensive check + raise ValueError(f"Unsupported configuration file format: {config_path.suffix}") + + if is_main_process: + logger.info(f"Loaded configuration from {config_path}:") + print_json(data=configs) + + return configs + + +def save_experiment_config( + configs: Dict[str, Any], + *, + output_dir: Path, + config_path: Path, + is_main_process: bool, +) -> None: + """Persists the resolved configuration alongside useful metadata.""" + + if not is_main_process: + return + + serialisable_cfg = copy.deepcopy(configs) + metadata = serialisable_cfg.setdefault("_meta", {}) + metadata["config_path"] = str(config_path.resolve()) + + output_config_path = output_dir / "experiment_config.json" + if output_config_path.exists(): # pragma: no cover - defensive check + logger.warning(f"Overwriting existing configuration file at: {output_config_path}") + with open(output_config_path, "w", encoding="utf-8") as f: + json.dump(serialisable_cfg, f, indent=4) + + +def main(): + """Main function for sequential recommendation experiments.""" + + accelerator = Accelerator() + + # Set up logging (only master process should emit verbose logs) + if accelerator.is_main_process: + logging.set_verbosity_info() # set logging level to INFO + else: + logging.set_verbosity_error() + if accelerator.is_local_main_process: + logging.enable_progress_bar() # enable tqdm progress bar for local main process + else: + logging.disable_progress_bar() + + # Parse args and load configuration + parser = argparse.ArgumentParser(description="Quantizer Training Experiment") + parser.add_argument( + "--config", + type=Path, + required=True, + help="Path to the experiment configuration file", + ) + args = parser.parse_args() + if accelerator.is_main_process: + logger.info(f"Parsed command-line arguments: config={args.config}") + cfg = load_config(args.config, is_main_process=accelerator.is_main_process) + raw_cfg = copy.deepcopy(cfg) + + # Extract experiment-level parameters from the configuration + try: + output_dir = Path(cfg.pop("output_dir")).expanduser() + except KeyError as exc: + raise KeyError("`output_dir` must be specified in the configuration file.") from exc + + seed = int(cfg.pop("seed", 42)) + pretrained_ckpt_value = cfg.pop("pretrained_ckpt", None) + pretrained_ckpt = Path(pretrained_ckpt_value).expanduser() if pretrained_ckpt_value is not None else None + save_predictions = bool(cfg.pop("save_predictions", True)) + + if accelerator.is_main_process: + output_dir.mkdir(parents=True, exist_ok=True) + accelerator.wait_for_everyone() + + save_experiment_config( + raw_cfg, + output_dir=output_dir, + config_path=args.config, + is_main_process=accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + if pretrained_ckpt is not None: + if not pretrained_ckpt.exists(): + raise FileNotFoundError(f"Pretrained checkpoint not found: {pretrained_ckpt}") + if not pretrained_ckpt.is_dir(): + raise ValueError(f"`pretrained_ckpt` should be a directory: {pretrained_ckpt}") + + # Set up seed + set_seed(seed) + + # Builds datasets. Refer to the constructor of `QuantizerDataset`. + assert "dataset" in cfg, "`dataset` configuration section is missing." + dataset_cfg: Dict[str, Any] = cfg["dataset"] + dataset_type = dataset_cfg.pop("type", None) + assert dataset_type == "quantizer", f"Unsupported dataset type: {dataset_type}" + + lm_encoder_type = dataset_cfg.pop("lm_encoder_type", None) + assert lm_encoder_type is not None, "`lm_encoder_type` must be specified in the dataset configuration." + lm_encoder_path = dataset_cfg.pop("lm_encoder_path", None) + lm_encoder = LMEncoderFactory.create( + name=lm_encoder_type, + local_model_dir=lm_encoder_path, + device="cuda" if torch.cuda.is_available() else "cpu", + ) + + aux_item_embeddings: Optional[Float[np.ndarray, "I+1 D_aux"]] = None + aux_item_embeddings_path = dataset_cfg.pop("aux_item_embeddings_path", None) + if aux_item_embeddings_path is not None: + aux_path = Path(aux_item_embeddings_path).expanduser() + if not aux_path.exists(): + raise FileNotFoundError(f"Auxiliary item embeddings file not found: {aux_path}") + aux_item_embeddings = np.load(aux_path) + + dataset = QuantizerDataset( + **dataset_cfg, + lm_encoder=lm_encoder, + aux_item_embeddings=aux_item_embeddings, + ) + + if accelerator.is_main_process: + logger.info(f"Loaded datasets from {dataset_cfg['interaction_data_path']}.") + logger.info(f"Train dataset: {dataset.stats()}") + + # Builds collator. Refer to the constructor of `QuantizerCollator`. + assert "collator" in cfg, "`collator` configuration section is missing." + collator_cfg = cfg["collator"] + collator_type: str = collator_cfg.pop("type", None) + assert collator_type == "quantizer", f"Unsupported collator type: {collator_type}" + + collator_cfg = QuantizerCollatorConfig(**collator_cfg) + collator = QuantizerCollator(dataset, collator_cfg, seed=seed) + + # Builds model. Refer to the constructor of `QuantizerModel` and `QuantizerModelConfig`. + assert "model" in cfg, "`model` configuration section is missing." + model_cfg = cfg["model"] + model_type: str = model_cfg.pop("type", None) + + model_config_cfg = model_cfg.pop("config", {}) + model_config = QuantizerModelConfigFactory.create( + model_type, + embed_dim=lm_encoder.embedding_dim, + **model_config_cfg, + ) + + # Load pretrained checkpoint if provided + if pretrained_ckpt is not None: + model = QuantizerModelFactory.from_pretrained(model_type, pretrained_ckpt, config=model_config) + assert isinstance(model, QuantizerModel), f"Pretrained model is not an instance of SeqRecModel: {type(model)}" + if accelerator.is_main_process: + logger.info(f"Loaded pretrained model {model_type} checkpoint from {pretrained_ckpt}.") + else: + model = QuantizerModelFactory.create(model_type, config=model_config, **model_cfg) + if accelerator.is_main_process: + logger.info(f"Initialized model {model_type}.") + + # Builds trainer. Refer to the constructor of `SeqRecTrainer`. + assert "trainer" in cfg, "`trainer` configuration section is missing." + trainer_cfg = cfg["trainer"] + trainer_type: str = trainer_cfg.pop("type", None) + + training_args_cfg = trainer_cfg.pop("config", {}) + training_args = QuantizerTrainingArgumentsFactory.create( + trainer_type, + output_dir=output_dir, + logging_dir=output_dir / "runs", + seed=seed, + **training_args_cfg, + ) + + trainer = QuantizerTrainerFactory.create( + trainer_type, + model=model, + args=training_args, + data_collator=collator, + train_dataset=dataset, + eval_dataset=dataset, + **trainer_cfg, + ) + if accelerator.is_main_process: + logger.info(f"Initialized trainer {trainer_type}.") + + # Training + if training_args.do_train: + trainer.train() + + # Predicting, save results and metrics + if training_args.do_predict: + pred = trainer.predict(dataset) + + if save_predictions and accelerator.is_main_process: + save_path = output_dir / "test_predictions.pkl.gz" + with cast(BinaryIO, gzip.open(save_path, "wb")) as f: + pickle.dump(pred, f) + logger.info(f"Saved test predictions to {save_path}.") + + # also post-process save the semantic_ids separately for easier usage + semantic_ids: Int[torch.Tensor, "I C"] = torch.as_tensor(pred.predictions[0]) + processed_semantic_ids: Int[np.ndarray, "I C+1"] = model.post_process_quantized_ids(semantic_ids).numpy() + semantic_ids_save_path = output_dir / "test_processed_semantic_ids.npy" + np.save(semantic_ids_save_path, processed_semantic_ids) + logger.info(f"Saved test semantic IDs to {semantic_ids_save_path}.") + + if accelerator.is_main_process: + logger.info("Test set metrics:") + print_json(data=pred.metrics) + save_path = output_dir / "test_metrics.json" + with open(save_path, "w", encoding="utf-8") as f: + json.dump(pred.metrics, f, indent=4) + logger.info(f"Saved test metrics to {save_path}.") + + # Exit + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/src/genrec/main_seqrec.py b/src/genrec/main_seqrec.py index 00e1276..99391a1 100644 --- a/src/genrec/main_seqrec.py +++ b/src/genrec/main_seqrec.py @@ -36,11 +36,11 @@ from pathlib import Path from typing import Any, BinaryIO, Dict, cast -import yaml from accelerate import Accelerator from accelerate.utils import set_seed from rich import print_json from transformers.utils import logging +import yaml from .datasets import DatasetSplitLiteral, SeqRecCollator, SeqRecCollatorConfig, SeqRecDataset from .models import SeqRecModel, SeqRecModelConfigFactory, SeqRecModelFactory diff --git a/src/genrec/models/__init__.py b/src/genrec/models/__init__.py index 9a848d0..6912d20 100644 --- a/src/genrec/models/__init__.py +++ b/src/genrec/models/__init__.py @@ -6,9 +6,23 @@ __all__ += [] -# from .model_quantizer import () +from .model_quantizer import ( + QuantizerModel, + QuantizerModelConfig, + QuantizerModelConfigFactory, + QuantizerModelFactory, + QuantizerOutput, + QuantizerOutputFactory, +) -__all__ += [] +__all__ += [ + "QuantizerModel", + "QuantizerModelConfig", + "QuantizerModelConfigFactory", + "QuantizerModelFactory", + "QuantizerOutput", + "QuantizerOutputFactory", +] from .model_seqrec import ( SeqRecModel, diff --git a/src/genrec/models/model_quantizer/__init__.py b/src/genrec/models/model_quantizer/__init__.py index 23a03bd..3eaef00 100644 --- a/src/genrec/models/model_quantizer/__init__.py +++ b/src/genrec/models/model_quantizer/__init__.py @@ -1 +1,33 @@ """Models for quantizers used in generative recommendation tasks.""" + +__all__ = [] + +from .base import ( + QuantizerModel, + QuantizerModelConfig, + QuantizerModelConfigFactory, + QuantizerModelFactory, + QuantizerOutput, + QuantizerOutputFactory, +) + +__all__ += [ + "QuantizerModel", + "QuantizerModelConfig", + "QuantizerModelConfigFactory", + "QuantizerModelFactory", + "QuantizerOutput", + "QuantizerOutputFactory", +] + +from .rqvae import ( + RQVAEModel, + RQVAEModelConfig, + RQVAEModelOutput, +) + +__all__ += [ + "RQVAEModel", + "RQVAEModelConfig", + "RQVAEModelOutput", +] diff --git a/src/genrec/models/model_quantizer/base.py b/src/genrec/models/model_quantizer/base.py new file mode 100644 index 0000000..81b1c6b --- /dev/null +++ b/src/genrec/models/model_quantizer/base.py @@ -0,0 +1,323 @@ +"""Base model for quantizer.""" + +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Generic, Optional, Sequence, Type, TypeVar, Union + +from jaxtyping import Float, Int +import torch +from transformers import PretrainedConfig, PreTrainedModel +from transformers.utils.generic import ModelOutput + +__all__ = [ + "QuantizerModel", + "QuantizerModelFactory", + "QuantizerModelConfig", + "QuantizerModelConfigFactory", + "QuantizerOutput", + "QuantizerOutputFactory", +] + +_QuantizerModelConfig = TypeVar("_QuantizerModelConfig", bound="QuantizerModelConfig") +_QuantizerOutput = TypeVar("_QuantizerOutput", bound="QuantizerOutput") +_QuantizerModel = TypeVar("_QuantizerModel", bound="QuantizerModel[Any, Any]") + + +class QuantizerModelConfigFactory: # pragma: no cover - factory class + """Factory for creating `QuantizerModelConfig` instances.""" + + _registry: dict[str, Type[QuantizerModelConfig]] = {} + + @classmethod + def register(cls, name: str) -> Callable[[Type[_QuantizerModelConfig]], Type[_QuantizerModelConfig]]: + """Decorator to register a `QuantizerModelConfig` implementation.""" + + def decorator(config_cls: Type[_QuantizerModelConfig]) -> Type[_QuantizerModelConfig]: + if name in cls._registry: + raise ValueError(f"Quantizer model config '{name}' is already registered.") + cls._registry[name] = config_cls + return config_cls + + return decorator + + @classmethod + def create(cls, name: str, **kwargs) -> QuantizerModelConfig: + """Creates an instance of a registered `QuantizerModelConfig`.""" + if name not in cls._registry: + raise ValueError(f"Quantizer model config '{name}' is not registered.") + config_cls = cls._registry[name] + return config_cls(**kwargs) + + +class QuantizerModelConfig(PretrainedConfig): + """Base configuration class for quantizer models. + + This class extends the `PretrainedConfig` from the Hugging Face Transformers library + and serves as a base for implementing specific quantizer model configurations. + + Subclasses must specify the `model_type` attribute. + """ + + model_type = "quantizer" + + def __init__( + self, + embed_dim: int = 768, + hidden_sizes: Sequence[int] = (512, 256, 128), + num_codebooks: int = 3, + codebook_size: int = 256, + codebook_dim: int = 32, + **kwargs, + ) -> None: + """Initializes the configuration with model hyperparameters. + + Args: + embed_dim (int): Dimensionality of the input dense embeddings. + hidden_sizes (Sequence[int]): Sizes of hidden layers in the quantizer encoder. Note + that a Linear(hidden_sizes[-1], embed_dim) layer is appended to the encoder. + The decoder will be symmetric to the encoder. Default is (512, 256, 128). + num_codebooks (int): Number of codebooks in the quantizer. Default is 3. + codebook_size (int): Number of codes in each codebook. Default is 256. + codebook_dim (int): Dimensionality of each code in the codebooks. Default is 32. + **kwargs (Any): Additional keyword arguments for the base `PretrainedConfig`. + """ + super().__init__(**kwargs) + + self.embed_dim = embed_dim + self.hidden_sizes = hidden_sizes + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + +class QuantizerOutputFactory: # pragma: no cover - factory class + """Factory for creating `QuantizerOutput` instances.""" + + _registry: dict[str, Type[QuantizerOutput]] = {} + + @classmethod + def register(cls, name: str) -> Callable[[Type[_QuantizerOutput]], Type[_QuantizerOutput]]: + """Decorator to register a `QuantizerOutput` implementation.""" + + def decorator(output_cls: Type[_QuantizerOutput]) -> Type[_QuantizerOutput]: + if name in cls._registry: + raise ValueError(f"Quantizer output '{name}' is already registered.") + cls._registry[name] = output_cls + return output_cls + + return decorator + + @classmethod + def create(cls, name: str, **kwargs) -> QuantizerOutput: + """Creates an instance of a registered `QuantizerOutput`.""" + if name not in cls._registry: + raise ValueError(f"Quantizer output '{name}' is not registered.") + output_cls = cls._registry[name] + return output_cls(**kwargs) + + +@dataclass +class QuantizerOutput(ModelOutput): + """Base output class for quantizer models. + + Attributes: + semantic_ids (Int[torch.Tensor, "B C"]): Semantic IDs assigned by the quantizer, + i.e., the indices of the selected codes from each codebook. The shape is + (batch_size, num_codebooks). + quantized_embeddings (Optional[Float[torch.Tensor, "B C D_c"]]): Quantized + embeddings corresponding to the semantic IDs, if applicable. The shape is + (batch_size, num_codebooks, codebook_dim). + residual_embeddings (Optional[Float[torch.Tensor, "B C D_c"]]): Residual embeddings + before quantization to the corresponding codebooks, if applicable. The shape is + (batch_size, num_codebooks, codebook_dim). + decoded_embeddings (Optional[Float[torch.Tensor, "B D"]]): Reconstructed dense + embeddings from the quantizer, if applicable. The shape is (batch_size, embed_dim). + reconstruction_loss (Optional[Float[torch.Tensor, "B"]]): The reconstruction loss + value, if applicable. + codebook_loss (Optional[Float[torch.Tensor, "B"]]): The codebook loss value, + if applicable. + commitment_loss (Optional[Float[torch.Tensor, "B"]]): The commitment loss value, + if applicable. + model_loss (Optional[Float[torch.Tensor, ""]]): The computed model-specific + loss value, if applicable. Note that the model-agnostic loss (e.g., + reconstruction or commitment losses) is handled outside of this class. + + .. note:: + In typical quantizers, e.g., RQ-VAE, the `residual_embeddings` are supposed to be + close to the `quantized_embeddings` by optimizing the commitment loss. The + `decoded_embeddings` are supposed to be close to the original dense embeddings + by optimizing the reconstruction loss. + + .. note:: + The STE (Straight-Through Estimator) trick is applied for `quantized_embeddings` + to allow gradient backpropagation during training. That is, the forward pass uses + the quantized embeddings, while in the backward pass, the gradients are directly + passed to the input embeddings before quantization. + + .. note:: + The output `semantic_ids`, e.g., `, , `, may not be completely + unique for different input embeddings due to collisions in the quantization process. + In practice, some strategies such as adding an addition anti-collision code, e.g., + ``, can be employed to mitigate this issue. In addition, each code, e.g., ``, + is originally ranged from 0 to `codebook_size - 1` within each codebook. To convert + them to global unique IDs, an offset can be added based on the codebook index, e.g., + `B_23` -> `B_23 + codebook_size * 1 + 1` (The last `+1` is for reserving the padding ID). + These logics are expected to be handled in `QuantizerModel.post_process_quantized_ids`. + """ + + semantic_ids: Int[torch.Tensor, "B C"] + quantized_embeddings: Optional[Float[torch.Tensor, "B C D_c"]] = None + residual_embeddings: Optional[Float[torch.Tensor, "B C D_c"]] = None + decoded_embeddings: Optional[Float[torch.Tensor, "B D"]] = None + reconstruction_loss: Optional[Float[torch.Tensor, "B"]] = None + codebook_loss: Optional[Float[torch.Tensor, "B"]] = None + commitment_loss: Optional[Float[torch.Tensor, "B"]] = None + model_loss: Optional[Float[torch.Tensor, ""]] = None + + +class QuantizerModelFactory: # pragma: no cover - factory class + """Factory for creating `QuantizerModel` instances.""" + + _registry: dict[str, Type[QuantizerModel[Any, Any]]] = {} + + @classmethod + def register(cls, name: str) -> Callable[[Type[_QuantizerModel]], Type[_QuantizerModel]]: + """Decorator to register a `QuantizerModel` implementation.""" + + def decorator(model_cls: Type[_QuantizerModel]) -> Type[_QuantizerModel]: + if name in cls._registry: + raise ValueError(f"Quantizer model '{name}' is already registered.") + cls._registry[name] = model_cls + return model_cls + + return decorator + + @classmethod + def create(cls, name: str, config: QuantizerModelConfig) -> QuantizerModel[Any, Any]: + """Creates an instance of a registered `QuantizerModel`.""" + if name not in cls._registry: + raise ValueError(f"Quantizer model '{name}' is not registered.") + model_cls = cls._registry[name] + return model_cls(config) + + @classmethod + def from_pretrained(cls, name: str, path: Union[str, os.PathLike], **kwargs) -> QuantizerModel[Any, Any]: + """Loads a pretrained instance of a registered `QuantizerModel`.""" + if name not in cls._registry: + raise ValueError(f"Quantizer model '{name}' is not registered.") + model_cls = cls._registry[name] + return model_cls.from_pretrained(path, **kwargs) + + +class QuantizerModel(PreTrainedModel, Generic[_QuantizerModelConfig, _QuantizerOutput], ABC): + """Base class for quantizer models. + + This class extends the `PreTrainedModel` from the Hugging Face Transformers library + and serves as a base for implementing specific quantizer models. + + Subclasses must specify the `config_class` attribute and implement the `forward` and + `initialize_codebooks` methods. We provide a default implementation for + `post_process_quantized_ids`, which can be overridden if needed. + """ + + config_class: Type[_QuantizerModelConfig] + + def __init__(self, config: _QuantizerModelConfig) -> None: + """Initializes the quantizer model. + + Args: + config (_QuantizerModelConfig): Configuration containing model hyperparameters. + """ + super().__init__(config) + self.config: _QuantizerModelConfig + + @abstractmethod + def forward( # pragma: no cover - abstract method + self, + item_id: Int[torch.Tensor, "B"], + item_embedding: Float[torch.Tensor, "B D"], + output_loss: bool = False, + output_model_loss: bool = False, + output_embeddings: bool = False, + **kwargs, + ) -> _QuantizerOutput: + """Performs a forward pass through the quantizer model. + + Args: + item_id (Int[torch.Tensor, "B"]): Item IDs corresponding to the input embeddings. + item_embedding (Float[torch.Tensor, "B D"]): Dense item embeddings to be quantized. + output_loss (bool): Whether to compute and return the reconstruction and commitment losses. Default is False. + output_model_loss (bool): Whether to compute and return the model-specific loss. Default is False. + output_embeddings (bool): Whether to return the (quantized, residual, and decoded) embeddings. Default is False. + **kwargs (Any): Additional keyword arguments for the forward pass. + + Returns: + _QuantizerOutput: Model outputs packaged as a `QuantizerOutput` instance. + """ + ... + + @abstractmethod + def initialize_codebooks( # pragma: no cover - abstract method + self, + item_embeddings: Float[torch.Tensor, "I D"], + **kwargs, + ) -> None: + """Initializes the codebooks using the provided item embeddings. + + This method is typically called when a specific global initialization strategy + is desired, e.g., when the `kmeans_init` is set to True in the RQ-VAE model + configuration. In this case, it performs k-means clustering on the item embeddings + to initialize the codebooks. You may override this method in subclasses to implement + custom initialization logic. + + Args: + item_embeddings (Float[torch.Tensor, "I D"]): Dense item embeddings used for + initializing the codebooks. + """ + ... + + def post_process_quantized_ids( + self, + semantic_ids: Int[torch.Tensor, "B C"], + ) -> Int[torch.Tensor, "B C_new"]: + """Post-processes the semantic IDs to ensure global uniqueness. + + This method converts the local semantic IDs (ranging from 0 to `codebook_size - 1` + within each codebook) to globally unique IDs by adding an offset based on the + codebook index. Additionally, it handles anti-collision codes by appending an + extra code `Z_0` to the codebooks (so `C_new = C + 1` by default). Note that the + final codes are shifted by 1 to reserve the padding ID zero (which corresponds to + `, , , ..., `). You may override this method in subclasses to + implement custom post-processing logic. + + Args: + semantic_ids (Int[torch.Tensor, "B C"]): Local semantic IDs assigned by the quantizer. + The shape is (batch_size, num_codebooks), where the `batch_size` is expected to be + the item number in most cases. + + Returns: + Int[torch.Tensor, "B C_new"]: Globally unique semantic IDs after post-processing. + In this implementation, `C_new = C + 1` to account for the anti-collision codes. + """ + B, C = semantic_ids.shape + assert C == self.config.num_codebooks, "Unexpected number of codebooks in semantic IDs." + + # append anti-collision code at the end of each codebook + anti_collision_codes = torch.zeros((B, 1), dtype=semantic_ids.dtype, device=semantic_ids.device) + seen_counts = {} + for i in range(B): + key = tuple(semantic_ids[i].tolist()) + count = seen_counts.get(key, 0) + anti_collision_codes[i, 0] = count + seen_counts[key] = count + 1 + semantic_ids = torch.cat([semantic_ids, anti_collision_codes], dim=1) + + # add offset based on codebook index + for codebook_idx in range(semantic_ids.shape[1]): + offset = codebook_idx * self.config.codebook_size + 1 # +1 for reserving padding ID + semantic_ids[:, codebook_idx] += offset + + return semantic_ids diff --git a/src/genrec/models/model_quantizer/rqvae.py b/src/genrec/models/model_quantizer/rqvae.py new file mode 100644 index 0000000..fb3b188 --- /dev/null +++ b/src/genrec/models/model_quantizer/rqvae.py @@ -0,0 +1,268 @@ +"""Quantizer Model: RQ-VAE.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +from jaxtyping import Float, Int +from sklearn.cluster import KMeans +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..modules import MLP +from .base import ( + QuantizerModel, + QuantizerModelConfig, + QuantizerModelConfigFactory, + QuantizerModelFactory, + QuantizerOutput, + QuantizerOutputFactory, +) + +__all__ = [ + "RQVAEModel", + "RQVAEModelConfig", + "RQVAEModelOutput", +] + + +@QuantizerModelConfigFactory.register("rqvae") +class RQVAEModelConfig(QuantizerModelConfig): + """Configuration class for RQ-VAE model, which extends the base `QuantizerModelConfig`.""" + + def __init__( + self, + kmeans_init: bool = True, + kmeans_max_iter: int = 10, + **kwargs, + ) -> None: + """Initializes the configuration with model hyperparameters. + + Args: + kmeans_init (bool): Whether to initialize the codebooks using k-means clustering + on the provided item embeddings. If False, random initialization is used. + kmeans_max_iter (int): Maximum number of iterations for the k-means algorithm. + **kwargs (Any): Additional keyword arguments for the base `QuantizerModelConfig`. + """ + super().__init__(**kwargs) + self.kmeans_init = kmeans_init + self.kmeans_max_iter = kmeans_max_iter + + +@QuantizerOutputFactory.register("rqvae") +@dataclass +class RQVAEModelOutput(QuantizerOutput): + """Output class for RQ-VAE model. + + The `RQVAEModelOutput` class extends the base `QuantizerOutput` without adding any additional attributes. + """ + + pass + + +@QuantizerModelFactory.register("rqvae") +class RQVAEModel(QuantizerModel): + """Residual-Quantized Variational AutoEncoder (RQ-VAE) model implementation. + + Here we implement the RQ-VAE model with Kmeans initialization. + + References: + - Neural Discrete Representation Learning. NeurIPS '17. + - Autoregressive Image Generation Using Residual Quantization. CVPR '22. + - Recommender Systems with Generative Retrieval. NeurIPS '23. + """ + + config_class = RQVAEModelConfig + + def __init__(self, config: RQVAEModelConfig) -> None: + super().__init__(config) + self.config: RQVAEModelConfig + + self.codebooks = nn.ModuleList( + nn.Embedding( + self.config.codebook_size, + self.config.codebook_dim, + ) + for _ in range(config.num_codebooks) + ) + self.encoder = MLP( + input_size=self.config.embed_dim, + hidden_sizes=list(self.config.hidden_sizes), + output_size=self.config.codebook_dim, + activation=nn.ReLU(), + ffn_bias=True, + ) + self.decoder = MLP( + input_size=self.config.codebook_dim, + hidden_sizes=list(reversed(self.config.hidden_sizes)), + output_size=self.config.embed_dim, + activation=nn.ReLU(), + ffn_bias=True, + ) + + self.gradient_checkpointing = False # disable gradient checkpointing by default + self.post_init() # use PretrainedModel's default weight initialization + + @torch.no_grad() + def initialize_codebooks( + self, + item_embeddings: Float[torch.Tensor, "I D"], + **kwargs, + ) -> None: + """Initializes the codebooks using the provided item embeddings. + + Applying k-means clustering to initialize the codebooks if `kmeans_init` is True. + + Args: + item_embeddings (Float[torch.Tensor, "I D"]): Dense item embeddings used for + initializing the codebooks. + + .. note:: + This implementation is not identical to the RQ-Kmeans, as the latter do not + use encoder. + """ + if not self.config.kmeans_init: + return # Use random initialization + + embeddings = self.encoder(item_embeddings).cpu().numpy() + + residual = embeddings.copy() + for code_idx in range(self.config.num_codebooks): + kmeans = KMeans( + n_clusters=self.config.codebook_size, + n_init='auto', + max_iter=self.config.kmeans_max_iter, + random_state=42, + ) + kmeans.fit(residual) + centroids = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32) + self.codebooks[code_idx].weight.data.copy_(centroids) + + distances = torch.cdist(torch.tensor(residual), centroids, p=2) + nearest_codes = torch.argmin(distances, dim=1) + quantized = centroids[nearest_codes].numpy() + residual = residual - quantized + + def _quantize( + self, + embeddings: Float[torch.Tensor, "B D_c"], + codebook: Float[torch.Tensor, "K_c D_c"], + ) -> Tuple[ + Int[torch.Tensor, "B"], Float[torch.Tensor, "B D_c"], Float[torch.Tensor, "B"], Float[torch.Tensor, "B"] + ]: + """Quantizes the input embeddings using the provided codebook. + + In the default implementation, we select the nearest code from the codebook for + each input embedding based on Euclidean distance. You may override this method + in subclasses to implement alternative quantization strategies. + + Args: + embeddings (Float[torch.Tensor, "B D_c"]): Input embeddings to be quantized. + codebook (Float[torch.Tensor, "K_c D_c"]): Codebook weight used for quantization. + + Returns: + Tuple[Int[torch.Tensor, "B"], Float[torch.Tensor, "B D_c"], Float[torch.Tensor, "B"], Float[torch.Tensor, "B"]]: + A tuple containing: + - semantic_ids: Indices of the selected codes from the codebook. + - quantized_embeddings: Quantized embeddings corresponding to the selected codes. + - codebook_loss: The codebook loss values. + - commitment_loss: The commitment loss values. + """ + # Compute pairwise distances between embeddings and codebook, select nearest code + distances: Float[torch.Tensor, "B K_c"] = torch.cdist(embeddings, codebook, p=2) + semantic_ids: Int[torch.Tensor, "B"] = torch.argmin(distances, dim=1) + quantized_embeddings: Float[torch.Tensor, "B D_c"] = codebook[semantic_ids] + + # Compute Codebook loss and Commitment loss + codebook_loss = F.mse_loss(quantized_embeddings, embeddings.detach(), reduction="none").mean(dim=-1) + commitment_loss = F.mse_loss(embeddings, quantized_embeddings.detach(), reduction="none").mean(dim=-1) + + # Apply Straight-Through Estimator (STE) trick for backpropagation + # This makes the gradient of quantized_embeddings equal to that of embeddings + # Then during backpropagation, the gradients of the decoder input (i.e., sum of quantized embeddings) + # is directly passed to the encoder output (i.e., 0-th residual) + quantized_embeddings = embeddings + (quantized_embeddings - embeddings).detach() + + return semantic_ids, quantized_embeddings, codebook_loss, commitment_loss + + def forward( + self, + item_id: Int[torch.Tensor, "B"], + item_embedding: Float[torch.Tensor, "B D"], + output_loss: bool = False, + output_model_loss: bool = False, + output_embeddings: bool = False, + **kwargs, + ) -> RQVAEModelOutput: + """Performs a forward pass through the quantizer model. + + Args: + item_id (Int[torch.Tensor, "B"]): Item IDs corresponding to the input embeddings. + item_embedding (Float[torch.Tensor, "B D"]): Dense item embeddings to be quantized. + output_loss (bool): Whether to compute and return the reconstruction and commitment losses. Default is False. + output_model_loss (bool): Whether to compute and return the model-specific loss. Default is False. + output_embeddings (bool): Whether to return the (quantized, residual, and decoded) embeddings. Default is False. + **kwargs (Any): Additional keyword arguments for the forward pass. + + Returns: + RQVAEModelOutput: Model outputs packaged as a `RQVAEModelOutput` object. + """ + B = item_embedding.shape[0] + C = self.config.num_codebooks + D_c = self.config.codebook_dim + + model_loss = None # By default, RQ-VAE does not compute model loss internally. + tot_codebook_loss: Float[torch.Tensor, "B"] = torch.zeros(B, device=item_embedding.device) + tot_commitment_loss: Float[torch.Tensor, "B"] = torch.zeros(B, device=item_embedding.device) + all_semantic_ids: Int[torch.Tensor, "B C"] = torch.empty(B, C, dtype=torch.long, device=item_embedding.device) + all_quantized_embeddings: Float[torch.Tensor, "B C D_c"] = torch.empty(B, C, D_c, device=item_embedding.device) + all_residual_embeddings: Float[torch.Tensor, "B C D_c"] = torch.empty(B, C, D_c, device=item_embedding.device) + + residual = self.encoder(item_embedding) + accumulated_quantized = torch.zeros_like(residual) + + for code_idx in range(C): + codebook = self.codebooks[code_idx] + assert isinstance(codebook, nn.Embedding), "Codebook must be an instance of nn.Embedding." + + if output_embeddings: + all_residual_embeddings[:, code_idx, :] = residual + + semantic_ids, quantized_embeddings, codebook_loss, commitment_loss = self._quantize( + residual, codebook.weight + ) + residual = residual - quantized_embeddings + accumulated_quantized = accumulated_quantized + quantized_embeddings + + all_semantic_ids[:, code_idx] = semantic_ids + if output_embeddings: + all_quantized_embeddings[:, code_idx, :] = quantized_embeddings + + tot_codebook_loss = tot_codebook_loss + codebook_loss + tot_commitment_loss = tot_commitment_loss + commitment_loss + + decoded_embeddings: Optional[Float[torch.Tensor, "B D"]] = None + if output_embeddings or output_loss: + decoded_embeddings = self.decoder(accumulated_quantized) + + tot_codebook_loss = tot_codebook_loss / C + tot_commitment_loss = tot_commitment_loss / C + reconstruction_loss = torch.zeros(B, device=item_embedding.device) + if output_loss: + assert ( + decoded_embeddings is not None + ), "Decoded embeddings must be computed to calculate reconstruction loss." + reconstruction_loss = F.mse_loss(decoded_embeddings, item_embedding, reduction="none").mean(dim=-1) + + return RQVAEModelOutput( + semantic_ids=all_semantic_ids, + quantized_embeddings=all_quantized_embeddings if output_embeddings else None, + residual_embeddings=all_residual_embeddings if output_embeddings else None, + decoded_embeddings=decoded_embeddings if output_embeddings else None, + reconstruction_loss=reconstruction_loss if output_loss else None, + codebook_loss=tot_codebook_loss if output_loss else None, + commitment_loss=tot_commitment_loss if output_loss else None, + model_loss=model_loss if output_model_loss else None, + ) diff --git a/src/genrec/models/model_seqrec/base.py b/src/genrec/models/model_seqrec/base.py index 57f26d9..e3c573a 100644 --- a/src/genrec/models/model_seqrec/base.py +++ b/src/genrec/models/model_seqrec/base.py @@ -54,38 +54,40 @@ def create(cls, name: str, **kwargs) -> SeqRecModelConfig: return config_cls(**kwargs) -class SeqRecModelFactory: # pragma: no cover - factory class - """Factory for creating `SeqRecModel` instances.""" +class SeqRecModelConfig(PretrainedConfig): + """Base configuration class for sequential recommendation models. - _registry: dict[str, Type[SeqRecModel[Any, Any]]] = {} + This class extends the `PretrainedConfig` from the Hugging Face Transformers library + and serves as a base for implementing specific sequential recommendation model configurations. - @classmethod - def register(cls, name: str) -> Callable[[Type[_SeqRecModel]], Type[_SeqRecModel]]: - """Decorator to register a `SeqRecModel` implementation.""" + Subclasses must specify the `model_type` attribute. + """ - def decorator(model_cls: Type[_SeqRecModel]) -> Type[_SeqRecModel]: - if name in cls._registry: - raise ValueError(f"SeqRec model '{name}' is already registered.") - cls._registry[name] = model_cls - return model_cls + model_type = "seqrec" - return decorator + def __init__( + self, + item_size: int = 1024, + hidden_size: int = 256, + num_attention_heads: int = 4, + num_hidden_layers: int = 4, + **kwargs, + ) -> None: + """Initializes the configuration with model hyperparameters. - @classmethod - def create(cls, name: str, **kwargs) -> SeqRecModel[Any, Any]: - """Creates an instance of a registered `SeqRecModel`.""" - if name not in cls._registry: - raise ValueError(f"SeqRec model '{name}' is not registered.") - model_cls = cls._registry[name] - return model_cls(**kwargs) + Args: + item_size (int): Size of the item vocabulary, excluding the padding token (0-th). Default is 1024. + hidden_size (int): Dimensionality of the model's hidden representations. Default is 256. + num_attention_heads (int): Number of attention heads in the model. Default is 4. + num_hidden_layers (int): Number of hidden layers in the model. Default is 4. + **kwargs (Any): Additional keyword arguments for the base `PretrainedConfig`. + """ + super().__init__(**kwargs) - @classmethod - def from_pretrained(cls, name: str, path: Union[str, os.PathLike], **kwargs) -> SeqRecModel[Any, Any]: - """Loads a pretrained instance of a registered `SeqRecModel`.""" - if name not in cls._registry: - raise ValueError(f"SeqRec model '{name}' is not registered.") - model_cls = cls._registry[name] - return model_cls.from_pretrained(path, **kwargs) + self.item_size = item_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers class SeqRecOutputFactory: # pragma: no cover - factory class @@ -114,56 +116,20 @@ def create(cls, name: str, **kwargs) -> SeqRecOutput: return output_cls(**kwargs) -class SeqRecModelConfig(PretrainedConfig): - """Base configuration class for sequential recommendation models. - - This class extends the `PretrainedConfig` from the Hugging Face Transformers library - and serves as a base for implementing specific sequential recommendation model configurations. - - Subclasses must specify the `model_type` attribute. - """ - - model_type = "seqrec" - - def __init__( - self, - item_size: int = 1024, - hidden_size: int = 256, - num_attention_heads: int = 4, - num_hidden_layers: int = 4, - **kwargs, - ) -> None: - """Initializes the configuration with model hyperparameters. - - Args: - item_size (int): Size of the item vocabulary, excluding the padding token (0-th). - hidden_size (int): Dimensionality of the model's hidden representations. - num_attention_heads (int): Number of attention heads in the model. - num_hidden_layers (int): Number of hidden layers in the model. - **kwargs (Any): Additional keyword arguments for the base `PretrainedConfig`. - """ - super().__init__(**kwargs) - - self.item_size = item_size - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.num_hidden_layers = num_hidden_layers - - @dataclass class SeqRecOutput(ModelOutput): """Base output class for sequential recommendation models. Attributes: - last_hidden_state: Hidden states from the last layer of the model. - The shape is (batch_size, seq_len, hidden_size). - model_loss: The computed model-specific loss value, if applicable. - Note that the model-agnostic loss (e.g., NLL loss) is handled outside - of this class. - hidden_states: Hidden states from the model, if applicable. The shape is - (batch_size, seq_len, hidden_size). - attentions: Attention weights from the model, if applicable. The shape is - (batch_size, num_heads, seq_len, seq_len). + last_hidden_state (Float[torch.Tensor, "B L d"]): Hidden states from the + last layer of the model. The shape is (batch_size, seq_len, hidden_size). + model_loss (Optional[Float[torch.Tensor, ""]]): The computed model-specific + loss value, if applicable. Note that the model-agnostic loss (e.g., NLL loss) + is handled outside of this class. + hidden_states (Optional[Tuple[Float[torch.Tensor, "B L d"], ...]]): Hidden states + from the model, if applicable. The shape is (batch_size, seq_len, hidden_size). + attentions (Optional[Tuple[Float[torch.Tensor, "B H L L"], ...]]): Attention weights + from the model, if applicable. The shape is (batch_size, num_heads, seq_len, seq_len). """ last_hidden_state: Float[torch.Tensor, "B L d"] @@ -172,6 +138,40 @@ class SeqRecOutput(ModelOutput): attentions: Optional[Tuple[Float[torch.Tensor, "B H L L"], ...]] = None +class SeqRecModelFactory: # pragma: no cover - factory class + """Factory for creating `SeqRecModel` instances.""" + + _registry: dict[str, Type[SeqRecModel[Any, Any]]] = {} + + @classmethod + def register(cls, name: str) -> Callable[[Type[_SeqRecModel]], Type[_SeqRecModel]]: + """Decorator to register a `SeqRecModel` implementation.""" + + def decorator(model_cls: Type[_SeqRecModel]) -> Type[_SeqRecModel]: + if name in cls._registry: + raise ValueError(f"SeqRec model '{name}' is already registered.") + cls._registry[name] = model_cls + return model_cls + + return decorator + + @classmethod + def create(cls, name: str, **kwargs) -> SeqRecModel[Any, Any]: + """Creates an instance of a registered `SeqRecModel`.""" + if name not in cls._registry: + raise ValueError(f"SeqRec model '{name}' is not registered.") + model_cls = cls._registry[name] + return model_cls(**kwargs) + + @classmethod + def from_pretrained(cls, name: str, path: Union[str, os.PathLike], **kwargs) -> SeqRecModel[Any, Any]: + """Loads a pretrained instance of a registered `SeqRecModel`.""" + if name not in cls._registry: + raise ValueError(f"SeqRec model '{name}' is not registered.") + model_cls = cls._registry[name] + return model_cls.from_pretrained(path, **kwargs) + + class SeqRecModel(PreTrainedModel, Generic[_SeqRecModelConfig, _SeqRecOutput], ABC): """Base class for sequential recommendation models. diff --git a/src/genrec/models/model_seqrec/hstu.py b/src/genrec/models/model_seqrec/hstu.py index 8c9fb91..c3b17b2 100644 --- a/src/genrec/models/model_seqrec/hstu.py +++ b/src/genrec/models/model_seqrec/hstu.py @@ -118,8 +118,8 @@ class HSTUModel(SeqRecModel[HSTUModelConfig, HSTUModelOutput]): original STU design, to enhance the model's capacity. References: - - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for - Generative Recommendations. ICML '24. + - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for + Generative Recommendations. ICML '24. """ config_class = HSTUModelConfig diff --git a/src/genrec/models/model_seqrec/hstu_spring.py b/src/genrec/models/model_seqrec/hstu_spring.py index bef5543..47a5b2e 100644 --- a/src/genrec/models/model_seqrec/hstu_spring.py +++ b/src/genrec/models/model_seqrec/hstu_spring.py @@ -91,9 +91,9 @@ class HSTUSpringModel(SeqRecModel[HSTUSpringModelConfig, HSTUSpringModelOutput]) as the `model_loss` in the `HSTUSpringModelOutput`. References: - - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for - Generative Recommendations. ICML '24. - - ... + - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for + Generative Recommendations. ICML '24. + - ... """ config_class = HSTUSpringModelConfig diff --git a/src/genrec/models/model_seqrec/sasrec.py b/src/genrec/models/model_seqrec/sasrec.py index f6bdf36..fb71b16 100644 --- a/src/genrec/models/model_seqrec/sasrec.py +++ b/src/genrec/models/model_seqrec/sasrec.py @@ -75,7 +75,7 @@ class SASRecModel(SeqRecModel[SASRecModelConfig, SASRecModelOutput]): implementations in Llama model. References: - - Self-Attentive Sequential Recommendation. ICDM '18. + - Self-Attentive Sequential Recommendation. ICDM '18. """ config_class = SASRecModelConfig diff --git a/src/genrec/models/modules/__init__.py b/src/genrec/models/modules/__init__.py index 023b04a..5167ba9 100644 --- a/src/genrec/models/modules/__init__.py +++ b/src/genrec/models/modules/__init__.py @@ -9,10 +9,11 @@ "MaskedSelfAttentionWithRoPE", ] -from .feedforward import FeedForwardNetwork, SwiGLU +from .feedforward import FeedForwardNetwork, MLP, SwiGLU __all__ += [ "FeedForwardNetwork", + "MLP", "SwiGLU", ] diff --git a/src/genrec/models/modules/attention.py b/src/genrec/models/modules/attention.py index 7cfe40d..1ba48e3 100644 --- a/src/genrec/models/modules/attention.py +++ b/src/genrec/models/modules/attention.py @@ -13,8 +13,8 @@ from .posemb import RelativeBucketedTimeAndPositionAttentionBias, apply_rotary_pos_emb __all__ = [ - "MaskedSelfAttentionWithRoPE", "MaskedHSTUAttention", + "MaskedSelfAttentionWithRoPE", ] @@ -125,8 +125,8 @@ class MaskedHSTUAttention(nn.Module): computation, which can be beneficial in certain scenarios where gating may not be stable. References: - - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for - Generative Recommendations. ICML '24. + - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for + Generative Recommendations. ICML '24. """ def __init__( diff --git a/src/genrec/models/modules/feedforward.py b/src/genrec/models/modules/feedforward.py index cc25519..6d663a0 100644 --- a/src/genrec/models/modules/feedforward.py +++ b/src/genrec/models/modules/feedforward.py @@ -2,12 +2,16 @@ from __future__ import annotations +import copy +from typing import List + from jaxtyping import Float import torch import torch.nn as nn __all__ = [ "FeedForwardNetwork", + "MLP", "SwiGLU", ] @@ -48,6 +52,47 @@ def forward(self, x: Float[torch.Tensor, "... d"]) -> Float[torch.Tensor, "... d return self.fc2(self.act_fn(self.fc1(x))) +class MLP(nn.Module): + """Multi-Layer Perceptron (MLP).""" + + def __init__( + self, + input_size: int, + hidden_sizes: List[int], + output_size: int, + activation: nn.Module = nn.ReLU(), + ffn_bias: bool = False, + ) -> None: + """Initializes MLP module. + + Args: + input_size (int): Dimensionality of the input. + hidden_sizes (List[int]): List of hidden layer sizes. + output_size (int): Dimensionality of the output. + activation (nn.Module): Activation function to use between layers. Default is ReLU. + ffn_bias (bool): Whether to include bias terms in the linear projections. Default is False. + """ + super().__init__() + layer_sizes = [input_size] + hidden_sizes + [output_size] + layers = [] + for i in range(len(layer_sizes) - 1): + layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1], bias=ffn_bias)) + if i < len(layer_sizes) - 2: + layers.append(copy.deepcopy(activation)) + self.network = nn.Sequential(*layers) + + def forward(self, x: Float[torch.Tensor, "... d"]) -> Float[torch.Tensor, "... d"]: + """Forward pass for MLP. + + Args: + x (Float[torch.Tensor, "... d"]): Input tensor of shape (..., input_size). + + Returns: + Float[torch.Tensor, "... d"]: Output tensor of shape (..., output_size). + """ + return self.network(x) + + class SwiGLU(nn.Module): """SwiGLU-based Feed-Forward Network, following `LlamaMLP`'s implementation.""" diff --git a/src/genrec/models/modules/layers.py b/src/genrec/models/modules/layers.py index 6458a77..5a95c8e 100644 --- a/src/genrec/models/modules/layers.py +++ b/src/genrec/models/modules/layers.py @@ -125,8 +125,8 @@ class SequentialTransductionUnit(GradientCheckpointingLayer): the model's capacity. References: - - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for - Generative Recommendations. ICML '24. + - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for + Generative Recommendations. ICML '24. """ def __init__( @@ -279,8 +279,8 @@ def spring_power_iteration( function should not be called in the checkpointed forward pass. References: - - Spectral Norm Regularization for Improving the Generalizability of Deep Learning. arXiv '17. - - Spectral Normalization for Generative Adversarial Networks. ICLR '18. + - Spectral Norm Regularization for Improving the Generalizability of Deep Learning. arXiv '17. + - Spectral Normalization for Generative Adversarial Networks. ICLR '18. """ assert W.dim() == 2, "Input weight matrix must be 2-dimensional." m, n = W.shape @@ -378,7 +378,7 @@ class SpringLlamaDecoderLayer(GradientCheckpointingLayer): See `LlamaDecoderLayer` for more details. References: - - ... + - ... """ def __init__( @@ -476,9 +476,9 @@ class SpringSequentialTransductionUnit(GradientCheckpointingLayer): See `SequentialTransductionUnit` for more details. References: - - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for - Generative Recommendations. ICML '24. - - ... + - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for + Generative Recommendations. ICML '24. + - ... """ def __init__( diff --git a/src/genrec/models/modules/posemb.py b/src/genrec/models/modules/posemb.py index 77b8309..caab231 100644 --- a/src/genrec/models/modules/posemb.py +++ b/src/genrec/models/modules/posemb.py @@ -81,8 +81,8 @@ class RelativeBucketedTimeAndPositionAttentionBias(nn.Module): facilitating next-item prediction tasks. References: - - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for - Generative Recommendations. ICML '24. + - Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for + Generative Recommendations. ICML '24. """ def __init__( diff --git a/src/genrec/trainers/__init__.py b/src/genrec/trainers/__init__.py index b2ada31..dc89484 100644 --- a/src/genrec/trainers/__init__.py +++ b/src/genrec/trainers/__init__.py @@ -6,9 +6,19 @@ __all__ += [] -# from .trainer_quantizer import ... +from .trainer_quantizer import ( + QuantizerTrainer, + QuantizerTrainerFactory, + QuantizerTrainingArguments, + QuantizerTrainingArgumentsFactory, +) -__all__ += [] +__all__ += [ + "QuantizerTrainer", + "QuantizerTrainerFactory", + "QuantizerTrainingArguments", + "QuantizerTrainingArgumentsFactory", +] from .trainer_seqrec import SeqRecTrainer, SeqRecTrainerFactory, SeqRecTrainingArguments, SeqRecTrainingArgumentsFactory diff --git a/src/genrec/trainers/trainer_quantizer/__init__.py b/src/genrec/trainers/trainer_quantizer/__init__.py new file mode 100644 index 0000000..6aaf38d --- /dev/null +++ b/src/genrec/trainers/trainer_quantizer/__init__.py @@ -0,0 +1,24 @@ +"""Trainers and training utilities for quantization tasks.""" + +__all__ = [] + +from .base import ( + QuantizerTrainer, + QuantizerTrainerFactory, + QuantizerTrainingArguments, + QuantizerTrainingArgumentsFactory, +) + +__all__ += [ + "QuantizerTrainer", + "QuantizerTrainerFactory", + "QuantizerTrainingArguments", + "QuantizerTrainingArgumentsFactory", +] + +from .rqvae import RQVAEQuantizerTrainer, RQVAEQuantizerTrainingArguments + +__all__ += [ + "RQVAEQuantizerTrainer", + "RQVAEQuantizerTrainingArguments", +] diff --git a/src/genrec/trainers/trainer_quantizer/base.py b/src/genrec/trainers/trainer_quantizer/base.py new file mode 100644 index 0000000..8159f65 --- /dev/null +++ b/src/genrec/trainers/trainer_quantizer/base.py @@ -0,0 +1,335 @@ +"""Base trainer for quantizer models.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Tuple, Type, TypeVar, Union + +from jaxtyping import Float +import torch +import torch.nn as nn +from transformers import EvalPrediction, Trainer, TrainerCallback, TrainingArguments + +from ...datasets import QuantizerCollator, QuantizerDataset +from ...models import QuantizerModel, QuantizerOutput +from .utils.callbacks import EpochIntervalEvalCallback, HardStopCallback +from .utils.evaluations import compute_quantizer_metrics + +__all__ = [ + "QuantizerTrainer", + "QuantizerTrainerFactory", + "QuantizerTrainingArguments", + "QuantizerTrainingArgumentsFactory", +] + + +_QuantizerModel = TypeVar("_QuantizerModel", bound="QuantizerModel[Any, Any]") +_QuantizerTrainer = TypeVar("_QuantizerTrainer", bound="QuantizerTrainer[Any, Any]") +_QuantizerTrainingArguments = TypeVar("_QuantizerTrainingArguments", bound="QuantizerTrainingArguments") + + +class QuantizerTrainingArgumentsFactory: # pragma: no cover - factory class + """Factory for creating `QuantizerTrainingArguments` instances.""" + + _registry: dict[str, Type[QuantizerTrainingArguments]] = {} + + @classmethod + def register(cls, name: str) -> Callable[[Type[_QuantizerTrainingArguments]], Type[_QuantizerTrainingArguments]]: + """Decorator to register a `QuantizerTrainingArguments` implementation.""" + + def decorator( + training_args_cls: Type[_QuantizerTrainingArguments], + ) -> Type[_QuantizerTrainingArguments]: + if name in cls._registry: + raise ValueError(f"Quantizer training arguments '{name}' is already registered.") + cls._registry[name] = training_args_cls + return training_args_cls + + return decorator + + @classmethod + def create(cls, name: str, **kwargs) -> QuantizerTrainingArguments: + """Creates an instance of a registered `QuantizerTrainingArguments`.""" + if name not in cls._registry: + raise ValueError(f"Quantizer training arguments '{name}' is not registered.") + training_args_cls = cls._registry[name] + return training_args_cls(**kwargs) + + +@dataclass +class QuantizerTrainingArguments(TrainingArguments): + """Training arguments for quantizer trainers. + + Args: + eval_interval (int): Number of epochs between each evaluation. Default is 100. + train_stop_epoch (int): Number of epochs to stop training. Default is -1 (no early stop). + metrics (Sequence[Tuple[str, Dict[str, Any]]]): Metric names and their parameters to + compute during evaluation. Default is [('codebook_usage', {}), ('code_collision', {})]. + codebook_loss_weight (float): Weight for the codebook loss term. Default is 1.0. + commitment_loss_weight (float): Weight for the commitment loss term. Default is 0.25. + model_loss_weight (float): Weight for the model-specific loss. Default is 0.0. + """ + + eval_interval: int = field( + default=100, + metadata={"help": "Number of epochs between each evaluation. Default is 100."}, + ) + + train_stop_epoch: int = field( + default=-1, + metadata={"help": "Number of epochs to stop training. Default is -1 (no early stop)."}, + ) + + metrics: Sequence[Tuple[str, Dict[str, Any]]] = field( + default_factory=lambda: [ + ("codebook_usage", {}), + ("code_collision", {}), + ], + metadata={ + "help": ( + "Metric names and their parameters to compute during evaluation. " + "Default is [('codebook_usage', {}), ('code_collision', {})]." + ) + }, + ) + + codebook_loss_weight: float = field( + default=1.0, + metadata={"help": "Weight for the codebook loss term. Default is 1.0."}, + ) + + commitment_loss_weight: float = field( + default=0.25, + metadata={"help": "Weight for the commitment loss term. Default is 0.25."}, + ) + + model_loss_weight: float = field( + default=0.0, + metadata={"help": "Weight for the model-specific loss. Default is 0.0."}, + ) + + +class QuantizerTrainerFactory: # pragma: no cover - factory class + """Factory for creating `QuantizerTrainer` instances.""" + + _registry: dict[str, Type[QuantizerTrainer[Any, Any]]] = {} + + @classmethod + def register(cls, name: str) -> Callable[[Type[_QuantizerTrainer]], Type[_QuantizerTrainer]]: + """Decorator to register a `QuantizerTrainer` implementation.""" + + def decorator(trainer_cls: Type[_QuantizerTrainer]) -> Type[_QuantizerTrainer]: + if name in cls._registry: + raise ValueError(f"Quantizer trainer '{name}' is already registered.") + cls._registry[name] = trainer_cls + return trainer_cls + + return decorator + + @classmethod + def create( + cls, + name: str, + model: QuantizerModel[Any, Any], + args: QuantizerTrainingArguments, + data_collator: QuantizerCollator, + train_dataset: QuantizerDataset, + eval_dataset: Optional[QuantizerDataset] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + **kwargs, + ) -> QuantizerTrainer[Any, Any]: + """Creates an instance of a registered `QuantizerTrainer`.""" + if name not in cls._registry: + raise ValueError(f"Quantizer trainer '{name}' is not registered.") + trainer_cls = cls._registry[name] + return trainer_cls( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, + callbacks=callbacks, + **kwargs, + ) + + +class QuantizerTrainer(Trainer, Generic[_QuantizerModel, _QuantizerTrainingArguments], ABC): + """Base trainer class for quantizer models. + + This class extends the `Trainer` class from the `transformers` library. You should + implement specific training logic, i.e., `compute_loss`, in subclasses to + compute the model-agnostic loss for quantizer training. + + .. note:: + We set up the default callbacks to include `EpochIntervalEvalCallback` + which performs evaluation every `eval_interval` epochs (default is 100). + + .. note:: + We also set up the `compute_metrics` function to use `compute_quantizer_metrics` by default. + """ + + args: _QuantizerTrainingArguments + model: _QuantizerModel + + def __init__( + self, + model: _QuantizerModel, + args: _QuantizerTrainingArguments, + data_collator: QuantizerCollator, + train_dataset: QuantizerDataset, + eval_dataset: Optional[QuantizerDataset] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + **kwargs, + ) -> None: + """Initializes the QuantizerTrainer with the given model and training arguments. + + Args: + model (_QuantizerModel): Quantizer to be trained. + args (_QuantizerTrainingArguments): Training arguments specific to quantizer training. + data_collator (QuantizerCollator): Data collator that prepares model inputs. + train_dataset (QuantizerDataset): Dataset used for training. + eval_dataset (Optional[QuantizerDataset]): Dataset used for evaluation. + compute_metrics (Optional[Callable[[EvalPrediction], Dict]]): Function used to compute + metrics during evaluation. Defaults to :func:`compute_quantizer_metrics`. + callbacks (Optional[List[TrainerCallback]]): Trainer callbacks. Defaults to + `[EpochIntervalEvalCallback, HardStopCallback]`. + **kwargs (Any): Additional keyword arguments forwarded to the base `Trainer`. + """ + if compute_metrics is None: + compute_metrics = partial( + compute_quantizer_metrics, + train_dataset=train_dataset, + metrics=args.metrics, + codebook_size=model.config.codebook_size, + ) + + if callbacks is None: + callbacks = [ + EpochIntervalEvalCallback(eval_interval=args.eval_interval), + HardStopCallback(stop_epoch=args.train_stop_epoch), + ] + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, + callbacks=callbacks, + **kwargs, + ) + + # HuggingFace Trainer requires label_names to be set for evaluation, and + # we set it to ["item_id"] by default. + # Your model's forward method and data collator should ensure that + # the input batch contains a key "item_id" corresponding to the ground truth labels. + # You may override this attribute if your label key is different in subclasses. + self.label_names = ["item_id"] + + # initialize codebooks + self.initialize_codebooks() + + def initialize_codebooks(self) -> None: + """Initializes the quantizer codebooks before training. + You may override this method in subclasses if custom initialization is needed. + """ + model = self.model.module if hasattr(self.model, "module") else self.model # type: ignore - for distributed training + assert isinstance(model, QuantizerModel), "Model must be an instance of QuantizerModel." + + assert isinstance( + self.train_dataset, QuantizerDataset + ), "Train dataset must be an instance of QuantizerDataset." + item_embeddings = self.train_dataset.item_embeddings + assert item_embeddings is not None, "Item embeddings are required to initialize codebooks." + model.initialize_codebooks(torch.from_numpy(item_embeddings).to(model.device)) + + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs: bool = False, + num_items_in_batch: Optional[torch.Tensor] = None, + ) -> Union[Float[torch.Tensor, ""], Tuple[Float[torch.Tensor, ""], Dict[str, torch.Tensor]]]: + """Computes the loss for a batch of inputs. + + Args: + model (nn.Module): Model being trained. + inputs (dict[str, Union[torch.Tensor, Any]]): Dictionary of input tensors, i.e., the + `QuantizerCollator` output. + return_outputs (bool): Whether to return the model outputs along with the loss. + num_items_in_batch (Optional[torch.Tensor]): Optional tensor indicating the number of + valid items in each sequence in the batch (excluding padding). + + Returns: + Union[Float[torch.Tensor, ""], Tuple[Float[torch.Tensor, ""], Dict[str, torch.Tensor]]]: + Either the scalar loss or a tuple containing the loss and a dictionary with + loss and top-k indices of predicted items when `return_outputs` is True. The loss + combines the model-specific loss (if provided) with the quantizer losses via + `compute_quantizer_loss`. + """ + model = model.module if hasattr(model, "module") else model # type: ignore - for distributed training + assert isinstance(model, QuantizerModel), "Model must be an instance of QuantizerModel." + + outputs: QuantizerOutput = model( + **inputs, + output_loss=True, + output_model_loss=model.training, # only compute model loss during training + output_embeddings=False, # no need to output embeddings in Trainer + ) + assert isinstance(outputs, QuantizerOutput), "Model output must be an instance of QuantizerOutput." + + quantizer_loss = self.compute_quantizer_loss(inputs, outputs, num_items_in_batch) + if outputs.model_loss is not None: + loss = quantizer_loss + outputs.model_loss * self.args.model_loss_weight + else: + loss = quantizer_loss + + if return_outputs: + assert outputs.semantic_ids is not None, "Semantic IDs must be available in outputs." + assert outputs.reconstruction_loss is not None, "Reconstruction loss should be available in outputs." + assert outputs.codebook_loss is not None, "Codebook loss should be available in outputs." + assert outputs.commitment_loss is not None, "Commitment loss should be available in outputs." + assert "item_id" in inputs and isinstance( + inputs["item_id"], torch.Tensor + ), "Input batch must contain 'item_id' tensor." + output_dict: Dict[str, torch.Tensor] = { + "loss": loss, + "semantic_ids": outputs.semantic_ids, + "reconstruction_loss": outputs.reconstruction_loss, + "codebook_loss": outputs.codebook_loss, + "commitment_loss": outputs.commitment_loss, + "item_id": inputs["item_id"], + } + return loss, output_dict + else: + return loss + + @abstractmethod + def compute_quantizer_loss( # pragma: no cover - abstract method + self, + inputs: dict[str, Union[torch.Tensor, Any]], + outputs: QuantizerOutput, + num_items_in_batch: Optional[torch.Tensor] = None, + ) -> Float[torch.Tensor, ""]: + """Computes the model-agnostic quantizer loss for a batch of inputs and model outputs. + + This method should be implemented by all subclasses to compute the quantizer-specific loss + components, e.g., reconstruction loss, codebook loss, and commitment loss. + + Args: + inputs (dict[str, Union[torch.Tensor, Any]]): Dictionary of input tensors, i.e., the + `QuantizerCollator` output. + outputs (QuantizerOutput): Output from the quantizer model's forward pass. + num_items_in_batch (Optional[torch.Tensor]): Optional tensor indicating the number of + valid items in each sequence in the batch (excluding padding). + + Returns: + Float[torch.Tensor, ""]: Scalar tensor representing the computed quantizer loss. + """ + pass diff --git a/src/genrec/trainers/trainer_quantizer/rqvae.py b/src/genrec/trainers/trainer_quantizer/rqvae.py new file mode 100644 index 0000000..d545b12 --- /dev/null +++ b/src/genrec/trainers/trainer_quantizer/rqvae.py @@ -0,0 +1,82 @@ +"""Basic RQ-VAE Trainer for quantizer models.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, TypeVar, Union + +from jaxtyping import Float +import torch + +from ...models import QuantizerModel, QuantizerOutput +from .base import ( + QuantizerTrainer, + QuantizerTrainerFactory, + QuantizerTrainingArguments, + QuantizerTrainingArgumentsFactory, +) + +__all__ = [ + "RQVAEQuantizerTrainer", + "RQVAEQuantizerTrainingArguments", +] + + +_RQVAEModel = TypeVar("_RQVAEModel", bound="QuantizerModel[Any, Any]") + + +@QuantizerTrainingArgumentsFactory.register("rqvae") +@dataclass +class RQVAEQuantizerTrainingArguments(QuantizerTrainingArguments): + """Training arguments for `RQVAEQuantizerTrainer`.""" + + pass + + +@QuantizerTrainerFactory.register("rqvae") +class RQVAEQuantizerTrainer(QuantizerTrainer[_RQVAEModel, RQVAEQuantizerTrainingArguments]): + """Basic RQ-VAE Trainer for quantizer models. + + This trainer extends the base `QuantizerTrainer` to implement the loss function + specific to RQ-VAE models. No additional training arguments are required beyond + those provided by the base class. + """ + + args: RQVAEQuantizerTrainingArguments + model: _RQVAEModel + + def compute_quantizer_loss( + self, + inputs: dict[str, Union[torch.Tensor, Any]], + outputs: QuantizerOutput, + num_items_in_batch: Optional[torch.Tensor] = None, + ) -> Float[torch.Tensor, ""]: + """Computes the model-agnostic quantizer loss for a batch of inputs and model outputs. + + This method should be implemented by all subclasses to compute the quantizer-specific loss + components, e.g., reconstruction loss, codebook loss, and commitment loss. + + Args: + inputs (dict[str, Union[torch.Tensor, Any]]): Dictionary of input tensors, i.e., the + `QuantizerCollator` output. + outputs (QuantizerOutput): Output from the quantizer model's forward pass. + num_items_in_batch (Optional[torch.Tensor]): Optional tensor indicating the number of + valid items in each sequence in the batch (excluding padding). + + Returns: + Float[torch.Tensor, ""]: Scalar tensor representing the computed quantizer loss. + """ + assert outputs.reconstruction_loss is not None, "Reconstruction loss must be provided in outputs." + assert outputs.codebook_loss is not None, "Codebook loss must be provided in outputs." + assert outputs.commitment_loss is not None, "Commitment loss must be provided in outputs." + + reconstruction_loss: Float[torch.Tensor, "B"] = outputs.reconstruction_loss + codebook_loss: Float[torch.Tensor, "B"] = outputs.codebook_loss + commitment_loss: Float[torch.Tensor, "B"] = outputs.commitment_loss + + loss = ( + reconstruction_loss.mean() + + self.args.codebook_loss_weight * codebook_loss.mean() + + self.args.commitment_loss_weight * commitment_loss.mean() + ) + return loss diff --git a/src/genrec/trainers/trainer_quantizer/utils/__init__.py b/src/genrec/trainers/trainer_quantizer/utils/__init__.py new file mode 100644 index 0000000..3cce65a --- /dev/null +++ b/src/genrec/trainers/trainer_quantizer/utils/__init__.py @@ -0,0 +1,18 @@ +"""Common utilities for quantizer trainers.""" + +__all__ = [] + +from .callbacks import EpochIntervalEvalCallback, HardStopCallback + +__all__ += [ + "EpochIntervalEvalCallback", + "HardStopCallback", +] + +from .evaluations import QuantizerMetricFactory, QuantizerMetricFn, compute_quantizer_metrics + +__all__ += [ + "QuantizerMetricFactory", + "QuantizerMetricFn", + "compute_quantizer_metrics", +] diff --git a/src/genrec/trainers/trainer_quantizer/utils/callbacks.py b/src/genrec/trainers/trainer_quantizer/utils/callbacks.py new file mode 100644 index 0000000..ec137ec --- /dev/null +++ b/src/genrec/trainers/trainer_quantizer/utils/callbacks.py @@ -0,0 +1,68 @@ +"""Callbacks for quantizer trainers.""" + +from __future__ import annotations + +from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments + +__all__ = [ + "EpochIntervalEvalCallback", + "HardStopCallback", +] + + +class EpochIntervalEvalCallback(TrainerCallback): + """Callback to perform evaluation every `eval_interval` epochs.""" + + def __init__(self, eval_interval: int = 5) -> None: + """Initializes the callback with evaluation parameters. + + Args: + eval_interval (int): Number of epochs between evaluations. + """ + self.eval_interval = eval_interval + + def on_epoch_end( # type: ignore - must return TrainerControl + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> TrainerControl: + """Called at the end of an epoch to determine if evaluation should be performed. + If the current epoch is not a multiple of `eval_interval`, evaluation is skipped. + """ + assert state.epoch is not None, "EpochEvalCallback requires `state.epoch` to be not None." + current_epoch = int(state.epoch) + if current_epoch % self.eval_interval != 0: + control.should_evaluate = False + control.should_save = False + return control + + +class HardStopCallback(TrainerCallback): + """Callback to stop training at a specific epoch.""" + + def __init__(self, stop_epoch: int = -1) -> None: + """Initializes the callback with the stopping epoch. + + Args: + stop_epoch (int): The epoch at which to stop training. If less than 0, training continues + until the maximum number of epochs. Default is -1. + """ + self.stop_epoch = stop_epoch + + def on_epoch_end( # type: ignore - must return TrainerControl + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> TrainerControl: + """Called at the end of an epoch to determine if training should be stopped. + If the current epoch is greater than or equal to `stop_epoch`, training is stopped. + """ + assert state.epoch is not None, "HardStopCallback requires `state.epoch` to be not None." + current_epoch = int(state.epoch) + if self.stop_epoch >= 0 and current_epoch >= self.stop_epoch: + control.should_training_stop = True + return control diff --git a/src/genrec/trainers/trainer_quantizer/utils/evaluations.py b/src/genrec/trainers/trainer_quantizer/utils/evaluations.py new file mode 100644 index 0000000..8163f48 --- /dev/null +++ b/src/genrec/trainers/trainer_quantizer/utils/evaluations.py @@ -0,0 +1,184 @@ +"""Evaluation utilities for quantizer trainers.""" + +from __future__ import annotations + +from typing import Any, Callable, Dict, Protocol, Sequence, Tuple + +from jaxtyping import Float, Int +import numpy as np +from transformers import EvalPrediction + +from ....datasets import QuantizerDataset + +__all__ = [ + "QuantizerMetricFactory", + "QuantizerMetricFn", + "calc_metric_codebook_usage", + "calc_metric_code_collision", + "compute_quantizer_metrics", +] + + +def compute_quantizer_metrics( + prediction: EvalPrediction, + train_dataset: QuantizerDataset, + codebook_size: int, + metrics: Sequence[Tuple[str, Dict[str, Any]]] = ( + ("codebook_usage", {}), + ("code_collision", {}), + ), +) -> Dict[str, float]: + """Compute metrics for quantizer trainers. + + Args: + prediction (EvalPrediction): Object containing model predictions and labels. Predictions are + expected to be the dict values from `QuantizerTrainer.compute_loss`'s output dict, + at least including `semantic_ids`, `reconstruction_loss`, `codebook_loss`, `commitment_loss`, + and `item_id` in the first 5 elements. The labels are ignored for quantizer metrics. + train_dataset (QuantizerDataset): Dataset used during training; required for global metrics. + codebook_size (int): Size of the codebook used in the quantizer. + metrics (Sequence[Tuple[str, Dict[str, Any]]]): Metric specifications, where each tuple + comprises the metric name and an optional parameter dictionary. Default is + [('codebook_usage', {}), ('code_collision', {})]. + + Returns: + Dict[str, float]: Dictionary containing computed metric values keyed by metric name. + """ + if isinstance(prediction.predictions, tuple): + assert len(prediction.predictions) >= 5, ( + "Predictions should contain at least 5 elements: " + "`semantic_ids`, `reconstruction_loss`, `codebook_loss`, `commitment_loss`, and `item_id`." + ) + semantic_ids: Int[np.ndarray, "B C"] = prediction.predictions[0] + reconstruction_loss: Float[np.ndarray, "B"] = prediction.predictions[1] + codebook_loss: Float[np.ndarray, "B"] = prediction.predictions[2] + commitment_loss: Float[np.ndarray, "B"] = prediction.predictions[3] + item_id: Int[np.ndarray, "B"] = prediction.predictions[4] + else: + raise ValueError("Predictions should be a tuple containing model output dict's values.") + + results: Dict[str, float] = { + "reconstruction_loss": float(np.mean(reconstruction_loss)), + "codebook_loss": float(np.mean(codebook_loss)), + "commitment_loss": float(np.mean(commitment_loss)), + } + for metric_name, metric_params in metrics: + metric_fn = QuantizerMetricFactory.create(metric_name) + metric_results = metric_fn( + semantic_ids=semantic_ids, + item_id=item_id, + train_dataset=train_dataset, + codebook_size=codebook_size, + **metric_params, + ) + results.update(metric_results) + + return results + + +class QuantizerMetricFn(Protocol): + """Protocol for quantizer metric functions.""" + + def __call__( # pragma: no cover - protocol + self, + semantic_ids: Int[np.ndarray, "B C"], + item_id: Int[np.ndarray, "B"], + train_dataset: QuantizerDataset, + codebook_size: int, + **kwargs: Any, + ) -> Dict[str, float]: + """Compute the metric. + + Args: + semantic_ids (Int[np.ndarray, "B C"]): Semantic IDs predicted by the quantizer. + item_id (Int[np.ndarray, "B"]): Item IDs corresponding to the input embeddings. + train_dataset (QuantizerDataset): Dataset used during training; required for global metrics. + codebook_size (int): Size of the codebook used in the quantizer. + **kwargs (Any): Additional keyword arguments for metric computation. + + Returns: + Dict[str, float]: Dictionary containing computed metric values keyed by metric name. + """ + ... + + +class QuantizerMetricFactory: # pragma: no cover - factory class + """Factory for creating quantizer metric functions.""" + + _registry: dict[str, QuantizerMetricFn] = {} + + @classmethod + def register(cls, name: str) -> Callable[[QuantizerMetricFn], QuantizerMetricFn]: + """Decorator to register a metric function.""" + + def decorator(fn: QuantizerMetricFn) -> QuantizerMetricFn: + if name in cls._registry: + raise ValueError(f"Metric '{name}' is already registered.") + cls._registry[name] = fn + return fn + + return decorator + + @classmethod + def create(cls, name: str) -> QuantizerMetricFn: + """Create a metric function by name.""" + if name not in cls._registry: + raise ValueError(f"Metric '{name}' is not registered.") + metric_fn = cls._registry[name] + return metric_fn + + +@QuantizerMetricFactory.register("codebook_usage") +def calc_metric_codebook_usage( + semantic_ids: Int[np.ndarray, "B C"], + item_id: Int[np.ndarray, "B"], + train_dataset: QuantizerDataset, + codebook_size: int, + **kwargs: Any, +) -> Dict[str, float]: + """Calculate codebook usage metric. + + Args: + semantic_ids (Int[np.ndarray, "B C"]): Semantic IDs predicted by the quantizer. + item_id (Int[np.ndarray, "B"]): Item IDs corresponding to the input embeddings. + train_dataset (QuantizerDataset): Dataset used during training; required for global metrics. + codebook_size (int): Size of the codebook used in the quantizer. + **kwargs (Any): Additional keyword arguments for metric computation. + + Returns: + Dict[str, float]: Dictionary containing the codebook usage metric. + """ + num_codebooks = semantic_ids.shape[1] + results: Dict[str, float] = {} + for codebook_idx in range(num_codebooks): + used_codes = np.unique(semantic_ids[:, codebook_idx]) + usage_ratio = len(used_codes) / codebook_size + results[f"codebook_{codebook_idx}_usage"] = usage_ratio + return results + + +@QuantizerMetricFactory.register("code_collision") +def calc_metric_code_collision( + semantic_ids: Int[np.ndarray, "B C"], + item_id: Int[np.ndarray, "B"], + train_dataset: QuantizerDataset, + codebook_size: int, + **kwargs: Any, +) -> Dict[str, float]: + """Calculate code collision metric. + + Args: + semantic_ids (Int[np.ndarray, "B C"]): Semantic IDs predicted by the quantizer. + item_id (Int[np.ndarray, "B"]): Item IDs corresponding to the input embeddings. + train_dataset (QuantizerDataset): Dataset used during training; required for global metrics. + codebook_size (int): Size of the codebook used in the quantizer. + **kwargs (Any): Additional keyword arguments for metric computation. + + Returns: + Dict[str, float]: Dictionary containing the code collision metric. + """ + codes = [tuple(code_ids) for code_ids in semantic_ids] + unique_codes = set(codes) + num_collisions = len(codes) - len(unique_codes) + collision_rate = num_collisions / len(codes) if len(codes) > 0 else 0.0 + return {"code_collision_rate": collision_rate} diff --git a/src/genrec/trainers/trainer_seqrec/base.py b/src/genrec/trainers/trainer_seqrec/base.py index 6805fe0..f2d1ca4 100644 --- a/src/genrec/trainers/trainer_seqrec/base.py +++ b/src/genrec/trainers/trainer_seqrec/base.py @@ -7,8 +7,7 @@ from functools import partial from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Tuple, Type, TypeVar, Union -from jaxtyping import Float, Int -import numpy as np +from jaxtyping import Float import torch import torch.nn as nn import torch.nn.functional as F @@ -16,11 +15,10 @@ from ...datasets import SeqRecCollator, SeqRecDataset from ...models import SeqRecModel, SeqRecOutput -from ..utils.callbacks import EpochIntervalEvalCallback, HardStopCallback -from ..utils.evaluations import MetricFactory, clip_top_k +from .utils.callbacks import EpochIntervalEvalCallback, HardStopCallback +from .utils.evaluations import clip_top_k, compute_seqrec_metrics __all__ = [ - "compute_seqrec_metrics", "SeqRecTrainer", "SeqRecTrainerFactory", "SeqRecTrainingArguments", @@ -61,52 +59,6 @@ def create(cls, name: str, **kwargs) -> SeqRecTrainingArguments: return training_args_cls(**kwargs) -class SeqRecTrainerFactory: # pragma: no cover - factory class - """Factory for creating `SeqRecTrainer` instances.""" - - _registry: dict[str, Type[SeqRecTrainer[Any, Any]]] = {} - - @classmethod - def register(cls, name: str) -> Callable[[Type[_SeqRecTrainer]], Type[_SeqRecTrainer]]: - """Decorator to register a `SeqRecTrainer` implementation.""" - - def decorator(trainer_cls: Type[_SeqRecTrainer]) -> Type[_SeqRecTrainer]: - if name in cls._registry: - raise ValueError(f"SeqRec trainer '{name}' is already registered.") - cls._registry[name] = trainer_cls - return trainer_cls - - return decorator - - @classmethod - def create( - cls, - name: str, - model: SeqRecModel[Any, Any], - args: SeqRecTrainingArguments, - data_collator: SeqRecCollator, - train_dataset: SeqRecDataset, - eval_dataset: Optional[SeqRecDataset] = None, - compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - callbacks: Optional[List[TrainerCallback]] = None, - **kwargs, - ) -> SeqRecTrainer[Any, Any]: - """Creates an instance of a registered `SeqRecTrainer`.""" - if name not in cls._registry: - raise ValueError(f"SeqRec trainer '{name}' is not registered.") - trainer_cls = cls._registry[name] - return trainer_cls( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - compute_metrics=compute_metrics, - callbacks=callbacks, - **kwargs, - ) - - @dataclass class SeqRecTrainingArguments(TrainingArguments): """Training arguments for sequential recommendation trainers. @@ -116,6 +68,7 @@ class SeqRecTrainingArguments(TrainingArguments): computation and evaluation. If True, both user and item embeddings are normalized to unit length, and the dot product is equivalent to cosine similarity. Default is False. eval_interval (int): Number of epochs between each evaluation. Default is 5. + train_stop_epoch (int): Number of epochs to stop training. Default is -1 (no early stop). metrics (Sequence[Tuple[str, Dict[str, Any]]]): Metric names and their parameters to compute during evaluation. Default is [('hr', {}), ('ndcg', {}), ('popularity', {'p': (0.1, 0.2)}), ("unpopularity", {"p": (0.2, 0.4)})]. @@ -175,73 +128,50 @@ class SeqRecTrainingArguments(TrainingArguments): ) -def compute_seqrec_metrics( - prediction: EvalPrediction, - train_dataset: SeqRecDataset, - top_k: Sequence[int] = (1, 5, 10), - metrics: Sequence[Tuple[str, Dict[str, Any]]] = ( - ("hr", {}), - ("ndcg", {}), - ("popularity", {"p": (0.1, 0.2)}), - ("unpopularity", {"p": (0.2, 0.4)}), - ), - device: Union[torch.device, str, None] = None, -) -> Dict[str, float]: - """Compute metrics for sequential recommendation tasks. +class SeqRecTrainerFactory: # pragma: no cover - factory class + """Factory for creating `SeqRecTrainer` instances.""" - Args: - prediction (EvalPrediction): Object containing model predictions and labels. Predictions are - expected to be the precomputed top-k item indices per user (shape: ``[num_users, max_k]``). - train_dataset (SeqRecDataset): Dataset used during training; required for global metrics - such as popularity-based measurements. - top_k (Sequence[int]): Cutoff values for computing top-K metrics, determining how many - predictions to consider for each metric. Default is (1, 5, 10). - metrics (Sequence[Tuple[str, Dict[str, Any]]]): Metric specifications, where each tuple - comprises the metric name and an optional parameter dictionary. Default is - [('hr', {}), ('ndcg', {}), ('popularity', {'p': (0.1, 0.2)}), - ('unpopularity', {'p': (0.2, 0.4)})]. - device (Union[torch.device, str, None]): Device used for metric computations. - If None, defaults to CPU. Default is None. - - Returns: - Dict[str, float]: Dictionary containing computed metric values keyed by metric name. + _registry: dict[str, Type[SeqRecTrainer[Any, Any]]] = {} - .. note:: - As we may call this evaluation function for global metrics (e.g., popularity/fairness), - you should ensure the `train_dataset` is provided if any global metrics are specified. - In addition, `batch_eval_metrics` in `SeqRecTrainingArguments` should be set to `False` - to avoid conflicts. - """ - torch_device = torch.device(device) if device is not None else torch.device("cpu") - - topk_indices: Int[torch.Tensor, "B K"] - if isinstance(prediction.predictions, tuple): # pragma: no cover - rarely used - topk_indices = torch.as_tensor(prediction.predictions[0], dtype=torch.long, device=torch_device) - else: - topk_indices = torch.as_tensor(prediction.predictions, dtype=torch.long, device=torch_device) - - labels: Int[np.ndarray, "B L"] - if isinstance(prediction.label_ids, tuple): # pragma: no cover - rarely used - labels = prediction.label_ids[0].astype(np.int64) - else: - labels = prediction.label_ids.astype(np.int64) - last_step_labels: Int[torch.Tensor, "B"] - last_step_labels = torch.as_tensor(labels[:, -1], dtype=torch.long, device=torch_device) - - results: Dict[str, float] = {} - for k in top_k: - sliced_topk_indices = topk_indices[:, :k] - for metric_name, metric_params in metrics: - metric_fn = MetricFactory.create(metric_name) - metric_results = metric_fn( - topk_indices=sliced_topk_indices, - labels=last_step_labels, - train_dataset=train_dataset, - **metric_params, - ) - results.update(metric_results) + @classmethod + def register(cls, name: str) -> Callable[[Type[_SeqRecTrainer]], Type[_SeqRecTrainer]]: + """Decorator to register a `SeqRecTrainer` implementation.""" - return results + def decorator(trainer_cls: Type[_SeqRecTrainer]) -> Type[_SeqRecTrainer]: + if name in cls._registry: + raise ValueError(f"SeqRec trainer '{name}' is already registered.") + cls._registry[name] = trainer_cls + return trainer_cls + + return decorator + + @classmethod + def create( + cls, + name: str, + model: SeqRecModel[Any, Any], + args: SeqRecTrainingArguments, + data_collator: SeqRecCollator, + train_dataset: SeqRecDataset, + eval_dataset: Optional[SeqRecDataset] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + **kwargs, + ) -> SeqRecTrainer[Any, Any]: + """Creates an instance of a registered `SeqRecTrainer`.""" + if name not in cls._registry: + raise ValueError(f"SeqRec trainer '{name}' is not registered.") + trainer_cls = cls._registry[name] + return trainer_cls( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, + callbacks=callbacks, + **kwargs, + ) class SeqRecTrainer(Trainer, Generic[_SeqRecModel, _SeqRecTrainingArguments], ABC): @@ -337,7 +267,7 @@ def compute_loss( return_outputs: bool = False, num_items_in_batch: Optional[torch.Tensor] = None, ) -> Union[Float[torch.Tensor, ""], Tuple[Float[torch.Tensor, ""], Dict[str, torch.Tensor]]]: - """Computes the loss for a batch of inputs. This should be overridden by all subclasses. + """Computes the loss for a batch of inputs. Args: model (nn.Module): Model being trained. @@ -401,6 +331,7 @@ def compute_rec_loss( # pragma: no cover - abstract method norm_embeddings: bool = False, ) -> Float[torch.Tensor, ""]: """Computes the recommendation loss for a batch of inputs and model outputs. + This should be implemented by all subclasses. Args: inputs (dict[str, Union[torch.Tensor, Any]]): Dictionary of input tensors, i.e., the diff --git a/src/genrec/trainers/trainer_seqrec/bce_d2lr.py b/src/genrec/trainers/trainer_seqrec/bce_d2lr.py index c6ad349..ec92e2c 100644 --- a/src/genrec/trainers/trainer_seqrec/bce_d2lr.py +++ b/src/genrec/trainers/trainer_seqrec/bce_d2lr.py @@ -44,7 +44,7 @@ class BCED2LRSeqRecTrainer(SeqRecTrainer[_SeqRecModel, BCED2LRSeqRecTrainingArgu weights. References: - - Dual Debiasing in LLM-based Recommendation. SIGIR '25. + - Dual Debiasing in LLM-based Recommendation. SIGIR '25. """ args: BCED2LRSeqRecTrainingArguments diff --git a/src/genrec/trainers/trainer_seqrec/bce_dros.py b/src/genrec/trainers/trainer_seqrec/bce_dros.py index b611ef9..6b0fab4 100644 --- a/src/genrec/trainers/trainer_seqrec/bce_dros.py +++ b/src/genrec/trainers/trainer_seqrec/bce_dros.py @@ -58,7 +58,7 @@ class BCEDROSSeqRecTrainer(SeqRecTrainer[_SeqRecModel, BCEDROSSeqRecTrainingArgu MSE loss. We have fixed these issues in this implementation. References: - - A Generic Learning Framework for Sequential Recommendation with Distribution Shifts. SIGIR '23. + - A Generic Learning Framework for Sequential Recommendation with Distribution Shifts. SIGIR '23. """ args: BCEDROSSeqRecTrainingArguments diff --git a/src/genrec/trainers/trainer_seqrec/bce_logdet.py b/src/genrec/trainers/trainer_seqrec/bce_logdet.py index 5455073..5c2748c 100644 --- a/src/genrec/trainers/trainer_seqrec/bce_logdet.py +++ b/src/genrec/trainers/trainer_seqrec/bce_logdet.py @@ -52,8 +52,7 @@ class BCELogDetSeqRecTrainer(SeqRecTrainer[_SeqRecModel, BCELogDetSeqRecTraining MSE loss, is replaced with the standard BCE loss for recommendation. References: - - Mitigating the Popularity Bias of Graph Collaborative Filtering: A Dimensional Collapse Perspective. NeurIPS '23. - + - Mitigating the Popularity Bias of Graph Collaborative Filtering: A Dimensional Collapse Perspective. NeurIPS '23. """ args: BCELogDetSeqRecTrainingArguments diff --git a/src/genrec/trainers/trainer_seqrec/bce_r2rec.py b/src/genrec/trainers/trainer_seqrec/bce_r2rec.py index fb045c4..c2491b0 100644 --- a/src/genrec/trainers/trainer_seqrec/bce_r2rec.py +++ b/src/genrec/trainers/trainer_seqrec/bce_r2rec.py @@ -58,7 +58,7 @@ class BCER2RecSeqRecTrainer(SeqRecTrainer[_SeqRecModel, BCER2RecSeqRecTrainingAr stabilize training, a temperature parameter is applied to the reweighting factors. References: - - Reembedding and Reweighting are Needed for Tail Item Sequential Recommendation, WWW '25. + - Reembedding and Reweighting are Needed for Tail Item Sequential Recommendation, WWW '25. """ args: BCER2RecSeqRecTrainingArguments diff --git a/src/genrec/trainers/trainer_seqrec/bce_resn.py b/src/genrec/trainers/trainer_seqrec/bce_resn.py index f2f26f5..42e4552 100644 --- a/src/genrec/trainers/trainer_seqrec/bce_resn.py +++ b/src/genrec/trainers/trainer_seqrec/bce_resn.py @@ -45,7 +45,7 @@ class BCEReSNSeqRecTrainer(SeqRecTrainer[_SeqRecModel, BCEReSNSeqRecTrainingArgu the model outputs in each step is treated as user embeddings. References: - - How Do Recommendation Models Amplify Popularity Bias? An Analysis from the Spectral Perspective. WSDM '25. + - How Do Recommendation Models Amplify Popularity Bias? An Analysis from the Spectral Perspective. WSDM '25. """ args: BCEReSNSeqRecTrainingArguments diff --git a/src/genrec/trainers/trainer_seqrec/sl.py b/src/genrec/trainers/trainer_seqrec/sl.py index 9557490..3858cc8 100644 --- a/src/genrec/trainers/trainer_seqrec/sl.py +++ b/src/genrec/trainers/trainer_seqrec/sl.py @@ -51,7 +51,7 @@ class SLSeqRecTrainer(SeqRecTrainer[_SeqRecModel, SLSeqRecTrainingArguments]): which is commonly used in sequential recommendation tasks. References: - - PSL: Rethinking and Improving Softmax Loss from Pairwise Perspective for Recommendation. NeurIPS '24. + - PSL: Rethinking and Improving Softmax Loss from Pairwise Perspective for Recommendation. NeurIPS '24. """ args: SLSeqRecTrainingArguments diff --git a/src/genrec/trainers/trainer_seqrec/sl_d2lr.py b/src/genrec/trainers/trainer_seqrec/sl_d2lr.py index 84da488..2ffe549 100644 --- a/src/genrec/trainers/trainer_seqrec/sl_d2lr.py +++ b/src/genrec/trainers/trainer_seqrec/sl_d2lr.py @@ -59,7 +59,7 @@ class SLD2LRSeqRecTrainer(SeqRecTrainer[_SeqRecModel, SLD2LRSeqRecTrainingArgume Softmax Loss. To stabilize training, we apply a temperature parameter to the IPS weights. References: - - Dual Debiasing in LLM-based Recommendation. SIGIR '25. + - Dual Debiasing in LLM-based Recommendation. SIGIR '25. """ args: SLD2LRSeqRecTrainingArguments diff --git a/src/genrec/trainers/trainer_seqrec/sl_dros.py b/src/genrec/trainers/trainer_seqrec/sl_dros.py index eb6e5b8..b4560c8 100644 --- a/src/genrec/trainers/trainer_seqrec/sl_dros.py +++ b/src/genrec/trainers/trainer_seqrec/sl_dros.py @@ -74,7 +74,7 @@ class SLDROSSeqRecTrainer(SeqRecTrainer[_SeqRecModel, SLDROSSeqRecTrainingArgume MSE loss. We have fixed these issues in this implementation. References: - - A Generic Learning Framework for Sequential Recommendation with Distribution Shifts. SIGIR '23. + - A Generic Learning Framework for Sequential Recommendation with Distribution Shifts. SIGIR '23. """ args: SLDROSSeqRecTrainingArguments diff --git a/src/genrec/trainers/trainer_seqrec/sl_logdet.py b/src/genrec/trainers/trainer_seqrec/sl_logdet.py index 9d4f2d9..94d3e4d 100644 --- a/src/genrec/trainers/trainer_seqrec/sl_logdet.py +++ b/src/genrec/trainers/trainer_seqrec/sl_logdet.py @@ -68,7 +68,7 @@ class SLLogDetSeqRecTrainer(SeqRecTrainer[_SeqRecModel, SLLogDetSeqRecTrainingAr MSE loss, is replaced with the standard SL loss for recommendation. References: - - Mitigating the Popularity Bias of Graph Collaborative Filtering: A Dimensional Collapse Perspective. NeurIPS '23. + - Mitigating the Popularity Bias of Graph Collaborative Filtering: A Dimensional Collapse Perspective. NeurIPS '23. """ args: SLLogDetSeqRecTrainingArguments diff --git a/src/genrec/trainers/trainer_seqrec/sl_r2rec.py b/src/genrec/trainers/trainer_seqrec/sl_r2rec.py index f42ee27..7986388 100644 --- a/src/genrec/trainers/trainer_seqrec/sl_r2rec.py +++ b/src/genrec/trainers/trainer_seqrec/sl_r2rec.py @@ -74,7 +74,7 @@ class SLR2RecSeqRecTrainer(SeqRecTrainer[_SeqRecModel, SLR2RecSeqRecTrainingArgu temperature parameter is applied to the reweighting factors. References: - - Reembedding and Reweighting are Needed for Tail Item Sequential Recommendation. WWW '25. + - Reembedding and Reweighting are Needed for Tail Item Sequential Recommendation. WWW '25. """ args: SLR2RecSeqRecTrainingArguments diff --git a/src/genrec/trainers/trainer_seqrec/sl_resn.py b/src/genrec/trainers/trainer_seqrec/sl_resn.py index 1d0f0a9..981316e 100644 --- a/src/genrec/trainers/trainer_seqrec/sl_resn.py +++ b/src/genrec/trainers/trainer_seqrec/sl_resn.py @@ -61,7 +61,7 @@ class SLReSNSeqRecTrainer(SeqRecTrainer[_SeqRecModel, SLReSNSeqRecTrainingArgume the model outputs in each step is treated as user embeddings. References: - - How Do Recommendation Models Amplify Popularity Bias? An Analysis from the Spectral Perspective. WSDM '25. + - How Do Recommendation Models Amplify Popularity Bias? An Analysis from the Spectral Perspective. WSDM '25. """ args: SLReSNSeqRecTrainingArguments diff --git a/src/genrec/trainers/trainer_seqrec/utils/__init__.py b/src/genrec/trainers/trainer_seqrec/utils/__init__.py new file mode 100644 index 0000000..102b51e --- /dev/null +++ b/src/genrec/trainers/trainer_seqrec/utils/__init__.py @@ -0,0 +1,19 @@ +"""Common utilities for seqrec trainers.""" + +__all__ = [] + +from .callbacks import EpochIntervalEvalCallback, HardStopCallback + +__all__ += [ + "EpochIntervalEvalCallback", + "HardStopCallback", +] + +from .evaluations import SeqRecMetricFactory, SeqRecMetricFn, clip_top_k, compute_seqrec_metrics + +__all__ += [ + "SeqRecMetricFactory", + "SeqRecMetricFn", + "clip_top_k", + "compute_seqrec_metrics", +] diff --git a/src/genrec/trainers/utils/callbacks.py b/src/genrec/trainers/trainer_seqrec/utils/callbacks.py similarity index 98% rename from src/genrec/trainers/utils/callbacks.py rename to src/genrec/trainers/trainer_seqrec/utils/callbacks.py index 8897420..ab4e716 100644 --- a/src/genrec/trainers/utils/callbacks.py +++ b/src/genrec/trainers/trainer_seqrec/utils/callbacks.py @@ -1,4 +1,4 @@ -"""Callbacks for trainers.""" +"""Callbacks for seqrec trainers.""" from __future__ import annotations diff --git a/src/genrec/trainers/utils/evaluations.py b/src/genrec/trainers/trainer_seqrec/utils/evaluations.py similarity index 62% rename from src/genrec/trainers/utils/evaluations.py rename to src/genrec/trainers/trainer_seqrec/utils/evaluations.py index 0a0ed52..e445977 100644 --- a/src/genrec/trainers/utils/evaluations.py +++ b/src/genrec/trainers/trainer_seqrec/utils/evaluations.py @@ -1,33 +1,105 @@ -"""Evaluation utilities for recommendation tasks.""" +"""Evaluation utilities for seqrec models.""" from __future__ import annotations -from typing import Any, Callable, Dict, Protocol, Sequence, Tuple +from typing import Any, Callable, Dict, Protocol, Sequence, Tuple, Union -import torch from jaxtyping import Int +import numpy as np +import torch +from transformers import EvalPrediction -from ...datasets import SeqRecDataset +from ....datasets import SeqRecDataset __all__ = [ + "SeqRecMetricFactory", + "SeqRecMetricFn", "calc_metric_hr", "calc_metric_ndcg", "calc_metric_popularity", "calc_metric_unpopularity", "clip_top_k", - "MetricFactory", - "MetricFn", + "compute_seqrec_metrics", ] +def compute_seqrec_metrics( + prediction: EvalPrediction, + train_dataset: SeqRecDataset, + top_k: Sequence[int] = (1, 5, 10), + metrics: Sequence[Tuple[str, Dict[str, Any]]] = ( + ("hr", {}), + ("ndcg", {}), + ("popularity", {"p": (0.1, 0.2)}), + ("unpopularity", {"p": (0.2, 0.4)}), + ), + device: Union[torch.device, str, None] = None, +) -> Dict[str, float]: + """Compute metrics for sequential recommendation tasks. + + Args: + prediction (EvalPrediction): Object containing model predictions and labels. Predictions are + expected to be the precomputed top-k item indices per user (shape: ``[num_users, max_k]``). + train_dataset (SeqRecDataset): Dataset used during training; required for global metrics + such as popularity-based measurements. + top_k (Sequence[int]): Cutoff values for computing top-K metrics, determining how many + predictions to consider for each metric. Default is (1, 5, 10). + metrics (Sequence[Tuple[str, Dict[str, Any]]]): Metric specifications, where each tuple + comprises the metric name and an optional parameter dictionary. Default is + [('hr', {}), ('ndcg', {}), ('popularity', {'p': (0.1, 0.2)}), + ('unpopularity', {'p': (0.2, 0.4)})]. + device (Union[torch.device, str, None]): Device used for metric computations. + If None, defaults to CPU. Default is None. + + Returns: + Dict[str, float]: Dictionary containing computed metric values keyed by metric name. + + .. note:: + As we may call this evaluation function for global metrics (e.g., popularity/fairness), + you should ensure the `train_dataset` is provided if any global metrics are specified. + In addition, `batch_eval_metrics` in `SeqRecTrainingArguments` should be set to `False` + to avoid conflicts. + """ + torch_device = torch.device(device) if device is not None else torch.device("cpu") + + topk_indices: Int[torch.Tensor, "B K"] + if isinstance(prediction.predictions, tuple): # pragma: no cover - rarely used + topk_indices = torch.as_tensor(prediction.predictions[0], dtype=torch.long, device=torch_device) + else: + topk_indices = torch.as_tensor(prediction.predictions, dtype=torch.long, device=torch_device) + + labels: Int[np.ndarray, "B L"] + if isinstance(prediction.label_ids, tuple): # pragma: no cover - rarely used + labels = prediction.label_ids[0].astype(np.int64) + else: + labels = prediction.label_ids.astype(np.int64) + last_step_labels: Int[torch.Tensor, "B"] + last_step_labels = torch.as_tensor(labels[:, -1], dtype=torch.long, device=torch_device) + + results: Dict[str, float] = {} + for k in top_k: + sliced_topk_indices = topk_indices[:, :k] + for metric_name, metric_params in metrics: + metric_fn = SeqRecMetricFactory.create(metric_name) + metric_results = metric_fn( + topk_indices=sliced_topk_indices, + labels=last_step_labels, + train_dataset=train_dataset, + **metric_params, + ) + results.update(metric_results) + + return results + + def clip_top_k(top_k: Sequence[int], item_size: int) -> Tuple[int, ...]: """Clamp sorted ``top_k`` cutoffs to ``item_size``, sort and remove duplicates if any.""" top_k_set = set([min(k, item_size) for k in top_k]) return tuple(sorted(top_k_set)) -class MetricFn(Protocol): # pragma: no cover - protocol - """Protocol for metric functions.""" +class SeqRecMetricFn(Protocol): # pragma: no cover - protocol + """Protocol for seqrec metric functions.""" def __call__( self, @@ -51,16 +123,16 @@ def __call__( ... -class MetricFactory: # pragma: no cover - factory class - """Factory for creating metric functions.""" +class SeqRecMetricFactory: # pragma: no cover - factory class + """Factory for creating seqrec metric functions.""" - _registry: dict[str, MetricFn] = {} + _registry: dict[str, SeqRecMetricFn] = {} @classmethod - def register(cls, name: str) -> Callable[[MetricFn], MetricFn]: + def register(cls, name: str) -> Callable[[SeqRecMetricFn], SeqRecMetricFn]: """Decorator to register a metric function.""" - def decorator(fn: MetricFn) -> MetricFn: + def decorator(fn: SeqRecMetricFn) -> SeqRecMetricFn: if name in cls._registry: raise ValueError(f"Metric '{name}' is already registered.") cls._registry[name] = fn @@ -69,7 +141,7 @@ def decorator(fn: MetricFn) -> MetricFn: return decorator @classmethod - def create(cls, name: str) -> MetricFn: + def create(cls, name: str) -> SeqRecMetricFn: """Create a metric function by name.""" if name not in cls._registry: raise ValueError(f"Metric '{name}' is not registered.") @@ -77,7 +149,7 @@ def create(cls, name: str) -> MetricFn: return metric_fn -@MetricFactory.register("hr") +@SeqRecMetricFactory.register("hr") def calc_metric_hr( topk_indices: Int[torch.Tensor, "B K"], labels: Int[torch.Tensor, "B"], @@ -101,7 +173,7 @@ def calc_metric_hr( return {f"hr@{K}": hr} -@MetricFactory.register("ndcg") +@SeqRecMetricFactory.register("ndcg") def calc_metric_ndcg( topk_indices: Int[torch.Tensor, "B K"], labels: Int[torch.Tensor, "B"], @@ -127,7 +199,7 @@ def calc_metric_ndcg( return {f"ndcg@{K}": ndcg} -@MetricFactory.register("popularity") +@SeqRecMetricFactory.register("popularity") def calc_metric_popularity( topk_indices: Int[torch.Tensor, "B K"], labels: Int[torch.Tensor, "B"], @@ -165,7 +237,7 @@ def calc_metric_popularity( return results -@MetricFactory.register("unpopularity") +@SeqRecMetricFactory.register("unpopularity") def calc_metric_unpopularity( topk_indices: Int[torch.Tensor, "B K"], labels: Int[torch.Tensor, "B"], diff --git a/src/genrec/trainers/utils/__init__.py b/src/genrec/trainers/utils/__init__.py deleted file mode 100644 index a997e58..0000000 --- a/src/genrec/trainers/utils/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Common utilities for trainers.""" - -__all__ = [] - -from .callbacks import EpochIntervalEvalCallback, HardStopCallback - -__all__ += [ - "EpochIntervalEvalCallback", - "HardStopCallback", -] - -from .evaluations import MetricFactory, MetricFn - -__all__ += [ - "MetricFactory", - "MetricFn", -] diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 59e8919..9449c77 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -58,6 +58,13 @@ def _make_textual_frame(item_pool: int) -> pd.DataFrame: ) +def _make_aux_item_embeddings(item_pool: int, aux_dim: int) -> np.ndarray: + values = np.arange((item_pool + 1) * aux_dim, dtype=np.float32) + values = values.reshape(item_pool + 1, aux_dim) + values[0] = 0.0 + return values + + def _make_short_interaction_frame(length: int) -> pd.DataFrame: return pd.DataFrame( { @@ -105,7 +112,12 @@ def _assert_seqrec_batches(loader: DataLoader, expected_sizes: list[int], num_ne assert batch["negative_item_ids"].dtype == torch.int32 -def _assert_quantizer_batches(loader: DataLoader, expected_sizes: list[int], embedding_dim: int) -> None: +def _assert_quantizer_batches( + loader: DataLoader, + expected_sizes: list[int], + embedding_dim: int, + aux_embedding_dim: int | None = None, +) -> None: batches = list(loader) assert len(batches) == len(expected_sizes) for idx, batch in enumerate(batches): @@ -113,6 +125,9 @@ def _assert_quantizer_batches(loader: DataLoader, expected_sizes: list[int], emb assert batch["item_id"].shape == (expected_size,) assert batch["item_embedding"].shape == (expected_size, embedding_dim) assert batch["item_embedding"].dtype == torch.float32 + if aux_embedding_dim is not None: + assert batch["aux_item_embedding"].shape == (expected_size, aux_embedding_dim) + assert batch["aux_item_embedding"].dtype == torch.float32 @pytest.fixture() @@ -185,16 +200,45 @@ def seqrec_dataset(interaction_frame, sid_cache) -> SeqRecDataset: @pytest.fixture() def quantizer_dataset(interaction_frame, textual_frame, dummy_encoder) -> QuantizerDataset: + item_pool = int(textual_frame["ItemID"].max()) + aux_embeddings = _make_aux_item_embeddings(item_pool, aux_dim=3) return QuantizerDataset( interaction_data_path=interaction_frame, - split=DatasetSplitLiteral.TRAIN, max_seq_length=4, min_seq_length=1, textual_data_path=textual_frame, lm_encoder=dummy_encoder, + aux_item_embeddings=aux_embeddings, ) +def test_quantizer_dataset_aux_properties(interaction_frame, textual_frame, dummy_encoder) -> None: + item_pool = int(textual_frame["ItemID"].max()) + aux_embeddings = _make_aux_item_embeddings(item_pool, aux_dim=2) + dataset = QuantizerDataset( + interaction_data_path=interaction_frame, + max_seq_length=4, + min_seq_length=1, + textual_data_path=textual_frame, + lm_encoder=dummy_encoder, + aux_item_embeddings=aux_embeddings, + ) + + assert dataset.aux_item_embeddings is not None + np.testing.assert_allclose(dataset.aux_item_embeddings[1:], aux_embeddings[1:]) + assert dataset.aux_embedding_dim == aux_embeddings.shape[1] + + dataset_no_aux = QuantizerDataset( + interaction_data_path=interaction_frame, + max_seq_length=2, + min_seq_length=1, + textual_data_path=textual_frame, + lm_encoder=dummy_encoder, + ) + assert dataset_no_aux.aux_item_embeddings is None + assert dataset_no_aux.aux_embedding_dim is None + + def test_genrec_dataset_examples(genrec_dataset, sid_cache, dummy_encoder): assert len(genrec_dataset) == 6 assert genrec_dataset.user_size == 3 @@ -353,6 +397,9 @@ def test_quantizer_dataset_and_collator(quantizer_dataset, dummy_encoder): assert len(quantizer_dataset) == quantizer_dataset.item_size example = quantizer_dataset[0] assert example.item_embedding.shape[0] == dummy_encoder.embedding_dim + assert example.aux_item_embedding is not None + assert example.aux_item_embedding.ndim == 1 + aux_dim = example.aux_item_embedding.shape[0] collator = QuantizerCollator(quantizer_dataset) loader = DataLoader(quantizer_dataset, batch_size=4, collate_fn=collator) @@ -361,6 +408,8 @@ def test_quantizer_dataset_and_collator(quantizer_dataset, dummy_encoder): assert batch["item_id"].shape == (4,) assert batch["item_embedding"].shape == (4, dummy_encoder.embedding_dim) assert batch["item_embedding"].dtype == torch.float32 + assert batch["aux_item_embedding"].shape == (4, aux_dim) + assert batch["aux_item_embedding"].dtype == torch.float32 def test_genrec_iter_split_requires_minimum_length_for_validation(): @@ -465,13 +514,14 @@ def test_large_scale_multiworker_uniform_negative_sampler(): del loader encoder = DummyEncoder() + aux_embeddings = _make_aux_item_embeddings(item_pool, aux_dim=5) dataset = QuantizerDataset( interaction_data_path=interactions, - split=DatasetSplitLiteral.TRAIN, max_seq_length=4, min_seq_length=1, textual_data_path=textual, lm_encoder=encoder, + aux_item_embeddings=aux_embeddings, ) assert len(dataset) == item_pool assert dataset.split == DatasetSplitLiteral.TRAIN @@ -484,5 +534,10 @@ def test_large_scale_multiworker_uniform_negative_sampler(): shuffle=False, multiprocessing_context="fork", ) - _assert_quantizer_batches(loader, _expected_batch_sizes(len(dataset), batch_size), encoder.embedding_dim) + _assert_quantizer_batches( + loader, + _expected_batch_sizes(len(dataset), batch_size), + encoder.embedding_dim, + aux_embedding_dim=aux_embeddings.shape[1], + ) del loader diff --git a/tests/models/model_quantizer/__init__.py b/tests/models/model_quantizer/__init__.py new file mode 100644 index 0000000..232368a --- /dev/null +++ b/tests/models/model_quantizer/__init__.py @@ -0,0 +1 @@ +"""Tests for quantizer model implementations.""" diff --git a/tests/models/model_quantizer/test_base.py b/tests/models/model_quantizer/test_base.py new file mode 100644 index 0000000..96aef38 --- /dev/null +++ b/tests/models/model_quantizer/test_base.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import torch + +from genrec.models.model_quantizer.base import QuantizerModel, QuantizerModelConfig + + +class TinyQuantizerModel(QuantizerModel[QuantizerModelConfig, None]): + config_class = QuantizerModelConfig + + def __init__(self, config: QuantizerModelConfig) -> None: + super().__init__(config) + self._param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, *args, **kwargs): # pragma: no cover - not needed for this test + raise NotImplementedError + + def initialize_codebooks(self, item_embeddings: torch.Tensor, **kwargs) -> None: # pragma: no cover - unused + return None + + +def test_post_process_quantized_ids_offsets_and_anticollision() -> None: + config = QuantizerModelConfig( + embed_dim=4, + hidden_sizes=(2,), + num_codebooks=2, + codebook_size=10, + codebook_dim=2, + ) + model = TinyQuantizerModel(config) + + semantic_ids = torch.tensor( + [ + [0, 1], + [0, 1], + [2, 3], + ], + dtype=torch.long, + ) + + processed = model.post_process_quantized_ids(semantic_ids) + + assert processed.shape == (3, config.num_codebooks + 1) + torch.testing.assert_close(processed[:, 0], torch.tensor([1, 1, 3])) + torch.testing.assert_close(processed[:, 1], torch.tensor([12, 12, 14])) + torch.testing.assert_close(processed[:, 2], torch.tensor([21, 22, 21])) diff --git a/tests/models/model_quantizer/test_rqvae.py b/tests/models/model_quantizer/test_rqvae.py new file mode 100644 index 0000000..5d684d4 --- /dev/null +++ b/tests/models/model_quantizer/test_rqvae.py @@ -0,0 +1,253 @@ +import numpy as np +import torch +import torch.nn as nn + +from genrec.models.model_quantizer.rqvae import RQVAEModel, RQVAEModelConfig + + +def test_rqvae_forward_outputs_and_losses() -> None: + config = RQVAEModelConfig( + embed_dim=12, + hidden_sizes=(8, 4), + num_codebooks=2, + codebook_size=16, + codebook_dim=6, + kmeans_init=False, + ) + model = RQVAEModel(config) + + batch_size = 3 + item_id = torch.arange(1, batch_size + 1) + item_embedding = torch.randn(batch_size, config.embed_dim) + + output = model( + item_id=item_id, + item_embedding=item_embedding, + output_loss=True, + output_model_loss=True, + output_embeddings=True, + ) + + assert output.semantic_ids.shape == (batch_size, config.num_codebooks) + assert output.quantized_embeddings is not None + assert output.quantized_embeddings.shape == (batch_size, config.num_codebooks, config.codebook_dim) + assert output.residual_embeddings is not None + assert output.residual_embeddings.shape == (batch_size, config.num_codebooks, config.codebook_dim) + assert output.decoded_embeddings is not None + assert output.decoded_embeddings.shape == (batch_size, config.embed_dim) + assert output.reconstruction_loss is not None + assert output.reconstruction_loss.shape == (batch_size,) + assert output.codebook_loss is not None + assert output.codebook_loss.shape == (batch_size,) + assert output.commitment_loss is not None + assert output.commitment_loss.shape == (batch_size,) + assert output.model_loss is None + + +def test_rqvae_forward_without_optional_embeddings() -> None: + config = RQVAEModelConfig( + embed_dim=10, + hidden_sizes=(6,), + num_codebooks=1, + codebook_size=8, + codebook_dim=4, + kmeans_init=False, + ) + model = RQVAEModel(config) + + batch_size = 2 + item_id = torch.arange(1, batch_size + 1) + item_embedding = torch.randn(batch_size, config.embed_dim) + + output = model( + item_id=item_id, + item_embedding=item_embedding, + output_loss=True, + output_embeddings=False, + ) + + assert output.quantized_embeddings is None + assert output.residual_embeddings is None + assert output.decoded_embeddings is None + assert output.reconstruction_loss is not None + + +def test_rqvae_quantize_returns_expected_losses() -> None: + config = RQVAEModelConfig( + embed_dim=2, + hidden_sizes=(2,), + num_codebooks=1, + codebook_size=2, + codebook_dim=2, + kmeans_init=False, + ) + model = RQVAEModel(config) + + embeddings = torch.tensor( + [ + [0.0, 0.0], + [1.0, 1.0], + ], + requires_grad=True, + ) + codebook = torch.tensor( + [ + [0.0, 0.0], + [0.5, 0.5], + ], + dtype=torch.float32, + ) + + ids, quantized, codebook_loss, commitment_loss = model._quantize(embeddings, codebook) + + assert ids.tolist() == [0, 1] + torch.testing.assert_close(quantized, torch.tensor([[0.0, 0.0], [0.5, 0.5]])) + assert codebook_loss.shape == (2,) + assert commitment_loss.shape == (2,) + + quantized.sum().backward() + assert embeddings.grad is not None + + +def test_rqvae_initialize_codebooks_uses_kmeans(monkeypatch) -> None: + config = RQVAEModelConfig( + embed_dim=6, + hidden_sizes=(4,), + num_codebooks=2, + codebook_size=2, + codebook_dim=3, + kmeans_init=True, + kmeans_max_iter=1, + ) + model = RQVAEModel(config) + + class DummyKMeans: + call_count = 0 + + def __init__(self, n_clusters, **kwargs): + self.n_clusters = n_clusters + + def fit(self, residual): + DummyKMeans.call_count += 1 + fill = float(DummyKMeans.call_count) + self.cluster_centers_ = np.full((self.n_clusters, residual.shape[1]), fill, dtype=np.float32) + + monkeypatch.setattr("genrec.models.model_quantizer.rqvae.KMeans", DummyKMeans) + + item_embeddings = torch.randn(3, config.embed_dim) + model.initialize_codebooks(item_embeddings) + + for idx, codebook in enumerate(model.codebooks, start=1): + expected = torch.full((config.codebook_size, config.codebook_dim), float(idx)) + torch.testing.assert_close(codebook.weight.data, expected) + + +def test_rqvae_initialize_codebooks_skips_when_disabled() -> None: + config = RQVAEModelConfig(kmeans_init=False) + model = RQVAEModel(config) + + before = [codebook.weight.clone() for codebook in model.codebooks] + embeddings = torch.randn(5, config.embed_dim) + model.initialize_codebooks(embeddings) + + for original, codebook in zip(before, model.codebooks, strict=True): + torch.testing.assert_close(codebook.weight, original) + + +def test_rqvae_forward_computes_reconstruction_loss_without_embeddings() -> None: + config = RQVAEModelConfig( + embed_dim=2, + hidden_sizes=(), + num_codebooks=1, + codebook_size=1, + codebook_dim=2, + kmeans_init=False, + ) + model = RQVAEModel(config) + + model.encoder = nn.Identity() # type: ignore[assignment] + model.decoder = nn.Identity() # type: ignore[assignment] + model.codebooks[0].weight.data.zero_() + + item_embedding = torch.tensor([[1.0, -1.0]]) + output = model( + item_id=torch.tensor([1]), + item_embedding=item_embedding, + output_loss=True, + output_embeddings=False, + ) + + assert output.quantized_embeddings is None + assert output.residual_embeddings is None + assert output.decoded_embeddings is None + assert output.reconstruction_loss is not None + torch.testing.assert_close(output.reconstruction_loss, torch.tensor([1.0])) + + +def test_rqvae_forward_skips_decoder_when_not_requested() -> None: + config = RQVAEModelConfig( + embed_dim=2, + hidden_sizes=(), + num_codebooks=1, + codebook_size=1, + codebook_dim=2, + kmeans_init=False, + ) + model = RQVAEModel(config) + + class FailingDecoder(nn.Module): + def forward(self, x): # pragma: no cover - should not be called + raise AssertionError("Decoder should not be invoked") + + model.decoder = FailingDecoder() + model.encoder = nn.Identity() # type: ignore[assignment] + model.codebooks[0].weight.data.zero_() + + output = model( + item_id=torch.tensor([1]), + item_embedding=torch.zeros(1, config.embed_dim), + output_loss=False, + output_embeddings=False, + ) + + assert output.decoded_embeddings is None + assert output.reconstruction_loss is None + assert output.codebook_loss is None + assert output.commitment_loss is None + + +def test_rqvae_forward_invokes_decoder_when_loss_requested() -> None: + config = RQVAEModelConfig( + embed_dim=3, + hidden_sizes=(), + num_codebooks=1, + codebook_size=1, + codebook_dim=3, + kmeans_init=False, + ) + model = RQVAEModel(config) + + class RecordingDecoder(nn.Module): + def __init__(self, output_dim: int) -> None: + super().__init__() + self.called = False + self.output_dim = output_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.called = True + return torch.ones(x.size(0), self.output_dim, device=x.device) + + recording_decoder = RecordingDecoder(output_dim=config.embed_dim) + model.decoder = recording_decoder + model.encoder = nn.Identity() # type: ignore[assignment] + model.codebooks[0].weight.data.zero_() + + output = model( + item_id=torch.tensor([1]), + item_embedding=torch.zeros(1, config.embed_dim), + output_loss=True, + output_embeddings=False, + ) + + assert recording_decoder.called is True + assert output.reconstruction_loss is not None diff --git a/tests/trainers/trainer_quantizer/__init__.py b/tests/trainers/trainer_quantizer/__init__.py new file mode 100644 index 0000000..031d4ab --- /dev/null +++ b/tests/trainers/trainer_quantizer/__init__.py @@ -0,0 +1 @@ +"""Tests for quantizer trainer implementations.""" diff --git a/tests/trainers/trainer_quantizer/test_evaluations.py b/tests/trainers/trainer_quantizer/test_evaluations.py new file mode 100644 index 0000000..6d72a4b --- /dev/null +++ b/tests/trainers/trainer_quantizer/test_evaluations.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import numpy as np +import pytest +from transformers import EvalPrediction + +from genrec.trainers.trainer_quantizer.utils.evaluations import ( + QuantizerMetricFactory, + calc_metric_code_collision, + calc_metric_codebook_usage, + compute_quantizer_metrics, +) + + +class DummyTrainDataset: + item_popularity = np.ones(5, dtype=np.int64) + + +def test_compute_quantizer_metrics_returns_expected_values() -> None: + semantic_ids = np.array( + [ + [0, 1], + [0, 1], + [2, 3], + ], + dtype=np.int64, + ) + reconstruction_loss = np.array([1.0, 2.0, 3.0], dtype=np.float32) + codebook_loss = np.array([0.5, 1.5, 2.5], dtype=np.float32) + commitment_loss = np.array([2.0, 0.0, 1.0], dtype=np.float32) + item_id = np.array([1, 2, 3], dtype=np.int64) + + prediction = EvalPrediction( + predictions=(semantic_ids, reconstruction_loss, codebook_loss, commitment_loss, item_id), + label_ids=None, + ) + + metrics = compute_quantizer_metrics( + prediction, + train_dataset=DummyTrainDataset(), + codebook_size=4, + ) + + assert metrics["reconstruction_loss"] == pytest.approx(2.0) + assert metrics["codebook_loss"] == pytest.approx(1.5) + assert metrics["commitment_loss"] == pytest.approx(1.0) + assert metrics["codebook_0_usage"] == pytest.approx(0.5) + assert metrics["codebook_1_usage"] == pytest.approx(0.5) + assert metrics["code_collision_rate"] == pytest.approx(1.0 / 3.0) + + +def test_compute_quantizer_metrics_requires_tuple_predictions() -> None: + prediction = EvalPrediction(predictions=np.zeros((2, 2)), label_ids=np.zeros(2)) + + with pytest.raises(ValueError, match="Predictions should be a tuple"): + compute_quantizer_metrics( + prediction, + train_dataset=DummyTrainDataset(), + codebook_size=4, + ) + + +def test_metric_functions_directly() -> None: + semantic_ids = np.array( + [ + [1, 1], + [1, 2], + ], + dtype=np.int64, + ) + item_id = np.array([1, 2], dtype=np.int64) + + usage = calc_metric_codebook_usage( + semantic_ids=semantic_ids, + item_id=item_id, + train_dataset=DummyTrainDataset(), + codebook_size=4, + ) + collision = calc_metric_code_collision( + semantic_ids=semantic_ids, + item_id=item_id, + train_dataset=DummyTrainDataset(), + codebook_size=4, + ) + + assert usage["codebook_0_usage"] == pytest.approx(1.0 / 4.0) + assert usage["codebook_1_usage"] == pytest.approx(2.0 / 4.0) + assert collision["code_collision_rate"] == pytest.approx(0.0) + + +def test_metric_factory_raises_for_unknown_metric() -> None: + with pytest.raises(ValueError, match="not registered"): + QuantizerMetricFactory.create("missing_metric") diff --git a/tests/trainers/trainer_quantizer/test_rqvae.py b/tests/trainers/trainer_quantizer/test_rqvae.py new file mode 100644 index 0000000..6317a8d --- /dev/null +++ b/tests/trainers/trainer_quantizer/test_rqvae.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from typing import Any, Dict + +import torch + +from genrec.models.model_quantizer.base import QuantizerModel, QuantizerModelConfig, QuantizerOutput +from genrec.trainers.trainer_quantizer.rqvae import RQVAEQuantizerTrainer, RQVAEQuantizerTrainingArguments + + +class DummyQuantizerModel(QuantizerModel[QuantizerModelConfig, QuantizerOutput]): + config_class = QuantizerModelConfig + + def __init__(self, config: QuantizerModelConfig) -> None: + super().__init__(config) + + def forward( + self, + item_id: torch.Tensor, + item_embedding: torch.Tensor, + output_loss: bool = False, + output_model_loss: bool = False, + output_embeddings: bool = False, + **kwargs: Any, + ) -> QuantizerOutput: + return QuantizerOutput(semantic_ids=torch.zeros(item_embedding.shape[0], 1, dtype=torch.long)) + + def initialize_codebooks(self, item_embeddings: torch.Tensor, **kwargs: Any) -> None: + return None + + +class DummyDataset: + def __len__(self) -> int: + return 2 + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + return { + "item_id": torch.tensor(idx + 1, dtype=torch.long), + "item_embedding": torch.randn(4), + } + + +class DummyCollator: + def __call__(self, batch: list[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + item_id = torch.stack([item["item_id"] for item in batch]) + item_embedding = torch.stack([item["item_embedding"] for item in batch]) + return {"item_id": item_id, "item_embedding": item_embedding} + + +class MinimalRQVAETrainer(RQVAEQuantizerTrainer[DummyQuantizerModel]): + def initialize_codebooks(self) -> None: + return None + + +def test_rqvae_trainer_compute_quantizer_loss(tmp_path) -> None: + args = RQVAEQuantizerTrainingArguments( + output_dir=str(tmp_path), + codebook_loss_weight=0.5, + commitment_loss_weight=0.25, + ) + model = DummyQuantizerModel( + QuantizerModelConfig( + embed_dim=4, + hidden_sizes=(3,), + num_codebooks=1, + codebook_size=8, + codebook_dim=2, + ) + ) + + trainer = MinimalRQVAETrainer( + model=model, + args=args, + data_collator=DummyCollator(), + train_dataset=DummyDataset(), + eval_dataset=None, + compute_metrics=lambda _: {}, + ) + + outputs = QuantizerOutput( + semantic_ids=torch.zeros(2, 1, dtype=torch.long), + reconstruction_loss=torch.tensor([1.0, 2.0]), + codebook_loss=torch.tensor([0.5, 1.5]), + commitment_loss=torch.tensor([2.0, 0.0]), + ) + loss = trainer.compute_quantizer_loss(inputs={}, outputs=outputs) + + expected = ( + outputs.reconstruction_loss.mean() + + args.codebook_loss_weight * outputs.codebook_loss.mean() + + args.commitment_loss_weight * outputs.commitment_loss.mean() + ) + torch.testing.assert_close(loss, expected) diff --git a/tests/trainers/trainer_quantizer/utils/__init__.py b/tests/trainers/trainer_quantizer/utils/__init__.py new file mode 100644 index 0000000..a4cbedd --- /dev/null +++ b/tests/trainers/trainer_quantizer/utils/__init__.py @@ -0,0 +1 @@ +"""Tests for quantizer trainer utilities.""" diff --git a/tests/trainers/trainer_quantizer/utils/test_base.py b/tests/trainers/trainer_quantizer/utils/test_base.py new file mode 100644 index 0000000..3192679 --- /dev/null +++ b/tests/trainers/trainer_quantizer/utils/test_base.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +from functools import partial +from typing import Any, Dict, Sequence + +import numpy as np +import pandas as pd +import torch + +from genrec.datasets import DatasetSplitLiteral, QuantizerCollator, QuantizerDataset +from genrec.models.model_quantizer.base import QuantizerModel, QuantizerModelConfig, QuantizerOutput +from genrec.trainers.trainer_quantizer.base import QuantizerTrainer, QuantizerTrainingArguments +from genrec.trainers.trainer_quantizer.utils.callbacks import EpochIntervalEvalCallback, HardStopCallback +from genrec.trainers.trainer_quantizer.utils.evaluations import compute_quantizer_metrics +from transformers import TrainerCallback + + +class DummyEncoder: + embedding_dim: int = 4 + + def encode(self, texts: Sequence[str]) -> np.ndarray: + if not texts: + return np.empty((0, self.embedding_dim), dtype=np.float32) + base = np.arange(len(texts) * self.embedding_dim, dtype=np.float32) + return base.reshape(len(texts), self.embedding_dim) + + +class DummyQuantizerModel(QuantizerModel[QuantizerModelConfig, QuantizerOutput]): + config_class = QuantizerModelConfig + + def __init__(self, config: QuantizerModelConfig) -> None: + super().__init__(config) + self.initialize_called = False + self.last_init_embeddings: torch.Tensor | None = None + self._dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward( + self, + item_id: torch.Tensor, + item_embedding: torch.Tensor, + output_loss: bool = False, + output_model_loss: bool = False, + output_embeddings: bool = False, + **kwargs: Any, + ) -> QuantizerOutput: + batch_size = item_embedding.shape[0] + semantic_ids = torch.zeros(batch_size, self.config.num_codebooks, dtype=torch.long) + reconstruction_loss = torch.full((batch_size,), 1.0, device=item_embedding.device) if output_loss else None + codebook_loss = torch.full((batch_size,), 2.0, device=item_embedding.device) if output_loss else None + commitment_loss = torch.full((batch_size,), 3.0, device=item_embedding.device) if output_loss else None + model_loss = torch.tensor(4.0, device=item_embedding.device) if output_model_loss else None + return QuantizerOutput( + semantic_ids=semantic_ids, + reconstruction_loss=reconstruction_loss, + codebook_loss=codebook_loss, + commitment_loss=commitment_loss, + model_loss=model_loss, + ) + + def initialize_codebooks(self, item_embeddings: torch.Tensor, **kwargs: Any) -> None: + self.initialize_called = True + self.last_init_embeddings = item_embeddings + + +class MinimalQuantizerTrainer(QuantizerTrainer[DummyQuantizerModel, QuantizerTrainingArguments]): + def compute_quantizer_loss( + self, + inputs: dict[str, torch.Tensor], + outputs: QuantizerOutput, + num_items_in_batch: torch.Tensor | None = None, + ) -> torch.Tensor: + assert outputs.reconstruction_loss is not None + return outputs.reconstruction_loss.mean() + + +def _make_interaction_frame() -> pd.DataFrame: + return pd.DataFrame( + { + "UserID": np.array([0, 1], dtype=np.int64), + "ItemID": [ + [1, 2, 3], + [2, 3, 4], + ], + "Timestamp": [ + [1, 2, 3], + [1, 2, 3], + ], + } + ) + + +def _make_textual_frame(item_pool: int) -> pd.DataFrame: + return pd.DataFrame( + { + "ItemID": np.arange(1, item_pool + 1, dtype=np.int64), + "Title": [f"Item {idx}" for idx in range(1, item_pool + 1)], + } + ) + + +def _build_quantizer_dataset() -> QuantizerDataset: + interaction_frame = _make_interaction_frame() + textual_frame = _make_textual_frame(item_pool=4) + return QuantizerDataset( + interaction_data_path=interaction_frame, + max_seq_length=4, + min_seq_length=1, + textual_data_path=textual_frame, + lm_encoder=DummyEncoder(), + ) + + +def test_quantizer_trainer_defaults_and_initialize_codebooks(tmp_path) -> None: + dataset = _build_quantizer_dataset() + collator = QuantizerCollator(dataset) + model = DummyQuantizerModel( + QuantizerModelConfig( + embed_dim=4, + hidden_sizes=(3,), + num_codebooks=2, + codebook_size=8, + codebook_dim=2, + ) + ) + args = QuantizerTrainingArguments( + output_dir=str(tmp_path), + eval_interval=2, + train_stop_epoch=5, + ) + + trainer = MinimalQuantizerTrainer( + model=model, + args=args, + data_collator=collator, + train_dataset=dataset, + ) + + assert trainer.label_names == ["item_id"] + assert isinstance(trainer.compute_metrics, partial) + assert trainer.compute_metrics.func is compute_quantizer_metrics + assert trainer.compute_metrics.keywords["metrics"] == args.metrics + assert trainer.compute_metrics.keywords["codebook_size"] == model.config.codebook_size + assert trainer.compute_metrics.keywords["train_dataset"] is dataset + + callback_types = tuple(type(cb) for cb in trainer.callback_handler.callbacks) + assert any(isinstance(cb, EpochIntervalEvalCallback) for cb in trainer.callback_handler.callbacks), callback_types + assert any(isinstance(cb, HardStopCallback) for cb in trainer.callback_handler.callbacks), callback_types + + assert model.initialize_called is True + assert model.last_init_embeddings is not None + torch.testing.assert_close(model.last_init_embeddings.cpu(), torch.from_numpy(dataset.item_embeddings)) + + +def test_quantizer_trainer_compute_loss_with_model_loss(tmp_path) -> None: + dataset = _build_quantizer_dataset() + collator = QuantizerCollator(dataset) + model = DummyQuantizerModel( + QuantizerModelConfig( + embed_dim=4, + hidden_sizes=(3,), + num_codebooks=1, + codebook_size=4, + codebook_dim=2, + ) + ) + args = QuantizerTrainingArguments( + output_dir=str(tmp_path), + model_loss_weight=0.5, + ) + + trainer = MinimalQuantizerTrainer( + model=model, + args=args, + data_collator=collator, + train_dataset=dataset, + ) + + model.train() + batch = [dataset[0], dataset[1]] + inputs = collator(batch) + + loss, output_dict = trainer.compute_loss(trainer.model, inputs, return_outputs=True) + + expected = torch.tensor(1.0 + 4.0 * args.model_loss_weight) + torch.testing.assert_close(loss, expected) + torch.testing.assert_close(output_dict["loss"], expected) + assert output_dict["semantic_ids"].shape == (2, model.config.num_codebooks) + assert output_dict["item_id"].shape == (2,) + + +def test_quantizer_trainer_compute_loss_without_model_loss(tmp_path) -> None: + dataset = _build_quantizer_dataset() + collator = QuantizerCollator(dataset) + model = DummyQuantizerModel( + QuantizerModelConfig( + embed_dim=4, + hidden_sizes=(3,), + num_codebooks=1, + codebook_size=4, + codebook_dim=2, + ) + ) + args = QuantizerTrainingArguments(output_dir=str(tmp_path)) + + trainer = MinimalQuantizerTrainer( + model=model, + args=args, + data_collator=collator, + train_dataset=dataset, + ) + + model.eval() + batch = [dataset[0]] + inputs = collator(batch) + + loss = trainer.compute_loss(trainer.model, inputs, return_outputs=False) + torch.testing.assert_close(loss, torch.tensor(1.0)) + + +def test_quantizer_trainer_respects_custom_callbacks(tmp_path) -> None: + dataset = _build_quantizer_dataset() + collator = QuantizerCollator(dataset) + model = DummyQuantizerModel( + QuantizerModelConfig( + embed_dim=4, + hidden_sizes=(3,), + num_codebooks=1, + codebook_size=4, + codebook_dim=2, + ) + ) + args = QuantizerTrainingArguments(output_dir=str(tmp_path)) + + class DummyCallback(TrainerCallback): + pass + + custom_callback = DummyCallback() + + def custom_metrics(_): + return {"custom": 0.0} + + trainer = MinimalQuantizerTrainer( + model=model, + args=args, + data_collator=collator, + train_dataset=dataset, + compute_metrics=custom_metrics, + callbacks=[custom_callback], + ) + + assert trainer.compute_metrics is custom_metrics + assert any(isinstance(cb, DummyCallback) for cb in trainer.callback_handler.callbacks) diff --git a/tests/trainers/trainer_quantizer/utils/test_callbacks.py b/tests/trainers/trainer_quantizer/utils/test_callbacks.py new file mode 100644 index 0000000..d14a0ed --- /dev/null +++ b/tests/trainers/trainer_quantizer/utils/test_callbacks.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from transformers import TrainerControl, TrainerState, TrainingArguments + +from genrec.trainers.trainer_quantizer.utils.callbacks import EpochIntervalEvalCallback, HardStopCallback + + +def _build_state(epoch: float) -> TrainerState: + state = TrainerState() + state.epoch = epoch + return state + + +def _build_control(should_evaluate: bool = True, should_save: bool = True) -> TrainerControl: + control = TrainerControl() + control.should_evaluate = should_evaluate + control.should_save = should_save + return control + + +def test_epoch_interval_eval_callback_disables_evaluation_when_not_due(tmp_path) -> None: + callback = EpochIntervalEvalCallback(eval_interval=2) + args = TrainingArguments(output_dir=str(tmp_path)) + state = _build_state(epoch=3.0) + control = _build_control(should_evaluate=True, should_save=True) + + updated = callback.on_epoch_end(args=args, state=state, control=control) + + assert updated.should_evaluate is False + assert updated.should_save is False + + +def test_epoch_interval_eval_callback_keeps_evaluation_when_due(tmp_path) -> None: + callback = EpochIntervalEvalCallback(eval_interval=2) + args = TrainingArguments(output_dir=str(tmp_path)) + state = _build_state(epoch=4.0) + control = _build_control(should_evaluate=True, should_save=True) + + updated = callback.on_epoch_end(args=args, state=state, control=control) + + assert updated.should_evaluate is True + assert updated.should_save is True + + +def test_hard_stop_callback_requests_stop_at_epoch(tmp_path) -> None: + callback = HardStopCallback(stop_epoch=3) + args = TrainingArguments(output_dir=str(tmp_path)) + state = _build_state(epoch=3.0) + control = TrainerControl() + + updated = callback.on_epoch_end(args=args, state=state, control=control) + + assert updated.should_training_stop is True + + +def test_hard_stop_callback_ignores_negative_stop_epoch(tmp_path) -> None: + callback = HardStopCallback(stop_epoch=-1) + args = TrainingArguments(output_dir=str(tmp_path)) + state = _build_state(epoch=10.0) + control = TrainerControl() + + updated = callback.on_epoch_end(args=args, state=state, control=control) + + assert updated.should_training_stop is False diff --git a/tests/trainers/trainer_seqrec/__init__.py b/tests/trainers/trainer_seqrec/__init__.py index e69de29..e5701f1 100644 --- a/tests/trainers/trainer_seqrec/__init__.py +++ b/tests/trainers/trainer_seqrec/__init__.py @@ -0,0 +1 @@ +"""Tests for sequential recommender trainer implementations.""" diff --git a/tests/trainers/trainer_seqrec/test_base.py b/tests/trainers/trainer_seqrec/test_base.py index 906d43d..f78740b 100644 --- a/tests/trainers/trainer_seqrec/test_base.py +++ b/tests/trainers/trainer_seqrec/test_base.py @@ -13,8 +13,8 @@ from genrec.models.model_seqrec.base import SeqRecModelConfig from genrec.trainers.trainer_seqrec.base import compute_seqrec_metrics -from genrec.trainers.utils.evaluations import clip_top_k -from genrec.trainers.utils.callbacks import EpochIntervalEvalCallback +from genrec.trainers.trainer_seqrec.utils.evaluations import clip_top_k +from genrec.trainers.trainer_seqrec.utils.callbacks import EpochIntervalEvalCallback from tests.trainers.trainer_seqrec.helpers import ( DummySeqRecCollator, DummySeqRecDataset, diff --git a/tests/trainers/trainer_seqrec/utils/__init__.py b/tests/trainers/trainer_seqrec/utils/__init__.py new file mode 100644 index 0000000..ad51a8e --- /dev/null +++ b/tests/trainers/trainer_seqrec/utils/__init__.py @@ -0,0 +1 @@ +"""Tests for sequential recommender trainer utilities.""" diff --git a/tests/trainers/utils/test_callbacks.py b/tests/trainers/trainer_seqrec/utils/test_callbacks.py similarity index 96% rename from tests/trainers/utils/test_callbacks.py rename to tests/trainers/trainer_seqrec/utils/test_callbacks.py index 14318eb..f3a9090 100644 --- a/tests/trainers/utils/test_callbacks.py +++ b/tests/trainers/trainer_seqrec/utils/test_callbacks.py @@ -1,7 +1,7 @@ import pytest from transformers import TrainerControl, TrainerState, TrainingArguments -from genrec.trainers.utils.callbacks import EpochIntervalEvalCallback, HardStopCallback +from genrec.trainers.trainer_seqrec.utils.callbacks import EpochIntervalEvalCallback, HardStopCallback def _build_state(epoch): diff --git a/tests/trainers/utils/test_evaluations.py b/tests/trainers/trainer_seqrec/utils/test_evaluations.py similarity index 92% rename from tests/trainers/utils/test_evaluations.py rename to tests/trainers/trainer_seqrec/utils/test_evaluations.py index 6711cef..805c72d 100644 --- a/tests/trainers/utils/test_evaluations.py +++ b/tests/trainers/trainer_seqrec/utils/test_evaluations.py @@ -4,8 +4,8 @@ import pytest import torch -from genrec.trainers.utils.evaluations import ( - MetricFactory, +from genrec.trainers.trainer_seqrec.utils.evaluations import ( + SeqRecMetricFactory, calc_metric_hr, calc_metric_ndcg, calc_metric_popularity, @@ -64,11 +64,11 @@ def test_calc_metric_unpopularity_counts_rare_items(): def test_metric_factory_returns_registered_metric(): - hr_metric = MetricFactory.create("hr") + hr_metric = SeqRecMetricFactory.create("hr") assert hr_metric is calc_metric_hr def test_metric_factory_raises_for_unknown_metric(): with pytest.raises(ValueError): - MetricFactory.create("unknown") + SeqRecMetricFactory.create("unknown")