diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cf18381..7fa8f48 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,6 +9,12 @@ on: jobs: build: runs-on: ubuntu-latest + env: + OMP_NUM_THREADS: 4 + MKL_NUM_THREADS: 4 + OPENBLAS_NUM_THREADS: 4 + NUMEXPR_NUM_THREADS: 4 + PYTORCH_NUM_THREADS: 4 steps: - name: Checkout repository @@ -30,4 +36,6 @@ jobs: run: poetry install - name: Run tests - run: poetry run pytest --cov=genrec --cov-report=term-missing --cov-fail-under=90 \ No newline at end of file + run: poetry run pytest --cov=genrec --cov-report=term-missing --cov-fail-under=90 + # in local test, you may run the following command to constrain the CPU usage + # run: OMP_NUM_THREADS=4 MKL_NUM_THREADS=4 OPENBLAS_NUM_THREADS=4 NUMEXPR_NUM_THREADS=4 PYTORCH_NUM_THREADS=4 poetry run pytest --cov=genrec --cov-report=term-missing --cov-fail-under=90 \ No newline at end of file diff --git a/src/genrec/datasets/base.py b/src/genrec/datasets/base.py index 22ea71e..4acc9db 100644 --- a/src/genrec/datasets/base.py +++ b/src/genrec/datasets/base.py @@ -203,12 +203,12 @@ def __init__( self._min_seq_length = int(min_seq_length) self._sid_cache = sid_cache - self._item_embeddings: Optional[Float[np.ndarray, "I+1 D"]] = None + self._item_textual_embeddings: Optional[Float[np.ndarray, "I+1 D"]] = None + self._item_textual_data: Optional[np.ndarray] = None if textual_data_path is not None: - if lm_encoder is None: # pragma: no cover - defensive guard - raise ValueError("textual_data_path provided without lm_encoder.") - assert isinstance(lm_encoder, LMEncoder) - self._item_embeddings = self._build_item_embeddings(textual_data_path, lm_encoder) + self._item_textual_embeddings, self._item_textual_data = self._build_item_textual_embeddings( + textual_data_path, lm_encoder + ) ( self._user_interactions, @@ -233,6 +233,9 @@ def _load_dataframe( containing `UserID`, `ItemID`, and `Timestamp` (Unix time). columns (Sequence[str]): Required column names. dtypes (Mapping[str, Any]): Expected dtypes per column. + + Returns: + pd.DataFrame: Loaded dataframe with required columns and dtypes. """ if isinstance(data_source, pd.DataFrame): frame = data_source.copy(deep=False) @@ -257,13 +260,25 @@ def _load_dataframe( return frame - def _build_item_embeddings( + def _build_item_textual_embeddings( self, textual_data_path: Union[pd.DataFrame, str, Path], - lm_encoder: LMEncoder, - ) -> Float[np.ndarray, "I+1 D"]: - """Encodes item titles into dense vectors using `encoder`. - Returns a `np.ndarray` with shape (num_items, embedding_dim). + lm_encoder: Optional[LMEncoder], + ) -> Tuple[Optional[Float[np.ndarray, "I+1 D"]], np.ndarray]: + """Loads the item titles and encodes them into dense vectors using `encoder`. + + Args: + textual_data_path (Union[pd.DataFrame, str, Path]): Pandas DataFrame + or path to a pickle file containing `ItemID` and `Title` columns. + lm_encoder (LMEncoder): Encoder used to transform item titles into dense embeddings. + + Returns: + Tuple[Optional[Float[np.ndarray, "I+1 D"]], np.array]: A tuple containing: + - The optional item textual embeddings as a np.ndarray (float) with shape + (num_items + 1, embedding_dim), where index 0 is reserved for padding. + If lm_encoder is None, returns None. + - The original titles as a np.ndarray (object) with shape (num_items + 1,), + where index 0 is reserved for padding. """ textual_frame = self._load_dataframe( textual_data_path, @@ -275,11 +290,16 @@ def _build_item_embeddings( assert textual_frame["ItemID"].nunique() == num_items, "ItemIDs must be contiguous integers." titles = textual_frame["Title"].to_list() - embeddings = lm_encoder.encode(titles).astype(np.float32, copy=False) - padding_embedding = np.zeros((1, embeddings.shape[1]), dtype=np.float32) - embeddings = np.vstack([padding_embedding, embeddings]) + embeddings: Optional[Float[np.ndarray, "I+1 D"]] = None + if lm_encoder is not None: + embeddings = lm_encoder.encode(titles).astype(np.float32, copy=False) + padding_embedding = np.zeros((1, embeddings.shape[1]), dtype=np.float32) + embeddings = np.vstack([padding_embedding, embeddings]) + + titles = np.array(titles, dtype=object) + titles = np.concatenate((np.array([""], dtype=object), titles), axis=0) - return embeddings + return embeddings, titles def _build_interactions( self, @@ -395,16 +415,16 @@ def sid_width(self) -> Optional[int]: return self._sid_cache.shape[1] @property - def item_embeddings(self) -> Optional[Float[np.ndarray, "I+1 D"]]: + def item_textual_embeddings(self) -> Optional[Float[np.ndarray, "I+1 D"]]: """Exposes the cached dense item embeddings, when available.""" - return self._item_embeddings + return self._item_textual_embeddings @property - def embedding_dim(self) -> Optional[int]: + def textual_embedding_dim(self) -> Optional[int]: """Returns the dimensionality of cached embeddings, if present.""" - if self._item_embeddings is None: # pragma: no cover - embedding absent + if self._item_textual_embeddings is None: # pragma: no cover - embedding absent return None - return self._item_embeddings.shape[1] + return self._item_textual_embeddings.shape[1] @property def item_size(self) -> int: @@ -412,8 +432,8 @@ def item_size(self) -> int: is provided in `textual_data_path`, we infer the size from there; otherwise, we estimate it from the maximum item ID observed in the interaction data. """ - if self._item_embeddings is not None: - return self._item_embeddings.shape[0] - 1 + if self._item_textual_data is not None: + return self._item_textual_data.shape[0] - 1 user_max_item_ids = [items[-1] if items.size > 0 else 0 for items in self._user_positive_items] return int(max(user_max_item_ids)) diff --git a/src/genrec/datasets/dataset_genrec.py b/src/genrec/datasets/dataset_genrec.py index 57a786e..315a37c 100644 --- a/src/genrec/datasets/dataset_genrec.py +++ b/src/genrec/datasets/dataset_genrec.py @@ -22,7 +22,6 @@ RecExampleFactory, ) from .modules.lm_encoders import LMEncoder -from .modules.negative_samplers import NegativeSamplerFactory __all__ = [ "GenRecCollator", @@ -35,28 +34,30 @@ @RecExampleFactory.register("genrec") @dataclass(slots=True) class GenRecExample(RecExample): - """Container storing a single training example for encoder-decoder models. + """Container storing a single training pair with item-level and SID views + that suit encoder-decoder style generative recommendation models. - The example pairs a truncated interaction history with the next positive - item that the model is asked to predict. + The example keeps both the original item identifiers (for bookkeeping) and + their flattened Semantic ID (SID) token counterparts that are directly + consumed by generative models. Attributes: user_id: Identifier of the user to which the example belongs. - input_ids: Interaction history (already truncated) that conditions the model. - labels: Next positive item the model is asked to predict. - timestamps: Timestamps aligned with `input_ids`, in Unix time. - input_sid_tokens: Optional matrix of SIDs corresponding to `input_ids`. - target_sid_tokens: Optional SIDs for `labels`. - input_embeddings: Optional dense embedding matrix aligned with `input_ids`. - target_embedding: Optional dense embedding vector for `labels`. + input_ids: Flattened SID tokens derived from `input_item_ids`. + labels: SID tokens representing `label_item_ids`. + input_item_ids: Interaction history expressed as item IDs. + label_item_ids: Positive item ID that should follow the history. + timestamps: Timestamps aligned with `input_item_ids` in Unix time. + input_embeddings: Optional dense embedding matrix aligned with `input_item_ids`. + target_embedding: Optional dense embedding vector for `label_item_ids`. """ user_id: int - input_ids: Int[np.ndarray, "L"] - labels: int + input_ids: Int[np.ndarray, "L*C"] + labels: Int[np.ndarray, "C"] + input_item_ids: Int[np.ndarray, "L"] + label_item_ids: int timestamps: Int[np.ndarray, "L"] - input_sid_tokens: Optional[Int[np.ndarray, "L C"]] = None - target_sid_tokens: Optional[Int[np.ndarray, "C"]] = None input_embeddings: Optional[Float[np.ndarray, "L D"]] = None target_embedding: Optional[Float[np.ndarray, "D"]] = None @@ -76,6 +77,7 @@ def __init__( sid_cache: Optional[Int[np.ndarray, "I+1 C"]] = None, textual_data_path: Optional[Union[pd.DataFrame, str, Path]] = None, lm_encoder: Optional[LMEncoder] = None, + truncation_strategy: str = "tail", ) -> None: """Initialises the dataset and materialises user-level metadata. @@ -93,7 +95,15 @@ def __init__( pickle file with `ItemID` and `Title` columns. lm_encoder (Optional[LMEncoder]): Optional encoder used to transform item titles into dense embeddings. + truncation_strategy (str): Strategy for truncating interaction histories, supported + options are `"tail"` and `"slide"`. `"tail"` will directly truncate the user history + to `max_seq_length`, then construct training examples with length of history from + `min_seq_length` to (up to) `max_seq_length`. `"slide"` will use a sliding window of + size `max_seq_length` over the entire user history to construct training examples. + Defaults to `"tail"`. """ + assert truncation_strategy in {"tail", "slide"}, f"Unsupported truncation strategy: {truncation_strategy}." + self.truncation_strategy = truncation_strategy super().__init__( interaction_data_path, split, @@ -103,6 +113,8 @@ def __init__( textual_data_path, lm_encoder, ) + if self._sid_cache is None: # pragma: no cover - defensive guard + raise ValueError("GenRecDataset requires `sid_cache` to materialise SID tokens.") # recompute training set item popularity self._train_item_popularity = self._compute_train_item_popularity() @@ -113,6 +125,7 @@ def _build_examples(self) -> List[GenRecExample]: examples: List[GenRecExample] = [] user_ids = np.arange(self.user_size, dtype=np.int64) for user_id, items, timestamps in zip(user_ids, self.user_interactions, self.user_interaction_timestamps): + items, timestamps = self._tail_truncate(items, timestamps) for context, target, times in self._iter_split(items, timestamps): if context.shape[0] < self._min_seq_length: # pragma: no cover - insufficient length continue @@ -145,6 +158,20 @@ def _iter_split( if seq_len >= 2: yield items[:-1], int(items[-1]), times[:-1] + def _tail_truncate( + self, + items: Int[np.ndarray, "..."], + times: Int[np.ndarray, "..."], + ) -> Tuple[Int[np.ndarray, "..."], Int[np.ndarray, "..."]]: + """Truncates the interaction history by keeping the most recent interactions. + We trim the history to `max_seq_length + 3` to account for the target and + validation/test items. + """ + if self.truncation_strategy == "tail": + return items[-self._max_seq_length - 3 :], times[-self._max_seq_length - 3 :] + else: + return items, times + def _construct_example( self, user_id: int, @@ -153,15 +180,29 @@ def _construct_example( times: Int[np.ndarray, "..."], ) -> GenRecExample: """Constructs a GenRecExample from the provided context and target.""" - example = GenRecExample(user_id=user_id, input_ids=context, labels=target, timestamps=times) - - if self._sid_cache is not None: - example.input_sid_tokens = self._sid_cache[context] - example.target_sid_tokens = self._sid_cache[target] + assert self._sid_cache is not None, "SID cache must be available after __init__ guard." + + input_item_ids = context.astype(np.int64, copy=True) + label_item_id = int(target) + timestamps = times.astype(np.int64, copy=True) + + sid_context: Int[np.ndarray, "L C"] = self._sid_cache[input_item_ids] + sid_target: Int[np.ndarray, "C"] = self._sid_cache[label_item_id] + flattened_context = sid_context.reshape(-1).astype(np.int64, copy=True) + sid_target = sid_target.astype(np.int64, copy=True) + + example = GenRecExample( + user_id=user_id, + input_ids=flattened_context, + labels=sid_target, + input_item_ids=input_item_ids, + label_item_ids=label_item_id, + timestamps=timestamps, + ) - if self._item_embeddings is not None: - example.input_embeddings = self._item_embeddings[context] - example.target_embedding = self._item_embeddings[target] + if self._item_textual_embeddings is not None: + example.input_embeddings = self._item_textual_embeddings[input_item_ids] + example.target_embedding = self._item_textual_embeddings[label_item_id] return example @@ -192,19 +233,11 @@ class GenRecCollatorConfig(RecCollatorConfig): Attributes: pad_sid: Padding value for Semantic ID token. - num_negative_samples: Number of negatives to sample per instance. - negative_sampling_strategy: Name of the negative sampling strategy to use. - need_sid_tokens: Whether to collate SID tokens if present. Note that - if the dataset examples do not contain SID tokens, this flag should - be set to `False` to avoid errors. need_embeddings: Whether to collate dense embeddings if present. Note that if the dataset examples do not contain embeddings, this flag should be set to `False` to avoid errors. """ - num_negative_samples: int = 0 - negative_sampling_strategy: str = "uniform" - need_sid_tokens: bool = True need_embeddings: bool = False @property @@ -221,28 +254,22 @@ class GenRecCollator(RecCollator[GenRecExample]): user_id: `Int[torch.Tensor, "B"]`. User IDs. - input_ids: `Int[torch.Tensor, "B L"]`. - Input item ID sequences. - labels: `Int[torch.Tensor, "B"]`. + input_ids: `Int[torch.Tensor, "B L*C"]`. + Flattened SID tokens aligned with `input_item_ids`. + labels: `Int[torch.Tensor, "B C"]`. + SID tokens describing `label_item_ids`. + input_item_ids: `Int[torch.Tensor, "B L"]`. + Input item ID sequences used for bookkeeping/metrics. + label_item_ids: `Int[torch.Tensor, "B"]`. Target item IDs. timestamps: `Int[torch.Tensor, "B L"]`. Input timestamp sequences. - attention_mask: `Int[torch.Tensor, "B L"]`. - Attention masks for inputs. - input_sid_tokens: `Optional[Int[torch.Tensor, "B L C"]]`. - Input Semantic ID token sequences, when needed. - target_sid_tokens: `Optional[Int[torch.Tensor, "B C"]]`. - Target Semantic ID tokens, when needed. + attention_mask: `Int[torch.Tensor, "B L*C"]`. + Attention masks for the flattened SID tokens. input_embeddings: `Optional[Float[torch.Tensor, "B L D"]]`. Input dense embeddings, when present and needed. target_embedding: `Optional[Float[torch.Tensor, "B D"]]`. Target dense embeddings, when present and needed. - negative_item_ids: `Optional[Int[torch.Tensor, "B N"]]`. - Sampled negative item IDs, when negatives are requested. - negative_sid_tokens: `Optional[Int[torch.Tensor, "B N C"]]`. - Sampled negative Semantic ID tokens, when needed. - negative_embeddings: `Optional[Float[torch.Tensor, "B N D"]]`. - Sampled negative dense embeddings, when present and needed. """ def __init__( @@ -262,38 +289,29 @@ def __init__( """ self._config = config or GenRecCollatorConfig() - assert self._config.num_negative_samples >= 0, "num_negative_samples must be non-negative." - self._negative_sampler = NegativeSamplerFactory.create( - self._config.negative_sampling_strategy, - dataset=dataset, - ) - self._item_size = dataset.item_size - self._sid_cache: Optional[Int[np.ndarray, "I+1 C"]] = dataset.sid_cache - self._item_embeddings: Optional[Float[np.ndarray, "I+1 D"]] = dataset.item_embeddings - if self._config.need_sid_tokens and self._sid_cache is None: # pragma: no cover - defensive guard - raise ValueError("Dataset must have SID cache when need_sid_tokens is True.") + self._sid_width = dataset.sid_width + self._item_embeddings: Optional[Float[np.ndarray, "I+1 D"]] = dataset.item_textual_embeddings if self._config.need_embeddings and self._item_embeddings is None: # pragma: no cover - defensive guard raise ValueError("Dataset must have item embeddings when need_embeddings is True.") need_pad_keys: Dict[str, type] = { "input_ids": np.int64, - "timestamps": np.int64, "attention_mask": np.int64, + "input_item_ids": np.int64, + "timestamps": np.int64, } no_pad_keys: Dict[str, type] = { "user_id": np.int64, + "label_item_ids": np.int64, "labels": np.int64, } pad_values: Dict[str, np.generic] = { - "input_ids": self._config.pad_item, - "timestamps": np.int64(0), + "input_ids": self._config.pad_sid, "attention_mask": np.int64(0), + "input_item_ids": self._config.pad_item, + "timestamps": np.int64(0), } - if self._config.need_sid_tokens: - need_pad_keys["input_sid_tokens"] = np.int64 - no_pad_keys["target_sid_tokens"] = np.int64 - pad_values["input_sid_tokens"] = self._config.pad_sid if self._config.need_embeddings: need_pad_keys["input_embeddings"] = np.float32 no_pad_keys["target_embedding"] = np.float32 @@ -306,30 +324,7 @@ def _process_before_padding( batch: List[Dict[str, np.ndarray]], batch_seed: int, ) -> None: - """Add attention masks to the batch before padding.""" + """Add attention masks aligned with flattened SID tokens before padding.""" for sample in batch: seq_len = sample["input_ids"].shape[0] sample["attention_mask"] = np.ones((seq_len,), dtype=np.int64) - - def _process_after_padding( - self, - need_pad_batch: Dict[str, np.ndarray], - no_pad_batch: Dict[str, np.ndarray], - batch_seed: int, - ) -> None: - """Perform negative sampling after padding, if requested.""" - if self._config.num_negative_samples == 0: # pragma: no cover - no negatives requested - return - user_histories = need_pad_batch["input_ids"] - negative_item_ids: Int[np.ndarray, "B N"] = self._negative_sampler( - history=user_histories, - num_samples=self._config.num_negative_samples, - batch_seed=batch_seed, - ) - no_pad_batch["negative_item_ids"] = negative_item_ids - if self._config.need_sid_tokens and self._sid_cache is not None: - negative_sid_tokens: Int[np.ndarray, "B N C"] = self._sid_cache[negative_item_ids] - no_pad_batch["negative_sid_tokens"] = negative_sid_tokens - if self._config.need_embeddings and self._item_embeddings is not None: - negative_embeddings: Float[np.ndarray, "B N D"] = self._item_embeddings[negative_item_ids] - no_pad_batch["negative_embeddings"] = negative_embeddings diff --git a/src/genrec/datasets/dataset_quantizer.py b/src/genrec/datasets/dataset_quantizer.py index 386dce1..c8dff3e 100644 --- a/src/genrec/datasets/dataset_quantizer.py +++ b/src/genrec/datasets/dataset_quantizer.py @@ -93,11 +93,11 @@ def __init__( def _build_examples(self) -> List[QuantizerExample]: """Generates training examples for quantizer training.""" - assert self._item_embeddings is not None, "Item embeddings are required for quantizer training." + assert self._item_textual_embeddings is not None, "Item embeddings are required for quantizer training." examples = [ QuantizerExample( i, - self._item_embeddings[i], + self._item_textual_embeddings[i], self._aux_item_embeddings[i] if self._aux_item_embeddings is not None else None, ) for i in range(1, self.item_size + 1) diff --git a/src/genrec/datasets/modules/utils.py b/src/genrec/datasets/modules/utils.py index 4e22512..e84ae54 100644 --- a/src/genrec/datasets/modules/utils.py +++ b/src/genrec/datasets/modules/utils.py @@ -47,22 +47,25 @@ def pad_batch( if len(all_keys) == 0: return {} - batch_max_length = max(sample[key].shape[0] for sample in batch for key in all_keys) - if max_length is not None: - batch_max_length = max(batch_max_length, max_length) - - if pad_to_multiple_of is not None and pad_to_multiple_of > 0: - if batch_max_length % pad_to_multiple_of != 0: - batch_max_length = ((batch_max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + key_max_lengths: Dict[str, int] = {} + for key in all_keys: + key_max_length = max(sample[key].shape[0] for sample in batch) + if max_length is not None: + key_max_length = max(key_max_length, max_length) + if pad_to_multiple_of is not None and pad_to_multiple_of > 0: + if key_max_length % pad_to_multiple_of != 0: + key_max_length = ((key_max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + key_max_lengths[key] = key_max_length padded_batch: Dict[str, np.ndarray] = {} for key in all_keys: pad_value = pad_values.get(key, 0) + target_length = key_max_lengths[key] field_rows: List[np.ndarray] = [] for sample in batch: array = sample[key] seq_len = array.shape[0] - pad_length = batch_max_length - seq_len + pad_length = target_length - seq_len if direction == "left": pad_width = (pad_length, 0) else: diff --git a/src/genrec/models/model_genrec/__init__.py b/src/genrec/models/model_genrec/__init__.py index e2db776..144fda1 100644 --- a/src/genrec/models/model_genrec/__init__.py +++ b/src/genrec/models/model_genrec/__init__.py @@ -1 +1,31 @@ """Models for generative recommendation tasks.""" + +__all__ = [] + +from .base import ( + GenRecModel, + GenRecModelConfig, + GenRecModelConfigFactory, + GenRecModelFactory, + GenRecOutput, + GenRecOutputFactory, + ShiftRightMixin, +) + +__all__ += [ + "GenRecModel", + "GenRecModelFactory", + "GenRecModelConfig", + "GenRecModelConfigFactory", + "GenRecOutput", + "GenRecOutputFactory", + "ShiftRightMixin", +] + +from .tiger import TIGERModel, TIGERModelConfig, TIGERModelOutput + +__all__ += [ + "TIGERModel", + "TIGERModelConfig", + "TIGERModelOutput", +] diff --git a/src/genrec/models/model_genrec/base.py b/src/genrec/models/model_genrec/base.py new file mode 100644 index 0000000..78b4c1c --- /dev/null +++ b/src/genrec/models/model_genrec/base.py @@ -0,0 +1,389 @@ +"""Base model for generative recommendation tasks.""" + +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Generic, Optional, Type, TypeVar, Union + +from jaxtyping import Float, Int +import torch +import torch.nn as nn +from transformers import PretrainedConfig, PreTrainedModel +from transformers.cache_utils import Cache +from transformers.generation.utils import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput + +__all__ = [ + "GenRecModel", + "GenRecModelFactory", + "GenRecModelConfig", + "GenRecModelConfigFactory", + "GenRecOutput", + "GenRecOutputFactory", + "ShiftRightMixin", +] + + +_GenRecModelConfig = TypeVar("_GenRecModelConfig", bound="GenRecModelConfig") +_GenRecOutput = TypeVar("_GenRecOutput", bound="GenRecOutput") +_GenRecEncoderDecoderOutput = TypeVar("_GenRecEncoderDecoderOutput", bound=BaseModelOutputWithPastAndCrossAttentions) +_GenRecModel = TypeVar("_GenRecModel", bound="GenRecModel[Any, Any, Any]") + + +class GenRecModelConfigFactory: # pragma: no cover - factory class + """Factory for creating `GenRecModelConfig` instances.""" + + _registry: dict[str, Type[GenRecModelConfig]] = {} + + @classmethod + def register(cls, name: str) -> Callable[[Type[_GenRecModelConfig]], Type[_GenRecModelConfig]]: + """Decorator to register a `GenRecModelConfig` implementation.""" + + def decorator(config_cls: Type[_GenRecModelConfig]) -> Type[_GenRecModelConfig]: + if name in cls._registry: + raise ValueError(f"GenRec model config '{name}' is already registered.") + cls._registry[name] = config_cls + return config_cls + + return decorator + + @classmethod + def create(cls, name: str, **kwargs) -> GenRecModelConfig: + """Creates an instance of a registered `GenRecModelConfig`.""" + if name not in cls._registry: + raise ValueError(f"GenRec model config '{name}' is not registered.") + config_cls = cls._registry[name] + return config_cls(**kwargs) + + +class GenRecModelConfig(PretrainedConfig): + """Base configuration class for generative recommendation models. + + This class extends the `PretrainedConfig` from the Hugging Face Transformers library + to include common configuration parameters for generative recommendation models. + """ + + model_type = "genrec" + + def __init__( + self, + vocab_size: int = 1024, + hidden_size: int = 256, + is_encoder_decoder: bool = True, + decoder_start_token_id: int = 0, + pad_token_id: int = 0, + tie_word_embeddings: bool = True, + **kwargs, + ) -> None: + """Initializes the configuration with model hyperparameters. + + Args: + vocab_size (int): Size of the Semantic ID vocabulary. Default is 1024. + hidden_size (int): Dimensionality of the model's hidden representations. Default is 256. + is_encoder_decoder (bool): Indicates if the model is an encoder-decoder architecture. Default is True. + decoder_start_token_id (int): The token ID to start decoding with. Default is 0. + pad_token_id (int): The token ID used for padding sequences. Default is 0. + tie_word_embeddings (bool): Whether to tie the input and output word embeddings. Default is True. + **kwargs: Additional keyword arguments for the base `PretrainedConfig`. + """ + super().__init__( + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + + if "is_encoder_decoder" in kwargs: # pragma: no cover - defensive check + assert kwargs["is_encoder_decoder"] is True, "GenRecModel only supports encoder-decoder architectures." + + +class GenRecOutputFactory: # pragma: no cover - factory class + """Factory for creating `GenRecOutput` instances.""" + + _registry: dict[str, Type[GenRecOutput]] = {} + + @classmethod + def register(cls, name: str) -> Callable[[Type[_GenRecOutput]], Type[_GenRecOutput]]: + """Decorator to register a `GenRecOutput` implementation.""" + + def decorator(output_cls: Type[_GenRecOutput]) -> Type[_GenRecOutput]: + if name in cls._registry: + raise ValueError(f"GenRec output '{name}' is already registered.") + cls._registry[name] = output_cls + return output_cls + + return decorator + + @classmethod + def create(cls, name: str, **kwargs) -> GenRecOutput: + """Creates an instance of a registered `GenRecOutput`.""" + if name not in cls._registry: + raise ValueError(f"GenRec output '{name}' is not registered.") + output_cls = cls._registry[name] + return output_cls(**kwargs) + + +@dataclass +class GenRecOutput(Seq2SeqLMOutput): + """Output class for encoder-decoder generative recommendation models. + + This class extends `Seq2SeqLMOutput` from the Hugging Face Transformers library + and includes additional attributes. + + Args: + model_loss (Optional[Float[torch.Tensor, ""]]): The computed model-specific + loss value, if applicable. Note that the model-agnostic loss (e.g., CE loss) + is handled outside of this class. + """ + + model_loss: Optional[Float[torch.Tensor, ""]] = None + + +class ShiftRightMixin(Generic[_GenRecModelConfig]): + """Mixin class providing the `_shift_right` utility method for shifting input IDs + to the right for decoder input sequences. + """ + + config: _GenRecModelConfig + + def _shift_right(self, input_ids: Int[torch.Tensor, "B L"]) -> Int[torch.Tensor, "B L"]: + """Shifts input IDs one position to the right, prepending the decoder start token. + This is used to create decoder input sequences. + + Args: + input_ids (Int[torch.Tensor, "B L"]): Input token sequences of shape (batch_size, seq_len). + + Returns: + Int[torch.Tensor, "B L"]: Shifted input token sequences. + """ + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, "decoder_start_token_id must be defined in GenRecModelConfig." + assert pad_token_id is not None, "pad_token_id must be defined in GenRecModelConfig." + + # shift input ids right by one position, prepending decoder_start_token_id + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class GenRecModelFactory: # pragma: no cover - factory class + """Factory for creating `GenRecModel` instances.""" + + _registry: dict[str, Type[GenRecModel[Any, Any, Any]]] = {} + + @classmethod + def register(cls, name: str) -> Callable[[Type[_GenRecModel]], Type[_GenRecModel]]: + """Decorator to register a `GenRecModel` implementation.""" + + def decorator(model_cls: Type[_GenRecModel]) -> Type[_GenRecModel]: + if name in cls._registry: + raise ValueError(f"GenRec model '{name}' is already registered.") + cls._registry[name] = model_cls + return model_cls + + return decorator + + @classmethod + def create(cls, name: str, **kwargs) -> GenRecModel[Any, Any, Any]: + """Creates an instance of a registered `GenRecModel`.""" + if name not in cls._registry: + raise ValueError(f"GenRec model '{name}' is not registered.") + model_cls = cls._registry[name] + return model_cls(**kwargs) + + @classmethod + def from_pretrained(cls, name: str, path: Union[str, os.PathLike], **kwargs) -> GenRecModel[Any, Any, Any]: + """Loads a pretrained instance of a registered `GenRecModel`.""" + if name not in cls._registry: + raise ValueError(f"GenRec model '{name}' is not registered.") + model_cls = cls._registry[name] + return model_cls.from_pretrained(path, **kwargs) + + +class GenRecModel( + PreTrainedModel, + ShiftRightMixin[_GenRecModelConfig], + Generic[_GenRecModelConfig, _GenRecOutput, _GenRecEncoderDecoderOutput], + GenerationMixin, + ABC, +): + """Base class for encoder-decoder generative recommendation models that support beam + search generation. + + This class extends the `PreTrainedModel` from the Hugging Face Transformers library + and serves as a base for implementing specific generative recommendation models. This + class also provides utilities for constrained generation (e.g., constrained beam search). + + Subclasses must specify the `config_class` attribute and implement the `forward` method. + """ + + config_class: Type[_GenRecModelConfig] + output_class: Type[_GenRecOutput] + main_input_name: str = "input_ids" # main input for generation, i.e., SIDs + supports_gradient_checkpointing = False # change to True if implemented in subclass + + def __init__( + self, + config: _GenRecModelConfig, + ) -> None: + """Initializes the generative recommendation model. + + Args: + config (_GenRecModelConfig): Configuration containing model hyperparameters. + """ + super().__init__(config) + self.config: _GenRecModelConfig + assert self.config.is_encoder_decoder, "GenRecModel only supports encoder-decoder architectures." + + # By default, we assume encoder and decoder share the same token embeddings + self.shared = nn.Embedding(config.vocab_size, config.hidden_size) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + @property + @abstractmethod + def encoder(self) -> nn.Module: # pragma: no cover - abstract method + """Returns the encoder module. You should implement this in subclasses.""" + pass + + @property + @abstractmethod + def decoder(self) -> nn.Module: # pragma: no cover - abstract method + """Returns the decoder module. You should implement this in subclasses.""" + pass + + def get_input_embeddings(self) -> nn.Module: + """Returns the input embedding module.""" + return self.shared + + def set_input_embeddings(self, value: nn.Module) -> None: + """Sets the input embedding module.""" + self.shared = value + assert hasattr(self.encoder, "set_input_embeddings"), "Encoder does not support setting input embeddings." + assert hasattr(self.decoder, "set_input_embeddings"), "Decoder does not support setting input embeddings." + self.encoder.set_input_embeddings(value) + self.decoder.set_input_embeddings(value) + + def get_encoder(self) -> nn.Module: # pragma: no cover - method in abstract base class + """Returns the encoder module.""" + return self.encoder + + def forward( # pragma: no cover - method in abstract base class + self, + input_ids: Int[torch.Tensor, "B L_enc"], + attention_mask: Int[torch.Tensor, "B L_enc"], + decoder_input_ids: Optional[Int[torch.Tensor, "B L_dec"]] = None, + decoder_attention_mask: Optional[Int[torch.Tensor, "B L_dec"]] = None, + encoder_outputs: Optional[_GenRecEncoderDecoderOutput] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[Float[torch.Tensor, "B L_enc d"]] = None, + decoder_inputs_embeds: Optional[Float[torch.Tensor, "B L_dec d"]] = None, + labels: Optional[Int[torch.Tensor, "B L_dec"]] = None, + cache_position: Optional[Int[torch.Tensor, "#L_dec"]] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_model_loss: Optional[bool] = None, + **kwargs: Any, + ) -> _GenRecOutput: + """Defines the forward pass of the generative recommendation model. + + This method provides a typical interface for encoder-decoder models. You may override this method in + subclasses to implement specific model architectures. + + Args: + input_ids (Int[torch.Tensor, "B L_enc"]): Input token sequences of shape (batch_size, seq_len). + attention_mask (Optional[Int[torch.Tensor, "B L_enc"]]): Attention masks for inputs of shape + (batch_size, seq_len). + decoder_input_ids (Optional[Int[torch.Tensor, "B L_dec"]]): Decoder input token sequences + of shape (batch_size, dec_seq_len). If `past_key_values` is used, only the last token + of `decoder_input_ids` have to be input. Default is None. + decoder_attention_mask (Optional[Int[torch.Tensor, "B L_dec"]]): Attention masks for decoder inputs + of shape (batch_size, dec_seq_len). Default is None. + encoder_outputs (Optional[_GenRecEncoderOutput]): Precomputed encoder outputs. + This should be a subclass of `_GenRecEncoderOutput`. Default is None. + past_key_values (Optional[Cache]): Cached key and value tensors for faster decoding. Default is None. + inputs_embeds (Optional[Float[torch.Tensor, "B L d"]]): Input embeddings of `input_ids` of shape + (batch_size, seq_len, hidden_size). If provided, `input_ids` will be ignored. Default is None. + decoder_inputs_embeds (Optional[Float[torch.Tensor, "B L_dec d"]]): Input embeddings of + `decoder_input_ids` of shape (batch_size, dec_seq_len, hidden_size). If provided, + `decoder_input_ids` will be ignored. Default is None. + labels (Optional[Int[torch.Tensor, "B L_dec"]]): Target token sequences for computin the loss, of + shape (batch_size, dec_seq_len). Default is None. + cache_position (Optional[Int[torch.Tensor, "#L_dec"]]): Positions for caching in the decoder. + Default is None. + use_cache (Optional[bool]): Whether to use past key values to speed up decoding. Default is None. + output_attentions (Optional[bool]): Whether to return attention weights. Default is None. + output_hidden_states (Optional[bool]): Whether to return hidden states. Default is None. + output_model_loss (Optional[bool]): Whether to compute and return the model-specific loss. + Default is None. + **kwargs (Any): Additional keyword arguments for the model. + + Returns: + _GenRecOutput: Model outputs packaged as a `GenRecOutput` object. + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + + # Encode if needed (training, first stage of generation) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + assert encoder_outputs is not None + hidden_states = encoder_outputs.last_hidden_state + + # Decoding + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = self._shift_right(labels) + + decoder_outputs: _GenRecEncoderDecoderOutput = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = decoder_outputs.last_hidden_state + + # Compute logits + # NOTE: Some models, e.g., T5, scale the logits by d_model ** -0.5 before the lm_head + # Here we assume that the decoder will apply a final layernorm and thus no scaling is needed + logits = self.lm_head(sequence_output) + + # NOTE: We do not compute the model-agnostic loss (e.g., CE loss) , compute it in `GenRecTrainer` instead + # By default, we set model_loss to None, if you have a model-specific loss, compute it and set it here + model_loss = torch.tensor(0.0, device=logits.device) + + return self.output_class( + loss=None, # model-agnostic loss to be computed outside + logits=logits, + past_key_values=decoder_outputs.past_key_values, # type: ignore - EncoderDecoderCache + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + model_loss=model_loss if output_model_loss else None, + ) diff --git a/src/genrec/models/model_genrec/tiger.py b/src/genrec/models/model_genrec/tiger.py new file mode 100644 index 0000000..6e30b17 --- /dev/null +++ b/src/genrec/models/model_genrec/tiger.py @@ -0,0 +1,483 @@ +"""GenRec Model: TIGER.""" + +from __future__ import annotations + +import copy +from dataclasses import dataclass +from typing import Any, Optional + +from jaxtyping import Float, Int +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from transformers.cache_utils import Cache, EncoderDecoderCache, DynamicCache +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions + +from ..modules import RMSNorm, RotaryEmbedding, T5Block +from ..modules.utils import create_attention_mask +from .base import ( + GenRecModel, + GenRecModelConfig, + GenRecModelConfigFactory, + GenRecModelFactory, + GenRecOutput, + GenRecOutputFactory, + ShiftRightMixin, +) + +__all__ = [ + "TIGERModel", + "TIGERModelConfig", + "TIGERModelOutput", +] + + +@GenRecModelConfigFactory.register("tiger") +class TIGERModelConfig(GenRecModelConfig): + """Configuration class for TIGER model, which extends the base `GenRecModelConfig`.""" + + def __init__( + self, + num_heads: int, + num_encoder_layers: int, + num_decoder_layers: int, + linear_dropout: float = 0.0, + attention_dropout: float = 0.0, + attention_bias: bool = False, + ffn_bias: bool = False, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + enable_rope: bool = False, + **kwargs, + ) -> None: + """Initializes the configuration with model hyperparameters. + + Args: + hidden_size (int): Dimensionality of the model's hidden representations. + num_heads (int): Number of attention heads. + num_encoder_layers (int): Number of layers in the encoder. + num_decoder_layers (int): Number of layers in the decoder. + linear_dropout (float): Dropout rate for the output of attention and feed-forward network. Default is 0.0. + attention_dropout (float): Dropout rate for attention weights. Default is 0.0. + attention_bias (bool): Whether to include bias terms in the attention projections. Default is False. + ffn_bias (bool): Whether to include bias terms in the feed-forward network projections. Default is False. + relative_attention_num_buckets (int): Number of buckets for relative positional embeddings. Default is 32. + relative_attention_max_distance (int): Maximum distance for relative positional embeddings. Default is 128. + enable_rope (bool): Whether to use RoPE instead of learnable relative positional bias. If False, the original + learnable relative positional bias in T5 will be used. Default is False. + **kwargs (Any): Additional keyword arguments for the base `GenRecModelConfig`. + """ + super().__init__(**kwargs) + self.num_heads = num_heads + assert self.hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads" + self.head_dim = self.hidden_size // num_heads + self.num_encoder_layers = num_encoder_layers + self.num_decoder_layers = num_decoder_layers + self.linear_dropout = linear_dropout + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.ffn_bias = ffn_bias + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.enable_rope = enable_rope + + +@GenRecOutputFactory.register("tiger") +@dataclass +class TIGERModelOutput(GenRecOutput): + """Output class for TIGER model. + + The `TIGERModelOutput` extends the base `GenRecModelOutput` without adding new fields. + """ + + pass + + +class TIGERStack(PreTrainedModel, ShiftRightMixin[TIGERModelConfig]): + """Standard T5 stack implementation used in the TIGER model, following HuggingFace's `T5Stack`. + + .. note:: + We do not use the `T5PreTrainedModel._init_weights` method to initialize weights, but directly use + HuggingFace's default initialization in `PreTrainedModel`. + """ + + config_class = TIGERModelConfig + supports_gradient_checkpointing = True + + def __init__(self, config: TIGERModelConfig, embed_tokens: nn.Module) -> None: + """Initializes the T5 stack with the given configuration and token embeddings. + + Args: + config (TIGERModelConfig): Configuration containing model hyperparameters. + embed_tokens (nn.Module): Token embedding module. + """ + super().__init__(config) + self.config: TIGERModelConfig + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.num_layers = config.num_decoder_layers if config.is_decoder else config.num_encoder_layers + self.block = nn.ModuleList( + [ + T5Block( + hidden_size=config.hidden_size, + head_dim=config.head_dim, + num_heads=config.num_heads, + intermediate_size=4 * config.hidden_size, + linear_dropout=config.linear_dropout, + attention_dropout=config.attention_dropout, + attention_bias=config.attention_bias, + ffn_bias=config.ffn_bias, + is_decoder=config.is_decoder, + # only the first layer may compute relative attention bias + # the other layers will reuse the same bias + has_relative_attention_bias=bool(layer_idx == 0), + relative_attention_num_buckets=config.relative_attention_num_buckets, + relative_attention_max_distance=config.relative_attention_max_distance, + enable_rope=config.enable_rope, + layer_idx=layer_idx, + ) + for layer_idx in range(self.num_layers) + ] + ) + self.final_layer_norm = RMSNorm(config.hidden_size, eps=1e-6) + self.dropout = nn.Dropout(config.dropout_rate) + + self.rotary_emb = RotaryEmbedding(head_dim=config.head_dim) + + self.gradient_checkpointing = False # disable gradient checkpointing by default + self.post_init() # use PretrainedModel's default weight initialization + + def forward( + self, + input_ids: Optional[Int[torch.Tensor, "B L"]] = None, + attention_mask: Optional[Int[torch.Tensor, "B L"]] = None, + inputs_embeds: Optional[Float[torch.Tensor, "B L d"]] = None, + encoder_hidden_states: Optional[Float[torch.Tensor, "B L_enc d"]] = None, + encoder_attention_mask: Optional[Int[torch.Tensor, "B L_enc"]] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[Int[torch.Tensor, "L"]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs: Any, + ) -> BaseModelOutputWithPastAndCrossAttentions: + """Forward pass for the T5 stack. + + Args: + input_ids (Optional[Int[torch.Tensor, "B L"]]): Input token sequences of shape (batch_size, seq_len). + Default is None. + attention_mask (Optional[Int[torch.Tensor, "B L"]]): Attention masks for inputs of shape + (batch_size, seq_len). Default is None. + inputs_embeds (Optional[Float[torch.Tensor, "B L d"]]): Input embeddings of `input_ids` of shape + (batch_size, seq_len, hidden_size). If provided, `input_ids` will be ignored. Default is None. + encoder_hidden_states (Optional[Float[torch.Tensor, "B L_enc d"]]): Encoder hidden states for decoder + stacks of shape (batch_size, seq_len_enc, hidden_size). + encoder_attention_mask (Optional[Int[torch.Tensor, "B L_enc"]]): Attention masks for encoder hidden states + of shape (batch_size, seq_len_enc). + past_key_values (Optional[EncoderDecoderCache]): Past key values for faster decoding. Default is None. + use_cache (Optional[bool]): Whether to use past key values to speed up decoding. Default is None. + cache_position (Optional[Int[torch.Tensor, "L"]]): Positions for caching in decoding of shape (seq_len,). + Default is None. + output_attentions (Optional[bool]): Whether to return attention weights. Default is None. + output_hidden_states (Optional[bool]): Whether to return hidden states. Default is None. + **kwargs (Any): Additional keyword arguments. + + Returns: + BaseModelOutputWithPastAndCrossAttentions: Model outputs including hidden states, attentions, + and past key values. + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # Prepare inputs + if inputs_embeds is None: + assert input_ids is not None, "Either input_ids or inputs_embeds must be provided." + inputs_embeds = self.embed_tokens(input_ids) + + assert inputs_embeds is not None + batch_size, seq_length = inputs_embeds.size()[:2] + + # If gradient checkpointing is enabled, disable use_cache + if self.gradient_checkpointing and self.training: + use_cache = False + + # If this module is not a decoder but use_cache is True, raise an error + if not self.is_decoder and use_cache: + raise ValueError("`use_cache=True` is only supported for decoder stacks.") + + # Prepare cache for decoder + if self.is_decoder: + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache( + DynamicCache(config=self.config), DynamicCache(config=self.config) + ) + else: + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + # Prepare attention mask for inputs + if attention_mask is None: + mask_seq_len = past_key_values_length + seq_length + attention_mask = torch.ones((batch_size, mask_seq_len), device=inputs_embeds.device) + + if self.is_decoder: + # for static/compileable cache, we may need to provide kv_seq_len (i.e., max cache size) + kv_seq_len = None + if past_key_values is not None and past_key_values.is_compileable: # pragma: no cover - static cache + max_cache_shape = past_key_values.get_max_cache_shape() + if max_cache_shape > 0: + kv_seq_len = max_cache_shape + + # decoder causal mask of shape (batch_size, 1, L, L_k) + # where L_k = max(seq_len, past_key_values_length + seq_len, kv_seq_len or 0) + causal_mask: Float[torch.Tensor, "B 1 L L_k"] + causal_mask = create_attention_mask( + attention_mask, + tgt_len=seq_length, + is_causal=True, + dtype=inputs_embeds.dtype, + cache_position=cache_position, + past_key_values_length=past_key_values_length, + kv_seq_len=kv_seq_len, + ) + elif attention_mask is not None: + # encoder attention mask of shape (batch_size, 1, 1, seq_len) + causal_mask: Float[torch.Tensor, "B 1 1 L"] + causal_mask = create_attention_mask( + attention_mask, + tgt_len=1, + is_causal=False, + dtype=inputs_embeds.dtype, + ) + else: # pragma: no cover - attention mask is None + causal_mask = None + + # Process encoder attention mask for decoder, of shape (batch_size, 1, 1, seq_len_enc) + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_seq_len, _ = encoder_hidden_states.size() + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + (encoder_batch_size, encoder_seq_len), device=inputs_embeds.device, dtype=torch.long + ) + encoder_extended_attention_mask: Optional[Float[torch.Tensor, "B 1 1 L_enc"]] + encoder_extended_attention_mask = create_attention_mask( + encoder_attention_mask, + tgt_len=1, + is_causal=False, + dtype=inputs_embeds.dtype, + ) + else: + encoder_extended_attention_mask = None + + # Prepare optional outputs, position bias for attention, and RoPE cache + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + + position_bias = None + encoder_decoder_position_bias = None + + position_embeddings = None + if self.config.enable_rope: + rope_position_ids: Optional[torch.Tensor] = None + if cache_position is not None and self.is_decoder: + # cache_position may be 1D (shared across batch) or 2D (per-sample); normalize to (B, L) + rope_position_ids = cache_position + if rope_position_ids.dim() == 1: # pragma: no cover - resize position ids + rope_position_ids = rope_position_ids.unsqueeze(0) + if rope_position_ids.size(0) != batch_size: # pragma: no cover - resize position ids + rope_position_ids = rope_position_ids.expand(batch_size, -1) + position_embeddings = self.rotary_emb(inputs_embeds, position_ids=rope_position_ids) + + # Model forward pass through each layer + hidden_states = self.dropout(inputs_embeds) + + for layer_idx, layer in enumerate(self.block): + # Save current hidden states if needed + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states = all_hidden_states + (hidden_states,) + + # Model forward, layer outputs is a tuple of: + # - output hidden states + # - self attention outputs: position bias, attn weights (if requested) + # - cross attention outputs: position bias, attn weights (if requested) + hidden_states, self_attn_outputs, cross_attn_outputs = layer( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + position_embeddings=position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + past_key_values=past_key_values, + cache_position=cache_position, + output_attentions=output_attentions, + ) + + # Share the position bias across layers, save self- and cross-attentions + position_bias = self_attn_outputs[0] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = cross_attn_outputs[0] + + if output_attentions: + assert all_attentions is not None + all_attentions = all_attentions + (self_attn_outputs[1],) + + if output_attentions and self.is_decoder: + assert all_cross_attentions is not None + all_cross_attentions = all_cross_attentions + (cross_attn_outputs[1],) + + # Final layer norm + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Save final hidden states if needed + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states = all_hidden_states + (hidden_states,) + + # Prepare outputs + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +@GenRecModelFactory.register("tiger") +class TIGERModel(GenRecModel[TIGERModelConfig, TIGERModelOutput, BaseModelOutputWithPastAndCrossAttentions]): + """TIGER model implementation, following the `T5ForConditionalGeneration` architecture from HuggingFace Transformers. + This model can be viewed as a base implementation for generative recommendation tasks. + + References: + - Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. JMLR '20. + - Recommender Systems with Generative Retrieval. NeurIPS '23. + """ + + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + config_class = TIGERModelConfig + output_class = TIGERModelOutput + supports_gradient_checkpointing = True + + def __init__(self, config: TIGERModelConfig) -> None: + """Initializes the TIGER model with the given configuration.""" + super().__init__(config) + self.config: TIGERModelConfig + + # Set up encoder configuration + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.tie_encoder_decoder = False + encoder_config.num_hidden_layers = config.num_encoder_layers + self._encoder = TIGERStack(encoder_config, self.shared) + + # Set up decoder configuration + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.tie_encoder_decoder = False + decoder_config.num_hidden_layers = config.num_decoder_layers + self._decoder = TIGERStack(decoder_config, self.shared) + + self.gradient_checkpointing = False # disable gradient checkpointing by default + self.post_init() # use PretrainedModel's default weight initialization + + def _tie_weights(self) -> None: + """Ties the weights of the encoder and decoder embeddings to the shared embeddings if needed.""" + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + @property + def encoder(self) -> TIGERStack: + """Returns the encoder module.""" + return self._encoder + + @property + def decoder(self) -> TIGERStack: + """Returns the decoder module.""" + return self._decoder + + def forward( + self, + input_ids: Int[torch.Tensor, "B L_enc"], + attention_mask: Int[torch.Tensor, "B L_enc"], + decoder_input_ids: Optional[Int[torch.Tensor, "B L_dec"]] = None, + decoder_attention_mask: Optional[Int[torch.Tensor, "B L_dec"]] = None, + encoder_outputs: Optional[BaseModelOutputWithPastAndCrossAttentions] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[Float[torch.Tensor, "B L_enc d"]] = None, + decoder_inputs_embeds: Optional[Float[torch.Tensor, "B L_dec d"]] = None, + labels: Optional[Int[torch.Tensor, "B L_dec"]] = None, + cache_position: Optional[Int[torch.Tensor, "#L_dec"]] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_model_loss: Optional[bool] = None, + **kwargs: Any, + ) -> TIGERModelOutput: + """Defines the forward pass of TIGERModel. We directly follow the standard implementation in the + base `GenRecModel`. + + Args: + input_ids (Int[torch.Tensor, "B L_enc"]): Input token sequences of shape (batch_size, seq_len). + attention_mask (Optional[Int[torch.Tensor, "B L_enc"]]): Attention masks for inputs of shape + (batch_size, seq_len). + decoder_input_ids (Optional[Int[torch.Tensor, "B L_dec"]]): Decoder input token sequences + of shape (batch_size, dec_seq_len). If `past_key_values` is used, only the last token + of `decoder_input_ids` have to be input. Default is None. + decoder_attention_mask (Optional[Int[torch.Tensor, "B L_dec"]]): Attention masks for decoder inputs + of shape (batch_size, dec_seq_len). Default is None. + encoder_outputs (Optional[BaseModelOutputWithPastAndCrossAttentions]): Precomputed encoder outputs. + Default is None. + past_key_values (Optional[Cache]): Cached key and value tensors for faster decoding. Default is None. + inputs_embeds (Optional[Float[torch.Tensor, "B L d"]]): Input embeddings of `input_ids` of shape + (batch_size, seq_len, hidden_size). If provided, `input_ids` will be ignored. Default is None. + decoder_inputs_embeds (Optional[Float[torch.Tensor, "B L_dec d"]]): Input embeddings of + `decoder_input_ids` of shape (batch_size, dec_seq_len, hidden_size). If provided, + `decoder_input_ids` will be ignored. Default is None. + labels (Optional[Int[torch.Tensor, "B L_dec"]]): Target token sequences for computin the loss, of + shape (batch_size, dec_seq_len). Default is None. + cache_position (Optional[Int[torch.Tensor, "#L_dec"]]): Positions for caching in the decoder. + Default is None. + use_cache (Optional[bool]): Whether to use past key values to speed up decoding. Default is None. + output_attentions (Optional[bool]): Whether to return attention weights. Default is None. + output_hidden_states (Optional[bool]): Whether to return hidden states. Default is None. + output_model_loss (Optional[bool]): Whether to compute and return the model-specific loss. + Default is None. + **kwargs (Any): Additional keyword arguments for the model. + + Returns: + TIGERModelOutput: Model outputs packaged as a `GenRecOutput` object. + """ + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + labels=labels, + cache_position=cache_position, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + output_model_loss=output_model_loss, + **kwargs, + ) diff --git a/src/genrec/models/model_seqrec/base.py b/src/genrec/models/model_seqrec/base.py index e3c573a..e05caf9 100644 --- a/src/genrec/models/model_seqrec/base.py +++ b/src/genrec/models/model_seqrec/base.py @@ -182,7 +182,7 @@ class SeqRecModel(PreTrainedModel, Generic[_SeqRecModelConfig, _SeqRecOutput], A """ config_class: Type[_SeqRecModelConfig] - supports_gradient_checkpointing = True + supports_gradient_checkpointing = False # change to True if implemented in subclass def __init__(self, config: _SeqRecModelConfig) -> None: """Initializes the sequential recommendation model. diff --git a/src/genrec/models/model_seqrec/hstu.py b/src/genrec/models/model_seqrec/hstu.py index c3b17b2..4be2c22 100644 --- a/src/genrec/models/model_seqrec/hstu.py +++ b/src/genrec/models/model_seqrec/hstu.py @@ -123,6 +123,7 @@ class HSTUModel(SeqRecModel[HSTUModelConfig, HSTUModelOutput]): """ config_class = HSTUModelConfig + supports_gradient_checkpointing = True def __init__(self, config: HSTUModelConfig) -> None: """Initializes HSTU model with the given configuration.""" diff --git a/src/genrec/models/model_seqrec/hstu_spring.py b/src/genrec/models/model_seqrec/hstu_spring.py index 47a5b2e..38be1ee 100644 --- a/src/genrec/models/model_seqrec/hstu_spring.py +++ b/src/genrec/models/model_seqrec/hstu_spring.py @@ -97,6 +97,7 @@ class HSTUSpringModel(SeqRecModel[HSTUSpringModelConfig, HSTUSpringModelOutput]) """ config_class = HSTUSpringModelConfig + supports_gradient_checkpointing = True def __init__(self, config: HSTUSpringModelConfig) -> None: """Initializes HSTU model with the given configuration.""" diff --git a/src/genrec/models/model_seqrec/sasrec.py b/src/genrec/models/model_seqrec/sasrec.py index fb71b16..1fb9eb8 100644 --- a/src/genrec/models/model_seqrec/sasrec.py +++ b/src/genrec/models/model_seqrec/sasrec.py @@ -79,6 +79,7 @@ class SASRecModel(SeqRecModel[SASRecModelConfig, SASRecModelOutput]): """ config_class = SASRecModelConfig + supports_gradient_checkpointing = True def __init__(self, config: SASRecModelConfig) -> None: """Initializes SASRec model with the given configuration.""" diff --git a/src/genrec/models/model_seqrec/sasrec_spring.py b/src/genrec/models/model_seqrec/sasrec_spring.py index 07b36e3..5a25c4a 100644 --- a/src/genrec/models/model_seqrec/sasrec_spring.py +++ b/src/genrec/models/model_seqrec/sasrec_spring.py @@ -83,6 +83,7 @@ class SASRecSpringModel(SeqRecModel[SASRecSpringModelConfig, SASRecSpringModelOu """ config_class = SASRecSpringModelConfig + supports_gradient_checkpointing = True def __init__(self, config: SASRecSpringModelConfig) -> None: """Initializes SASRec model with the given configuration.""" diff --git a/src/genrec/models/modules/__init__.py b/src/genrec/models/modules/__init__.py index 5167ba9..d0cad98 100644 --- a/src/genrec/models/modules/__init__.py +++ b/src/genrec/models/modules/__init__.py @@ -2,11 +2,12 @@ __all__ = [] -from .attention import MaskedHSTUAttention, MaskedSelfAttentionWithRoPE +from .attention import MaskedHSTUAttention, MaskedSelfAttentionWithRoPE, T5Attention __all__ += [ "MaskedHSTUAttention", "MaskedSelfAttentionWithRoPE", + "T5Attention", ] from .feedforward import FeedForwardNetwork, MLP, SwiGLU @@ -25,18 +26,20 @@ from .layers import ( LlamaDecoderLayer, - SequentialTransductionUnit, SpringLlamaDecoderLayer, + SequentialTransductionUnit, SpringSequentialTransductionUnit, + T5Block, spring_attention_weight_spectral_norm, spring_power_iteration, ) __all__ += [ "LlamaDecoderLayer", - "SequentialTransductionUnit", "SpringLlamaDecoderLayer", + "SequentialTransductionUnit", "SpringSequentialTransductionUnit", + "T5Block", "spring_attention_weight_spectral_norm", "spring_power_iteration", ] @@ -45,6 +48,7 @@ LearnableInputPositionalEmbedding, RelativeBucketedTimeAndPositionAttentionBias, RotaryEmbedding, + T5RelativePositionBias, apply_rotary_pos_emb, ) @@ -52,6 +56,7 @@ "LearnableInputPositionalEmbedding", "RelativeBucketedTimeAndPositionAttentionBias", "RotaryEmbedding", + "T5RelativePositionBias", "apply_rotary_pos_emb", ] diff --git a/src/genrec/models/modules/attention.py b/src/genrec/models/modules/attention.py index 1ba48e3..5b72faa 100644 --- a/src/genrec/models/modules/attention.py +++ b/src/genrec/models/modules/attention.py @@ -2,19 +2,21 @@ from __future__ import annotations -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from jaxtyping import Float, Int import torch import torch.nn as nn import torch.nn.functional as F +from transformers.cache_utils import Cache, EncoderDecoderCache from .layernorm import RMSNorm -from .posemb import RelativeBucketedTimeAndPositionAttentionBias, apply_rotary_pos_emb +from .posemb import RelativeBucketedTimeAndPositionAttentionBias, T5RelativePositionBias, apply_rotary_pos_emb __all__ = [ "MaskedHSTUAttention", "MaskedSelfAttentionWithRoPE", + "T5Attention", ] @@ -261,3 +263,214 @@ def forward( attn_output = self.o_proj(av_output) return attn_output, attn_weights + + +class T5Attention(nn.Module): + """Multi-Head Attention module used in T5 model, following `T5Attention`'s implementation. + + Compared to standard T5Attention, this module provides several options to generalize + and enhance the attention mechanism, including: + - Option to switch the original learnable relative attention bias with Rotary Positional + Embeddings (RoPE) for better extrapolation to longer sequences and improved performance. + + References: + - Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. JMLR '20. + """ + + def __init__( + self, + hidden_size: int, + head_dim: int, + num_heads: int, + attention_dropout: float = 0.0, + attention_bias: bool = False, + is_decoder: bool = False, + has_relative_attention_bias: bool = False, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + enable_rope: bool = False, + layer_idx: Optional[int] = None, + ) -> None: + """Initializes T5Attention module. + + Args: + hidden_size (int): Dimensionality of the model's hidden representations. + head_dim (int): Dimensionality of each attention head. + num_heads (int): Number of attention heads. + attention_dropout (float): Dropout rate for attention weights. Default is 0.0. + attention_bias (bool): Whether to include bias terms in the attention projections. Default is False. + is_decoder (bool): Whether this attention module is used in the decoder. This is used to determine + the directionality of relative positional embeddings. Default is False. + has_relative_attention_bias (bool): Whether to compute learnable relative positional bias. If False, this + module will not initialize a `T5RelativePositionBias` instance. Typically, T5 set `has_relative_attention_bias` + to True only for the first block, while the rest blocks reuse the same relative positional bias. Note + that when `enable_rope` is True, this argument will be ignored. Default is False. + relative_attention_num_buckets (int): Number of buckets for relative positional embeddings. Default is 32. + relative_attention_max_distance (int): Maximum distance for relative positional embeddings. Default is 128. + enable_rope (bool): Whether to use RoPE instead of learnable relative positional bias. If False, the original + learnable relative positional bias in T5 will be used. Default is False. + layer_idx (Optional[int]): Optional layer index of this attention module in the model. This should be set + when caching past key/values in the decoder for autoregressive generation. Default is None. + """ + super().__init__() + self.hidden_size = hidden_size + self.head_dim = head_dim + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.is_decoder = is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.enable_rope = enable_rope + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=attention_bias) + + self.rel_pos_bias: Optional[T5RelativePositionBias] = None + if self.has_relative_attention_bias and not self.enable_rope: + self.rel_pos_bias = T5RelativePositionBias( + num_buckets=relative_attention_num_buckets, + max_distance=relative_attention_max_distance, + num_heads=num_heads, + is_bidirectional=(not is_decoder), + ) + + def forward( + self, + hidden_states: Float[torch.Tensor, "B L_q d"], + attention_mask: Optional[Float[torch.Tensor, "B 1 #L_q L_k"]] = None, + key_value_states: Optional[Float[torch.Tensor, "B L_k d"]] = None, + position_bias: Optional[Float[torch.Tensor, "#B H L_q L_k"]] = None, + position_embeddings: Optional[ + Tuple[Float[torch.Tensor, "B L_q head_dim"], Float[torch.Tensor, "B L_q head_dim"]] + ] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + cache_position: Optional[Int[torch.Tensor, "L_q"]] = None, + output_attentions: bool = False, + ) -> Tuple[ + Float[torch.Tensor, "B L_q d"], + Optional[Float[torch.Tensor, "#B H L_q L_k"]], + Optional[Float[torch.Tensor, "B H L_q L_k"]], + ]: + """Forward pass for T5Attention. + + Args: + hidden_states (Float[torch.Tensor, "B L_q d"]): Input tensor of shape (batch_size, query_len, hidden_size). + attention_mask (Optional[Float[torch.Tensor, "B 1 #L_q L_k"]]): Optional attention mask added to the scores + before softmax, with shape either (batch_size, 1, query_len, key_len) for causal self-attention in + decoder, or (batch_size, 1, 1, key_len) for self-attention in encoder and cross-attention in decoder. + Specifically, if the dimension `L_k` is longer than the actual key length (which can happen during + autoregressive generation in the decoder), the mask will be sliced accordingly. Default is None. + key_value_states (Optional[Float[torch.Tensor, "B L_k d"]]): Optional tensor for key and value states, which + assumes to be the output of the encoder and is used in the decoder cross-attention. If None, self-attention + is performed; otherwise, cross-attention is performed. Default is None. + position_bias (Optional[Float[torch.Tensor, "#B H L_q L_k"]]): Optional precomputed position bias to be added + to the attention scores. If None, and if `has_relative_attention_bias` is True and `enable_rope` is False, + the position bias will be computed based on the relative positions of the queries and keys from the T5 + learnable relative positional embeddings. Default is None. + position_embeddings (Optional[Tuple[Float[torch.Tensor, "B L_q head_dim"], Float[torch.Tensor, "B L_q head_dim"]]]): + Optional tuple of cosine and sine embeddings for RoPE. Note that when `enable_rope` is False, this argument + will be ignored. Default is None. + past_key_values (Optional[EncoderDecoderCache]): Optional cache for previously computed key and value states, + used in the decoder for faster autoregressive generation. Default is None. + cache_position (Optional[Int[torch.Tensor, "L_q"]]): Optional position IDs used to compute relative positions + when caching past key/values in the decoder. If provided, it should contain the absolute positions of the + current query tokens. Default is None. + output_attentions (bool): Whether to return the attention weights along with the output. Default is False. + + Returns: + Tuple[ + Float[torch.Tensor, "B L_q d"], + Optional[Float[torch.Tensor, "#B H L_q L_k"]], + Optional[Float[torch.Tensor, "B H L_q L_k"]], + ]: Output tensor, position bias, and attention weights tensor. If `output_attentions` is False, the attention weights + will be None. + """ + B, L_q, _ = hidden_states.shape + H, head_dim = self.num_heads, self.head_dim + + # Compute query states + query_states: Float[torch.Tensor, "B H L_q head_dim"] + query_states = self.q_proj(hidden_states).view(B, L_q, H, head_dim).transpose(1, 2) + + # If key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + # Compute key and value states from either key_value_states or hidden_states + is_updated = False # indicates whether the cross-attention kv states are updated in the cache + curr_past_key_value: Optional[Cache] = None + if past_key_values is not None: + assert self.layer_idx is not None, "layer_idx must not be None when accessing cached layers" + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + + # Compute key and value states, updating the cache if necessary + key_states: Float[torch.Tensor, "B H L_k head_dim"] + value_states: Float[torch.Tensor, "B H L_k head_dim"] + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # use cached key and value states for cross-attention + assert curr_past_key_value is not None and self.layer_idx is not None + key_states = curr_past_key_value.layers[self.layer_idx].keys # type: ignore - lazy init in DynamicLayer + value_states = curr_past_key_value.layers[self.layer_idx].values # type: ignore - lazy init in DynamicLayer + else: + # cache miss for cross-attention, or in self-attention case + key_states = self.k_proj(current_states).view(B, -1, H, head_dim).transpose(1, 2) + value_states = self.v_proj(current_states).view(B, -1, H, head_dim).transpose(1, 2) + + # apply RoPE if not using learnable relative PE (only for self-attention, so L_q = L_k) + if not is_cross_attention and position_embeddings is not None and self.enable_rope: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # update cache with new key and value states + if past_key_values is not None: + assert curr_past_key_value is not None and self.layer_idx is not None + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + if is_cross_attention: # update cross-attention cache flag, as it is static after being set + past_key_values.is_updated[self.layer_idx] = True + + # Compute attention scores + attn_weights: Float[torch.Tensor, "B H L_q L_k"] + attn_weights = query_states @ key_states.transpose(-1, -2) + + # Apply relative position bias if enabled. When enabling RoPE, rel_pos_bias is ignored. + L_k = key_states.size(-2) + if position_bias is None: + if self.rel_pos_bias is not None: + real_L_q = cache_position[-1] + 1 if cache_position is not None else L_q + position_bias = self.rel_pos_bias( + real_L_q, L_k, cache_position=cache_position, device=hidden_states.device + )[:, :, -L_q:, :] + else: + position_bias = torch.zeros((1, H, L_q, L_k), device=hidden_states.device) + + # Apply attention mask, note that the returned position bias will contain the mask if provided + assert position_bias is not None + if attention_mask is not None: # slice for longer attention mask + causal_mask = attention_mask[:, :, :, :L_k] + position_bias = position_bias + causal_mask + + attn_weights = attn_weights + position_bias + + # Compute attention scores and attention output + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + softmax_output: Float[torch.Tensor, "B L_q H head_dim"] + softmax_output = (attn_weights @ value_states).transpose(1, 2).contiguous() + + attn_output: Float[torch.Tensor, "B L_q d"] + attn_output = self.o_proj(softmax_output.view(B, L_q, -1)) + + return attn_output, position_bias, attn_weights if output_attentions else None diff --git a/src/genrec/models/modules/feedforward.py b/src/genrec/models/modules/feedforward.py index 6d663a0..1fc9b50 100644 --- a/src/genrec/models/modules/feedforward.py +++ b/src/genrec/models/modules/feedforward.py @@ -24,6 +24,8 @@ def __init__( hidden_size: int, intermediate_size: int, ffn_bias: bool = False, + activation: nn.Module = nn.ReLU(), + dropout: float = 0.0, ) -> None: """Initializes FeedForwardNetwork module. @@ -31,6 +33,8 @@ def __init__( hidden_size (int): Dimensionality of the input and output. intermediate_size (int): Dimensionality of the intermediate layer. ffn_bias (bool): Whether to include bias terms in the linear projections. Default is False. + activation (nn.Module): Activation function to use between layers. Default is ReLU. + dropout (float): Dropout rate to apply after the activation. Default is 0.0. """ super().__init__() self.hidden_size = hidden_size @@ -38,7 +42,8 @@ def __init__( self.fc1 = nn.Linear(hidden_size, intermediate_size, bias=ffn_bias) self.fc2 = nn.Linear(intermediate_size, hidden_size, bias=ffn_bias) - self.act_fn = nn.functional.gelu + self.activation = activation + self.dropout = nn.Dropout(dropout) def forward(self, x: Float[torch.Tensor, "... d"]) -> Float[torch.Tensor, "... d"]: """Forward pass for FeedForwardNetwork. @@ -49,7 +54,7 @@ def forward(self, x: Float[torch.Tensor, "... d"]) -> Float[torch.Tensor, "... d Returns: Float[torch.Tensor, "... d"]: Output tensor of shape (..., hidden_size). """ - return self.fc2(self.act_fn(self.fc1(x))) + return self.fc2(self.dropout(self.activation(self.fc1(x)))) class MLP(nn.Module): @@ -101,6 +106,7 @@ def __init__( hidden_size: int, intermediate_size: int, ffn_bias: bool = False, + dropout: float = 0.0, ) -> None: """Initializes SwiGLU module. @@ -108,6 +114,7 @@ def __init__( hidden_size (int): Dimensionality of the input and output. intermediate_size (int): Dimensionality of the intermediate layer. ffn_bias (bool): Whether to include bias terms in the linear projections. Default is False. + dropout (float): Dropout rate to apply after the activation. Default is 0.0. """ super().__init__() self.hidden_size = hidden_size @@ -117,6 +124,7 @@ def __init__( self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=ffn_bias) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=ffn_bias) self.act_fn = nn.functional.silu + self.dropout = nn.Dropout(dropout) def forward(self, x: Float[torch.Tensor, "... d"]) -> Float[torch.Tensor, "... d"]: """Forward pass for SwiGLU. @@ -127,4 +135,4 @@ def forward(self, x: Float[torch.Tensor, "... d"]) -> Float[torch.Tensor, "... d Returns: Float[torch.Tensor, "... d"]: Output tensor of shape (..., hidden_size). """ - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(self.dropout(self.act_fn(self.gate_proj(x)) * self.up_proj(x))) diff --git a/src/genrec/models/modules/layers.py b/src/genrec/models/modules/layers.py index 5a95c8e..2deba30 100644 --- a/src/genrec/models/modules/layers.py +++ b/src/genrec/models/modules/layers.py @@ -7,19 +7,20 @@ from jaxtyping import Bool, Float, Int import torch import torch.nn as nn -import torch.nn.functional as F - -from .attention import MaskedHSTUAttention, MaskedSelfAttentionWithRoPE +from transformers.cache_utils import EncoderDecoderCache from transformers.modeling_layers import GradientCheckpointingLayer + +from .attention import MaskedHSTUAttention, MaskedSelfAttentionWithRoPE, T5Attention from .feedforward import SwiGLU from .layernorm import RMSNorm from .utils import create_attention_mask __all__ = [ "LlamaDecoderLayer", - "SequentialTransductionUnit", "SpringLlamaDecoderLayer", + "SequentialTransductionUnit", "SpringSequentialTransductionUnit", + "T5Block", "spring_attention_weight_spectral_norm", "spring_power_iteration", ] @@ -595,3 +596,238 @@ def forward( attn_weight_sn = None return hidden_states, attn_weights, attn_weight_sn + + +class T5Block(GradientCheckpointingLayer): + """A standard T5 Transformer Block, following `transformers.T5Block`'s implementation. + + Compared to standard T5Block, our attention module provides several options to generalize + and enhance the attention mechanism, including: + - Option to switch the original learnable relative attention bias with Rotary Positional + Embeddings (RoPE) for better extrapolation to longer sequences and improved performance. + - We replace the original LayerNorm with RMSNorm for better training stability. + - We replace the original feed-forward network with SwiGLU for improved model capacity. + + References: + - Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. JMLR '20. + """ + + def __init__( + self, + hidden_size: int, + head_dim: int, + num_heads: int, + intermediate_size: int, + linear_dropout: float = 0.0, + attention_dropout: float = 0.0, + attention_bias: bool = False, + ffn_bias: bool = False, + is_decoder: bool = False, + has_relative_attention_bias: bool = False, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + enable_rope: bool = False, + layer_idx: Optional[int] = None, + ) -> None: + """Initializes T5Block module. + + Args: + hidden_size (int): Dimensionality of the model's hidden representations. + head_dim (int): Dimensionality of each attention head. + num_heads (int): Number of attention heads. + intermediate_size (int): Dimensionality of the feed-forward network's intermediate layer. + linear_dropout (float): Dropout rate for the output of attention and feed-forward network. Default is 0.0. + attention_dropout (float): Dropout rate for attention weights. Default is 0.0. + attention_bias (bool): Whether to include bias terms in the attention projections. Default is False. + ffn_bias (bool): Whether to include bias terms in the feed-forward network projections. Default is False. + is_decoder (bool): Whether this attention module is used in the decoder. This is used to determine + the directionality of relative positional embeddings. Default is False. + has_relative_attention_bias (bool): Whether to compute learnable relative positional bias. If False, this + module will not initialize a `T5RelativePositionBias` instance. Typically, T5 set `has_relative_attention_bias` + to True only for the first block, while the rest blocks reuse the same relative positional bias. Note + that when `enable_rope` is True, this argument will be ignored. Default is False. + relative_attention_num_buckets (int): Number of buckets for relative positional embeddings. Default is 32. + relative_attention_max_distance (int): Maximum distance for relative positional embeddings. Default is 128. + enable_rope (bool): Whether to use RoPE instead of learnable relative positional bias. If False, the original + learnable relative positional bias in T5 will be used. Default is False. + layer_idx (Optional[int]): Optional layer index of this attention module in the model. This should be set + when caching past key/values in the decoder for autoregressive generation. Default is None. + """ + super().__init__() + + self.hidden_size = hidden_size + self.head_dim = head_dim + self.num_heads = num_heads + self.intermediate_size = intermediate_size + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.ffn_bias = ffn_bias + self.is_decoder = is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.enable_rope = enable_rope + self.layer_idx = layer_idx + + self.self_attn = T5Attention( + hidden_size=hidden_size, + head_dim=head_dim, + num_heads=num_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + is_decoder=is_decoder, + has_relative_attention_bias=has_relative_attention_bias, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + enable_rope=enable_rope, + layer_idx=layer_idx, + ) + self.self_attn_layernorm = RMSNorm(hidden_size) + self.self_attn_dropout = nn.Dropout(linear_dropout) + + self.cross_attn: Optional[T5Attention] = None + self.cross_attn_layernorm: Optional[RMSNorm] = None + self.cross_attn_dropout: Optional[nn.Dropout] = None + if is_decoder: + # cross attention disables relative positional embeddings or RoPE + self.cross_attn = T5Attention( + hidden_size=hidden_size, + head_dim=head_dim, + num_heads=num_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + is_decoder=is_decoder, + has_relative_attention_bias=False, + enable_rope=False, + layer_idx=layer_idx, + ) + self.cross_attn_layernorm = RMSNorm(hidden_size) + self.cross_attn_dropout = nn.Dropout(linear_dropout) + + self.mlp = SwiGLU( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ffn_bias=ffn_bias, + dropout=linear_dropout, + ) + self.post_attention_layernorm = RMSNorm(hidden_size) + + def forward( + self, + hidden_states: Float[torch.Tensor, "B L d"], + attention_mask: Optional[Float[torch.Tensor, "B 1 #L L_k"]] = None, + position_bias: Optional[Float[torch.Tensor, "#B H L L"]] = None, + position_embeddings: Optional[ + Tuple[Float[torch.Tensor, "B L head_dim"], Float[torch.Tensor, "B L head_dim"]] + ] = None, + encoder_hidden_states: Optional[Float[torch.Tensor, "B L_enc d"]] = None, + encoder_attention_mask: Optional[Float[torch.Tensor, "B 1 1 L_enc"]] = None, + encoder_decoder_position_bias: Optional[Float[torch.Tensor, "#B H L L_enc"]] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + cache_position: Optional[Int[torch.Tensor, "L"]] = None, + output_attentions: bool = False, + ) -> Tuple[ + Float[torch.Tensor, "B L d"], + Tuple[ + Optional[Float[torch.Tensor, "#B H L L"]], + Optional[Float[torch.Tensor, "B H L L"]], + ], + Tuple[ + Optional[Float[torch.Tensor, "#B H L L_enc"]], + Optional[Float[torch.Tensor, "B H L L_enc"]], + ], + ]: + """Forward pass for T5Block. + + Args: + hidden_states (Float[torch.Tensor, "B L d"]): Input tensor of shape (batch_size, seq_len, hidden_size). + Note that for cross-attention in the decoder, the encoder output should be provided in the argument + `encoder_hidden_states`. + attention_mask (Optional[Float[torch.Tensor, "B 1 #L L_k"]]): Optional attention mask added to the self-attention + scores before softmax, with shape either (batch_size, 1, seq_len, key_len) for causal self-attention in + decoder, or (batch_size, 1, 1, key_len) for self-attention in encoder. Specifically, if the dimension + `L_k` is longer than the actual key length `L` (which can happen during autoregressive generation in the + decoder), the mask will be sliced accordingly. For cross-attention in the decoder, please use + `encoder_attention_mask`. Default is None. + position_bias (Optional[Float[torch.Tensor, "#B H L L"]]): Optional precomputed position bias to be added + to the attention scores. If None, and if `has_relative_attention_bias` is True and `enable_rope` is False, + the position bias will be computed based on the relative positions of the queries and keys from the T5 + learnable relative positional embeddings. Default is None. + position_embeddings (Optional[Tuple[Float[torch.Tensor, "B L head_dim"], Float[torch.Tensor, "B L head_dim"]]]): + Optional tuple of cosine and sine embeddings for RoPE. Note that when `enable_rope` is False, this argument + will be ignored. Default is None. + encoder_hidden_states (Optional[Float[torch.Tensor, "B L_enc d"]]): Optional encoder hidden states for + cross-attention in the decoder, of shape (batch_size, encoder_len, hidden_size). If provided, cross-attention + will be performed using these encoder hidden states as keys and values. Default is None. + encoder_attention_mask (Optional[Float[torch.Tensor, "B 1 1 L_enc"]]): Optional attention mask for + cross-attention in the decoder, of shape (batch_size, 1, 1, encoder_len). Default is None. + encoder_decoder_position_bias (Optional[Float[torch.Tensor, "#B H L L_enc"]]): Optional precomputed + position bias to be added to the cross-attention scores. If None, no positional bias will be used in + cross-attention. In usual T5 implementations, cross-attention does not use relative positional embeddings + or RoPE, and this argument is typically set to None. This argument can be provided manually if needed. + Default is None. + past_key_values (Optional[EncoderDecoderCache]): Optional cache for previously computed key and value states, + used in the decoder for faster autoregressive generation. Default is None. + cache_position (Optional[Int[torch.Tensor, "L"]]): Optional position IDs used to compute relative positions + when caching past key/values in the decoder. If provided, it should contain the absolute positions of the + current query tokens. Default is None. + output_attentions (bool): Whether to return the attention weights. Default is False. + + Returns: + Tuple[ + Float[torch.Tensor, "B L d"], + Tuple[ + Optional[Float[torch.Tensor, "#B H L L"]], + Optional[Float[torch.Tensor, "B H L L"]], + ], + Tuple[ + Optional[Float[torch.Tensor, "#B H L L_enc"]], + Optional[Float[torch.Tensor, "B H L L_enc"]], + ], + ]: A tuple containing: + - output tensor of shape (batch_size, seq_len, hidden_size). + - a tuple of optional self-attention position bias and attention weights. + - a tuple of optional cross-attention position bias and attention weights. + """ + # Self-Attention + residual = hidden_states + hidden_states = self.self_attn_layernorm(hidden_states) + self_attn_outputs: Tuple = self.self_attn( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + cache_position=cache_position, + output_attentions=output_attentions, + ) + hidden_states = residual + self.self_attn_dropout(self_attn_outputs[0]) + self_attn_outputs = self_attn_outputs[1:] # remove the output hidden states + + # Cross-Attention + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + cross_attn_outputs = (None, None) + if do_cross_attention: + assert self.cross_attn is not None + assert self.cross_attn_layernorm is not None + assert self.cross_attn_dropout is not None + + residual = hidden_states + hidden_states = self.cross_attn_layernorm(hidden_states) + cross_attn_outputs = self.cross_attn( + hidden_states, + attention_mask=encoder_attention_mask, + key_value_states=encoder_hidden_states, + position_bias=encoder_decoder_position_bias, + past_key_values=past_key_values, + output_attentions=output_attentions, + ) + hidden_states = residual + self.cross_attn_dropout(cross_attn_outputs[0]) + cross_attn_outputs = cross_attn_outputs[1:] # remove the output hidden states + + # Feed-Forward Network + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + self.mlp(hidden_states) + + return hidden_states, self_attn_outputs, cross_attn_outputs diff --git a/src/genrec/models/modules/posemb.py b/src/genrec/models/modules/posemb.py index caab231..e94d166 100644 --- a/src/genrec/models/modules/posemb.py +++ b/src/genrec/models/modules/posemb.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math from typing import Optional, Callable, Tuple from jaxtyping import Float, Int @@ -12,6 +13,7 @@ "LearnableInputPositionalEmbedding", "RelativeBucketedTimeAndPositionAttentionBias", "RotaryEmbedding", + "T5RelativePositionBias", "apply_rotary_pos_emb", ] @@ -242,3 +244,119 @@ def apply_rotary_pos_emb( query_rotated = (query * cos) + (rotate_half(query) * sin) key_rotated = (key * cos) + (rotate_half(key) * sin) return query_rotated, key_rotated + + +class T5RelativePositionBias(nn.Module): + """T5 relative position bias module. + + References: + - Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. JMLR '20. + """ + + def __init__( + self, + num_buckets: int, + max_distance: int, + num_heads: int, + is_bidirectional: bool, + ) -> None: + """Initializes T5RelativePositionBias module. + + Args: + num_buckets (int): Number of relative position buckets. + max_distance (int): Maximum distance for relative positions. + num_heads (int): Number of attention heads. + is_bidirectional (bool): Whether the attention is bidirectional. + """ + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.num_heads = num_heads + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + self.is_bidirectional = is_bidirectional + + @staticmethod + def _relative_position_bucket( + relative_position: Int[torch.Tensor, "..."], + bidirectional: bool, + num_buckets: int, + max_distance: int, + ) -> Int[torch.Tensor, "..."]: + """Converts relative position to a bucket number for relative attention. + + Args: + relative_position (Int[torch.Tensor, "..."]): Relative position, which is + defined as the distance from the query position to the key position, + i.e., key_pos - query_pos. + bidirectional (bool): Whether the attention is bidirectional. + num_buckets (int): Number of relative position buckets. + max_distance (int): Maximum distance for relative positions. + + Returns: + Int[torch.Tensor, "..."]: Bucketed relative position tensor, within [0, num_buckets). + """ + relative_buckets = 0 + if bidirectional: # both directions use half of the buckets + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: # decoder-only attention + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # the other half of the buckets are for logarithmically bigger bins in positions up to max_distance + # i.e., N + log(relative_position / N) / log(max_distance / N) * N, where N = num_buckets / 2 + # if relative_position > max_distance, just put it in the last bucket + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def forward( + self, + query_length: int, + key_length: int, + cache_position: Optional[Int[torch.Tensor, "q_len"]] = None, + device: Optional[torch.device] = None, + ) -> Float[torch.Tensor, "1 H q_len k_len"]: + """Computes T5 relative position bias. + + Args: + query_length (int): Length of the query sequence. + key_length (int): Length of the key sequence. + cache_position (Optional[Int[torch.Tensor, "q_len"]]): Optional position IDs used to compute + relative positions when caching past key/values in the decoder. If provided, it should + contain the absolute positions of the current query tokens. Default is None. + device (Optional[torch.device]): Device identifier to perform computations on. + + Returns: + Float[torch.Tensor, "1 H q_len k_len"]: Relative position bias tensor of shape + (1, num_heads, query_length, key_length). If `cache_position` is provided, + `q_len` corresponds to the length of `cache_position`; otherwise, it equals + `query_length`. + """ + device = self.relative_attention_bias.weight.device if device is None else device + if cache_position is None: + query_pos: Int[torch.Tensor, "q_len"] = torch.arange(query_length, dtype=torch.long, device=device) + else: + query_pos = cache_position.to(device) + key_pos: Int[torch.Tensor, "k_len"] = torch.arange(key_length, dtype=torch.long, device=device) + rel_pos: Int[torch.Tensor, "q_len k_len"] = key_pos[None, :] - query_pos[:, None] + rel_pos_buckets: Int[torch.Tensor, "q_len k_len"] = self._relative_position_bucket( + rel_pos, + bidirectional=self.is_bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + values: Float[torch.Tensor, "q_len k_len H"] = self.relative_attention_bias(rel_pos_buckets) + return values.permute(2, 0, 1).unsqueeze(0) diff --git a/src/genrec/models/modules/utils.py b/src/genrec/models/modules/utils.py index 6d443dc..2a0bd9d 100644 --- a/src/genrec/models/modules/utils.py +++ b/src/genrec/models/modules/utils.py @@ -18,7 +18,10 @@ def create_attention_mask( is_causal: bool = True, mask_value: Optional[float] = None, dtype: torch.dtype = torch.float32, -) -> Float[torch.Tensor, "B 1 tgt_len seq_len"]: + cache_position: Optional[Int[torch.Tensor, "seq_len"]] = None, + past_key_values_length: int = 0, + kv_seq_len: Optional[int] = None, +) -> Float[torch.Tensor, "B 1 tgt_len key_len"]: """Creates a 4D attention mask from a 2D attention mask. Args: @@ -29,28 +32,65 @@ def create_attention_mask( mask_value (Optional[float]): Value to use for masked positions. If None, uses the minimum representable value for the specified `dtype`. dtype (torch.dtype): Data type of the output mask. Default is `torch.float32`. + cache_position (Optional[Int[torch.Tensor, "seq_len"]]): Absolute positions of the target tokens when + using KV-cache during decoding. If None, positions are inferred from `past_key_values_length`. + past_key_values_length (int): Number of cached tokens already seen by the decoder. Default is 0. + kv_seq_len (Optional[int]): Optional static key length (e.g., when using a compileable cache). If provided, + the attention mask will be expanded or trimmed to this length. Returns: - Float[torch.Tensor, "B 1 tgt_len seq_len"]: Attention mask tensor where masked positions - are set to `mask_value` (default to `-inf`) and unmasked positions are set to 0. + Float[torch.Tensor, "B 1 tgt_len key_len"]: Attention mask tensor where masked positions + are set to `mask_value` (default to `-inf`) and unmasked positions are set to 0. Here + `key_len` equals `max(seq_len, past_key_values_length + tgt_len, kv_seq_len or 0)`. """ + device = attention_mask.device batch_size, seq_len = attention_mask.shape tgt_len = tgt_len if tgt_len is not None else seq_len + required_key_len = past_key_values_length + tgt_len + if kv_seq_len is not None and kv_seq_len > 0: + required_key_len = max(required_key_len, kv_seq_len) + else: + required_key_len = max(required_key_len, seq_len) + + if seq_len < required_key_len: + pad = attention_mask.new_zeros(batch_size, required_key_len - seq_len) + attention_mask = torch.cat([attention_mask, pad], dim=-1) + elif seq_len > required_key_len: + attention_mask = attention_mask[:, :required_key_len] + padding_mask = attention_mask == 0 - padding_mask = padding_mask[:, None, None, :].expand(batch_size, 1, tgt_len, seq_len) + padding_mask = padding_mask[:, None, None, :].expand(batch_size, 1, tgt_len, required_key_len) + combined_mask = padding_mask if is_causal: - causal_mask = torch.tril(torch.ones((tgt_len, seq_len), device=attention_mask.device, dtype=torch.bool)) - causal_mask = causal_mask[None, None, :, :] - combined_mask = padding_mask | ~causal_mask - else: - combined_mask = padding_mask + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, + past_key_values_length + tgt_len, + device=device, + ) + cache_position = cache_position.to(device=device, dtype=torch.long) + if cache_position.dim() == 1: # pragma: no cover - defensive guard + cache_position = cache_position.unsqueeze(0) + if cache_position.dim() != 2: # pragma: no cover - defensive guard + raise ValueError("`cache_position` must be 1D or 2D tensor") + if cache_position.size(-1) != tgt_len: # pragma: no cover - defensive guard + raise ValueError("`cache_position` length must match target length") + if cache_position.size(0) == 1 and batch_size > 1: # pragma: no cover - defensive guard + cache_position = cache_position.expand(batch_size, -1) + elif cache_position.size(0) not in {1, batch_size}: # pragma: no cover - defensive guard + raise ValueError("`cache_position` batch dimension must be 1 or match batch size") + if cache_position.size(0) != batch_size: # pragma: no cover - defensive guard + cache_position = cache_position.expand(batch_size, -1) - float_mask = torch.zeros_like(combined_mask, dtype=dtype, device=attention_mask.device) - if mask_value is not None: - float_mask = float_mask.masked_fill(combined_mask, mask_value) - else: - float_mask = float_mask.masked_fill(combined_mask, torch.finfo(dtype).min) + key_positions = torch.arange(required_key_len, device=device) + causal_mask = cache_position.unsqueeze(-1) >= key_positions # (B, tgt_len, key_len) + causal_mask = causal_mask[:, None, :, :] + combined_mask = combined_mask | ~causal_mask + + float_mask = torch.zeros_like(combined_mask, dtype=dtype, device=device) + fill_value = mask_value if mask_value is not None else torch.finfo(dtype).min + float_mask = float_mask.masked_fill(combined_mask, fill_value) return float_mask diff --git a/src/genrec/trainers/trainer_quantizer/base.py b/src/genrec/trainers/trainer_quantizer/base.py index 8159f65..6dc8ff2 100644 --- a/src/genrec/trainers/trainer_quantizer/base.py +++ b/src/genrec/trainers/trainer_quantizer/base.py @@ -7,7 +7,7 @@ from functools import partial from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Tuple, Type, TypeVar, Union -from jaxtyping import Float +from jaxtyping import Float, Int import torch import torch.nn as nn from transformers import EvalPrediction, Trainer, TrainerCallback, TrainingArguments @@ -245,7 +245,7 @@ def initialize_codebooks(self) -> None: assert isinstance( self.train_dataset, QuantizerDataset ), "Train dataset must be an instance of QuantizerDataset." - item_embeddings = self.train_dataset.item_embeddings + item_embeddings = self.train_dataset.item_textual_embeddings assert item_embeddings is not None, "Item embeddings are required to initialize codebooks." model.initialize_codebooks(torch.from_numpy(item_embeddings).to(model.device)) @@ -255,7 +255,7 @@ def compute_loss( inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False, num_items_in_batch: Optional[torch.Tensor] = None, - ) -> Union[Float[torch.Tensor, ""], Tuple[Float[torch.Tensor, ""], Dict[str, torch.Tensor]]]: + ) -> Union[Float[torch.Tensor, ""], Tuple[Float[torch.Tensor, ""], QuantizerOutput]]: """Computes the loss for a batch of inputs. Args: @@ -267,11 +267,10 @@ def compute_loss( valid items in each sequence in the batch (excluding padding). Returns: - Union[Float[torch.Tensor, ""], Tuple[Float[torch.Tensor, ""], Dict[str, torch.Tensor]]]: - Either the scalar loss or a tuple containing the loss and a dictionary with - loss and top-k indices of predicted items when `return_outputs` is True. The loss - combines the model-specific loss (if provided) with the quantizer losses via - `compute_quantizer_loss`. + Union[Float[torch.Tensor, ""], Tuple[Float[torch.Tensor, ""], QuantizerOutput]]: + Either the scalar loss or a tuple containing the loss and the raw `QuantizerOutput` + when `return_outputs` is True. The loss combines the model-specific loss (if provided) + with the quantizer losses via `compute_quantizer_loss`. """ model = model.module if hasattr(model, "module") else model # type: ignore - for distributed training assert isinstance(model, QuantizerModel), "Model must be an instance of QuantizerModel." @@ -291,24 +290,8 @@ def compute_loss( loss = quantizer_loss if return_outputs: - assert outputs.semantic_ids is not None, "Semantic IDs must be available in outputs." - assert outputs.reconstruction_loss is not None, "Reconstruction loss should be available in outputs." - assert outputs.codebook_loss is not None, "Codebook loss should be available in outputs." - assert outputs.commitment_loss is not None, "Commitment loss should be available in outputs." - assert "item_id" in inputs and isinstance( - inputs["item_id"], torch.Tensor - ), "Input batch must contain 'item_id' tensor." - output_dict: Dict[str, torch.Tensor] = { - "loss": loss, - "semantic_ids": outputs.semantic_ids, - "reconstruction_loss": outputs.reconstruction_loss, - "codebook_loss": outputs.codebook_loss, - "commitment_loss": outputs.commitment_loss, - "item_id": inputs["item_id"], - } - return loss, output_dict - else: - return loss + return loss, outputs + return loss @abstractmethod def compute_quantizer_loss( # pragma: no cover - abstract method @@ -333,3 +316,70 @@ def compute_quantizer_loss( # pragma: no cover - abstract method Float[torch.Tensor, ""]: Scalar tensor representing the computed quantizer loss. """ pass + + def _build_prediction_outputs( + self, + outputs: QuantizerOutput, + inputs: dict[str, Union[torch.Tensor, Any]], + ) -> tuple[torch.Tensor, ...]: + """Package tensors required by quantizer metrics into a tuple.""" + + assert outputs.semantic_ids is not None, "Semantic IDs must be available in outputs." + assert outputs.reconstruction_loss is not None, "Reconstruction loss should be available in outputs." + assert outputs.codebook_loss is not None, "Codebook loss should be available in outputs." + assert outputs.commitment_loss is not None, "Commitment loss should be available in outputs." + item_id = inputs.get("item_id") + assert isinstance(item_id, torch.Tensor), "Input batch must contain 'item_id' tensor." + + semantic_ids: Int[torch.Tensor, "B C"] = outputs.semantic_ids.detach() + reconstruction_loss: Float[torch.Tensor, "B"] = outputs.reconstruction_loss.detach() + codebook_loss: Float[torch.Tensor, "B"] = outputs.codebook_loss.detach() + commitment_loss: Float[torch.Tensor, "B"] = outputs.commitment_loss.detach() + item_id_tensor: Int[torch.Tensor, "B"] = item_id.detach() + + return (semantic_ids, reconstruction_loss, codebook_loss, commitment_loss, item_id_tensor) + + def prediction_step( # type: ignore[override] + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[list[str]] = None, + ) -> tuple[Optional[torch.Tensor], Optional[tuple[torch.Tensor, ...]], Optional[torch.Tensor]]: + """Evaluation step that surfaces loss plus the semantic-ID payloads required by the metrics. + + Args: + model (nn.Module): Quantizer under evaluation. May be wrapped by `nn.DataParallel`/`nn.DistributedDataParallel`. + inputs (dict[str, Union[torch.Tensor, Any]]): Batch prepared by :class:`QuantizerCollator`. + prediction_loss_only (bool): Whether to suppress prediction tensors and only output the loss. + ignore_keys (Optional[list[str]]): Present for :class:`~transformers.Trainer` compatibility; unused. + + Returns: + tuple[Optional[torch.Tensor], Optional[tuple[torch.Tensor, ...]], Optional[Int[torch.Tensor, "B"]]]: + `(loss, payload, labels)` where `payload` matches the tuple expected by + :func:`compute_quantizer_metrics` and `labels` is just the detached `item_id` tensor. + """ + + inputs = self._prepare_inputs(inputs) + + label_tensor = inputs[self.label_names[0]] + assert isinstance(label_tensor, torch.Tensor), "Item IDs must be a tensor." + labels: Int[torch.Tensor, "B"] = label_tensor.detach() + + with torch.no_grad(): + num_items_in_batch = self._get_num_items_in_batch([inputs], self.args.device) + loss, outputs = self.compute_loss( + model, + inputs, + return_outputs=True, + num_items_in_batch=num_items_in_batch, # type: ignore - num_items_in_batch can be int + ) + loss = loss.detach().mean() if loss is not None else None + + if prediction_loss_only: # pragma: no cover - prediction loss only + return loss, None, None + + assert isinstance(outputs, QuantizerOutput) + predictions = tuple(t.detach() for t in self._build_prediction_outputs(outputs, inputs)) + + return loss, predictions, labels diff --git a/src/genrec/trainers/trainer_quantizer/utils/evaluations.py b/src/genrec/trainers/trainer_quantizer/utils/evaluations.py index 8163f48..c7c4a9f 100644 --- a/src/genrec/trainers/trainer_quantizer/utils/evaluations.py +++ b/src/genrec/trainers/trainer_quantizer/utils/evaluations.py @@ -32,9 +32,9 @@ def compute_quantizer_metrics( Args: prediction (EvalPrediction): Object containing model predictions and labels. Predictions are - expected to be the dict values from `QuantizerTrainer.compute_loss`'s output dict, - at least including `semantic_ids`, `reconstruction_loss`, `codebook_loss`, `commitment_loss`, - and `item_id` in the first 5 elements. The labels are ignored for quantizer metrics. + expected to be the tuple returned by `QuantizerTrainer.prediction_step`, whose first + five elements are `semantic_ids`, `reconstruction_loss`, `codebook_loss`, + `commitment_loss`, and `item_id`. The labels are ignored for quantizer metrics. train_dataset (QuantizerDataset): Dataset used during training; required for global metrics. codebook_size (int): Size of the codebook used in the quantizer. metrics (Sequence[Tuple[str, Dict[str, Any]]]): Metric specifications, where each tuple @@ -55,7 +55,7 @@ def compute_quantizer_metrics( commitment_loss: Float[np.ndarray, "B"] = prediction.predictions[3] item_id: Int[np.ndarray, "B"] = prediction.predictions[4] else: - raise ValueError("Predictions should be a tuple containing model output dict's values.") + raise ValueError("Predictions should be a tuple matching QuantizerTrainer.prediction_step outputs.") results: Dict[str, float] = { "reconstruction_loss": float(np.mean(reconstruction_loss)), diff --git a/src/genrec/trainers/trainer_seqrec/base.py b/src/genrec/trainers/trainer_seqrec/base.py index f2d1ca4..4bcea96 100644 --- a/src/genrec/trainers/trainer_seqrec/base.py +++ b/src/genrec/trainers/trainer_seqrec/base.py @@ -7,7 +7,7 @@ from functools import partial from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Tuple, Type, TypeVar, Union -from jaxtyping import Float +from jaxtyping import Float, Int import torch import torch.nn as nn import torch.nn.functional as F @@ -266,7 +266,7 @@ def compute_loss( inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False, num_items_in_batch: Optional[torch.Tensor] = None, - ) -> Union[Float[torch.Tensor, ""], Tuple[Float[torch.Tensor, ""], Dict[str, torch.Tensor]]]: + ) -> Union[Float[torch.Tensor, ""], Tuple[Float[torch.Tensor, ""], SeqRecOutput]]: """Computes the loss for a batch of inputs. Args: @@ -278,11 +278,10 @@ def compute_loss( valid items in each sequence in the batch (excluding padding). Returns: - Union[Float[torch.Tensor, ""], Tuple[Float[torch.Tensor, ""], Dict[str, torch.Tensor]]]: - Either the scalar loss or a tuple containing the loss and a dictionary with - loss and top-k indices of predicted items when `return_outputs` is True. The loss - combines the model-specific loss (if provided) with the recommendation loss computed - via `compute_rec_loss`. + Union[Float[torch.Tensor, ""], Tuple[Float[torch.Tensor, ""], SeqRecOutput]]: + Either the scalar loss or a tuple containing the loss and the raw `SeqRecOutput` + when `return_outputs` is True. The loss combines the model-specific loss (if provided) + with the recommendation loss computed via `compute_rec_loss`. """ model = model.module if hasattr(model, "module") else model # type: ignore - for distributed training assert isinstance(model, SeqRecModel), "Model must be an instance of SeqRecModel." @@ -300,27 +299,8 @@ def compute_loss( loss = rec_loss if return_outputs: - last_step_hidden_states: Float[torch.Tensor, "B d"] - last_step_hidden_states = outputs.last_hidden_state[:, -1, :] - item_embed_weight: Float[torch.Tensor, "I+1 d"] = model.item_embed_weight - - if self.args.norm_embeddings: - last_step_hidden_states = F.normalize(last_step_hidden_states, p=2, dim=-1) - item_embed_weight = F.normalize(item_embed_weight, p=2, dim=-1) - - logits: Float[torch.Tensor, "B I+1"] - logits = last_step_hidden_states @ item_embed_weight.T - - effective_top_k = max(1, min(self.max_top_k, self.item_size)) - _, topk_indices = torch.topk(logits, k=effective_top_k, dim=-1) # may predict padding index - - output_dict: Dict[str, torch.Tensor] = { - "loss": loss, - "topk_indices": topk_indices.detach(), - } - return loss, output_dict - else: - return loss + return loss, outputs + return loss @abstractmethod def compute_rec_loss( # pragma: no cover - abstract method @@ -345,3 +325,74 @@ def compute_rec_loss( # pragma: no cover - abstract method Float[torch.Tensor, ""]: Computed recommendation loss as a scalar tensor. """ ... + + def _compute_topk_indices( + self, + outputs: SeqRecOutput, + model: SeqRecModel[Any, Any], + ) -> Int[torch.Tensor, "B K"]: + """Compute top-K item predictions from the current model outputs.""" + + last_step_hidden_states: Float[torch.Tensor, "B d"] + last_step_hidden_states = outputs.last_hidden_state[:, -1, :] + item_embed_weight: Float[torch.Tensor, "I+1 d"] = model.item_embed_weight + + if self.args.norm_embeddings: + last_step_hidden_states = F.normalize(last_step_hidden_states, p=2, dim=-1) + item_embed_weight = F.normalize(item_embed_weight, p=2, dim=-1) + + logits: Float[torch.Tensor, "B I+1"] + logits = last_step_hidden_states @ item_embed_weight.T + + effective_top_k = max(1, min(self.max_top_k, self.item_size)) + _, topk_indices = torch.topk(logits, k=effective_top_k, dim=-1) # may predict padding index + + return topk_indices + + def prediction_step( # type: ignore[override] + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[list[str]] = None, + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """Evaluation forward pass that returns mean loss, top-K indices, and detached labels. + + Args: + model (nn.Module): Model under evaluation. May be wrapped for distributed training. + inputs (dict[str, Union[torch.Tensor, Any]]): Batch produced by :class:`SeqRecCollator`. + prediction_loss_only (bool): Whether to skip prediction tensors and only surface the loss. + ignore_keys (Optional[list[str]]): Unused placeholder required by :class:`~transformers.Trainer`. + + Returns: + tuple[Optional[torch.Tensor], Optional[Int[torch.Tensor, "B K"]], Optional[Int[torch.Tensor, "B L"]]]: + `(loss, topk_indices, labels)` where `loss` is the batch-averaged value, `topk_indices` matches + `max_top_k` from the trainer config, and `labels` mirrors the collator's ground-truth tensor. + """ + + inputs = self._prepare_inputs(inputs) + + label_tensor = inputs[self.label_names[0]] + assert isinstance(label_tensor, torch.Tensor), "Labels must be a tensor." + labels: Int[torch.Tensor, "B L"] = label_tensor.detach() + + with torch.no_grad(): + num_items_in_batch = self._get_num_items_in_batch([inputs], self.args.device) + loss, outputs = self.compute_loss( + model, + inputs, + return_outputs=True, + num_items_in_batch=num_items_in_batch, # type: ignore - num_items_in_batch can be int + ) + loss = loss.detach().mean() if loss is not None else None + + if prediction_loss_only: # pragma: no cover - prediction loss only + return loss, None, None + + unwrapped_model = model.module if hasattr(model, "module") else model # type: ignore - distributed + assert isinstance(unwrapped_model, SeqRecModel) + assert isinstance(outputs, SeqRecOutput) + predictions: Int[torch.Tensor, "B K"] + predictions = self._compute_topk_indices(outputs, unwrapped_model).detach() + + return loss, predictions, labels diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 9449c77..6409b8e 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -75,6 +75,13 @@ def _make_short_interaction_frame(length: int) -> pd.DataFrame: ) +def _make_sid_cache(item_pool: int, sid_width: int = 3) -> np.ndarray: + cache = np.zeros((item_pool + 1, sid_width), dtype=np.int64) + for item_id in range(1, item_pool + 1): + cache[item_id] = np.arange(item_id, item_id + sid_width, dtype=np.int64) + return cache + + def _expected_batch_sizes(dataset_len: int, batch_size: int) -> list[int]: full_batches, remainder = divmod(dataset_len, batch_size) sizes = [batch_size] * full_batches @@ -83,19 +90,20 @@ def _expected_batch_sizes(dataset_len: int, batch_size: int) -> list[int]: return sizes -def _assert_genrec_batches(loader: DataLoader, expected_sizes: list[int], num_negatives: int) -> None: +def _assert_genrec_batches(loader: DataLoader, expected_sizes: list[int]) -> None: batches = list(loader) assert len(batches) == len(expected_sizes) for idx, batch in enumerate(batches): expected_size = expected_sizes[idx] assert batch["user_id"].shape == (expected_size,) - assert batch["labels"].shape == (expected_size,) + assert batch["labels"].ndim == 2 + sid_width = batch["labels"].shape[1] + assert batch["input_ids"].shape == batch["attention_mask"].shape assert batch["input_ids"].shape[0] == expected_size - assert batch["attention_mask"].shape == batch["input_ids"].shape - assert batch["timestamps"].shape == batch["input_ids"].shape - assert batch["input_ids"].shape[1] <= 4 - assert batch["negative_item_ids"].shape == (expected_size, num_negatives) - assert batch["negative_item_ids"].dtype == torch.int32 + assert batch["input_item_ids"].shape == batch["timestamps"].shape + assert batch["input_item_ids"].shape[0] == expected_size + assert batch["input_ids"].shape[1] == batch["input_item_ids"].shape[1] * sid_width + assert batch["labels"].shape == (expected_size, sid_width) def _assert_seqrec_batches(loader: DataLoader, expected_sizes: list[int], num_negatives: int) -> None: @@ -161,12 +169,7 @@ def textual_frame() -> pd.DataFrame: @pytest.fixture() def sid_cache() -> np.ndarray: - sid_width = 3 - cache = np.zeros((8, sid_width), dtype=np.int64) - # Populate rows 1..7 with unique token patterns. - for item_id in range(1, 8): - cache[item_id] = np.array([item_id, item_id + 10, item_id + 20], dtype=np.int64) - return cache + return _make_sid_cache(item_pool=7, sid_width=3) @pytest.fixture() @@ -244,14 +247,16 @@ def test_genrec_dataset_examples(genrec_dataset, sid_cache, dummy_encoder): assert genrec_dataset.user_size == 3 assert genrec_dataset.item_size == 7 assert genrec_dataset.sid_width == sid_cache.shape[1] - assert genrec_dataset.embedding_dim == dummy_encoder.embedding_dim + assert genrec_dataset.sid_cache.shape == sid_cache.shape + assert genrec_dataset.textual_embedding_dim == dummy_encoder.embedding_dim example = genrec_dataset[0] assert example.user_id in {0, 1, 2} - assert example.input_ids.ndim == 1 - assert example.timestamps.shape == example.input_ids.shape - assert example.input_sid_tokens.shape[1] == sid_cache.shape[1] - assert example.target_sid_tokens.shape == (sid_cache.shape[1],) + assert example.input_item_ids.ndim == 1 + assert example.timestamps.shape == example.input_item_ids.shape + assert example.input_ids.shape == (example.input_item_ids.shape[0] * sid_cache.shape[1],) + assert example.labels.shape == (sid_cache.shape[1],) + assert example.label_item_ids in genrec_dataset.user_positive_items[example.user_id] assert example.input_embeddings.shape[1] == dummy_encoder.embedding_dim assert example.target_embedding.shape[0] == dummy_encoder.embedding_dim @@ -269,55 +274,42 @@ def test_genrec_dataset_examples(genrec_dataset, sid_cache, dummy_encoder): def test_genrec_collator_with_dataloader(genrec_dataset, dummy_encoder): - config = GenRecCollatorConfig(num_negative_samples=2, need_sid_tokens=True, need_embeddings=True) + config = GenRecCollatorConfig(need_embeddings=True) collator = GenRecCollator(genrec_dataset, config=config, seed=123) - max_item_id = genrec_dataset.item_size - - def fake_negative_sampler(history, num_samples, batch_seed=None): - return np.full((history.shape[0], num_samples), fill_value=max_item_id, dtype=np.int32) - - collator._negative_sampler = fake_negative_sampler # type: ignore[assignment] - loader = DataLoader(genrec_dataset, batch_size=2, collate_fn=collator) batch = next(iter(loader)) assert batch["input_ids"].shape[0] == 2 assert batch["attention_mask"].shape == batch["input_ids"].shape - assert batch["timestamps"].shape == batch["input_ids"].shape - assert batch["input_sid_tokens"].shape[2] == genrec_dataset.sid_width - assert batch["target_sid_tokens"].shape == (2, genrec_dataset.sid_width) + assert batch["input_item_ids"].shape == batch["timestamps"].shape + assert batch["input_ids"].shape[1] == batch["input_item_ids"].shape[1] * genrec_dataset.sid_width + assert batch["labels"].shape == (2, genrec_dataset.sid_width) + assert batch["label_item_ids"].shape == (2,) assert batch["input_embeddings"].shape[2] == dummy_encoder.embedding_dim assert batch["target_embedding"].shape == (2, dummy_encoder.embedding_dim) - assert batch["negative_item_ids"].shape == (2, 2) - assert batch["negative_sid_tokens"].shape == (2, 2, genrec_dataset.sid_width) - assert batch["negative_embeddings"].shape == (2, 2, dummy_encoder.embedding_dim) for tensor in batch.values(): assert isinstance(tensor, torch.Tensor) -def test_genrec_collator_without_sid_or_embedding_features(interaction_frame): +def test_genrec_collator_without_sid_or_embedding_features(interaction_frame, sid_cache): dataset = GenRecDataset( interaction_data_path=interaction_frame, split=DatasetSplitLiteral.TRAIN, max_seq_length=3, min_seq_length=1, + sid_cache=sid_cache, ) - config = GenRecCollatorConfig(num_negative_samples=1, need_sid_tokens=False, need_embeddings=False) + config = GenRecCollatorConfig(need_embeddings=False) collator = GenRecCollator(dataset, config=config, seed=11) - def fake_negative_sampler(history, num_samples, batch_seed=None): - return np.full((history.shape[0], num_samples), fill_value=dataset.item_size, dtype=np.int32) - - collator._negative_sampler = fake_negative_sampler # type: ignore[assignment] - loader = DataLoader(dataset, batch_size=2, collate_fn=collator) batch = next(iter(loader)) - assert batch["negative_item_ids"].shape == (2, 1) - assert "negative_sid_tokens" not in batch - assert "negative_embeddings" not in batch + assert "input_embeddings" not in batch + assert "target_embedding" not in batch + assert batch["labels"].shape[1] == sid_cache.shape[1] def test_seqrec_dataset_and_collator(seqrec_dataset): @@ -366,12 +358,15 @@ def test_quantizer_dataset_train_item_popularity_defaults_to_global(quantizer_da (DatasetSplitLiteral.TEST, [1, 2, 3, 4], 5), ], ) -def test_genrec_iter_split_handles_eval_and_test(interaction_frame, split, expected_context, expected_target): +def test_genrec_iter_split_handles_eval_and_test( + interaction_frame, sid_cache, split, expected_context, expected_target +): dataset = GenRecDataset( interaction_data_path=interaction_frame, split=split, max_seq_length=5, min_seq_length=1, + sid_cache=sid_cache, ) items = dataset.user_interactions[0] times = dataset.user_interaction_timestamps[0] @@ -385,6 +380,97 @@ def test_genrec_iter_split_handles_eval_and_test(interaction_frame, split, expec assert context_times.tolist() == times[: len(expected_context)].tolist() +def test_genrec_item_size_prefers_textual_titles_without_encoding(dummy_encoder): + interaction_frame = _make_short_interaction_frame(length=3) + textual_frame = _make_textual_frame(item_pool=6) + sid_cache_short = _make_sid_cache(item_pool=3) + sid_cache_text = _make_sid_cache(item_pool=6) + + dataset_text_only = GenRecDataset( + interaction_data_path=interaction_frame, + split=DatasetSplitLiteral.TRAIN, + max_seq_length=4, + min_seq_length=1, + textual_data_path=textual_frame, + sid_cache=sid_cache_text, + ) + assert dataset_text_only.item_size == 6 + assert dataset_text_only.item_textual_embeddings is None + + dataset_no_textual = GenRecDataset( + interaction_data_path=interaction_frame, + split=DatasetSplitLiteral.TRAIN, + max_seq_length=4, + min_seq_length=1, + sid_cache=sid_cache_short, + ) + assert dataset_no_textual.item_size == 3 + + dataset_with_encoder = GenRecDataset( + interaction_data_path=interaction_frame, + split=DatasetSplitLiteral.TRAIN, + max_seq_length=4, + min_seq_length=1, + textual_data_path=textual_frame, + lm_encoder=dummy_encoder, + sid_cache=sid_cache_text, + ) + assert dataset_with_encoder.item_size == 6 + assert dataset_with_encoder.item_textual_embeddings is not None + assert dataset_with_encoder.textual_embedding_dim == dummy_encoder.embedding_dim + + +def test_genrec_tail_truncation_keeps_recent_history_only(): + frame = _make_short_interaction_frame(length=12) + raw_items = np.array(frame.iloc[0]["ItemID"], dtype=np.int64) + raw_times = np.array(frame.iloc[0]["Timestamp"], dtype=np.int64) + max_seq_length = 4 + sid_cache = _make_sid_cache(item_pool=int(raw_items.max())) + + dataset = GenRecDataset( + interaction_data_path=frame, + split=DatasetSplitLiteral.TRAIN, + max_seq_length=max_seq_length, + min_seq_length=1, + truncation_strategy="tail", + sid_cache=sid_cache, + ) + + truncated_items, truncated_times = dataset._tail_truncate(raw_items, raw_times) + np.testing.assert_array_equal(truncated_items, raw_items[-(max_seq_length + 3) :]) + np.testing.assert_array_equal(truncated_times, raw_times[-(max_seq_length + 3) :]) + assert len(dataset) == max_seq_length + + concatenated_contexts = np.concatenate([example.input_item_ids for example in dataset]) + assert concatenated_contexts.min() == truncated_items[0] + + +def test_genrec_slide_truncation_retains_full_history_for_windows(): + frame = _make_short_interaction_frame(length=12) + raw_items = np.array(frame.iloc[0]["ItemID"], dtype=np.int64) + raw_times = np.array(frame.iloc[0]["Timestamp"], dtype=np.int64) + max_seq_length = 4 + sid_cache = _make_sid_cache(item_pool=int(raw_items.max())) + + dataset = GenRecDataset( + interaction_data_path=frame, + split=DatasetSplitLiteral.TRAIN, + max_seq_length=max_seq_length, + min_seq_length=1, + truncation_strategy="slide", + sid_cache=sid_cache, + ) + + truncated_items, truncated_times = dataset._tail_truncate(raw_items, raw_times) + np.testing.assert_array_equal(truncated_items, raw_items) + np.testing.assert_array_equal(truncated_times, raw_times) + assert len(dataset) == raw_items.shape[0] - 3 + + concatenated_contexts = np.concatenate([example.input_item_ids for example in dataset]) + assert concatenated_contexts.min() == raw_items[0] + assert dataset[0].input_item_ids.tolist() == [int(raw_items[0])] + + @pytest.mark.parametrize("dataset_fixture", ["genrec_dataset", "seqrec_dataset"]) def test_dataset_item_popularity_matches_item_size(request, dataset_fixture): dataset = request.getfixturevalue(dataset_fixture) @@ -414,11 +500,13 @@ def test_quantizer_dataset_and_collator(quantizer_dataset, dummy_encoder): def test_genrec_iter_split_requires_minimum_length_for_validation(): frame = _make_short_interaction_frame(length=2) + sid_cache = _make_sid_cache(item_pool=2) dataset = GenRecDataset( interaction_data_path=frame, split=DatasetSplitLiteral.VALIDATION, max_seq_length=3, min_seq_length=1, + sid_cache=sid_cache, ) items = dataset.user_interactions[0] times = dataset.user_interaction_timestamps[0] @@ -427,11 +515,13 @@ def test_genrec_iter_split_requires_minimum_length_for_validation(): def test_genrec_iter_split_requires_minimum_length_for_test(): frame = _make_short_interaction_frame(length=1) + sid_cache = _make_sid_cache(item_pool=1) dataset = GenRecDataset( interaction_data_path=frame, split=DatasetSplitLiteral.TEST, max_seq_length=3, min_seq_length=1, + sid_cache=sid_cache, ) items = dataset.user_interactions[0] times = dataset.user_interaction_timestamps[0] @@ -441,12 +531,13 @@ def test_genrec_iter_split_requires_minimum_length_for_test(): def test_large_scale_multiworker_uniform_negative_sampler(): num_users = 10032 batch_size = 4096 - num_negatives = 256 seq_len = 4 item_pool = num_users + sid_width = 2 interactions = _make_large_interaction_frame(num_users, seq_len=seq_len, item_pool=item_pool) textual = _make_textual_frame(item_pool) + sid_cache = _make_sid_cache(item_pool=item_pool, sid_width=sid_width) expected_sizes = _expected_batch_sizes(num_users, batch_size) expected_batches = math.ceil(num_users / batch_size) assert len(expected_sizes) == expected_batches @@ -463,19 +554,11 @@ def test_large_scale_multiworker_uniform_negative_sampler(): split=split, max_seq_length=3, min_seq_length=1, + sid_cache=sid_cache, ) assert len(dataset) == num_users assert dataset.split == split - collator = GenRecCollator( - dataset, - config=GenRecCollatorConfig( - num_negative_samples=num_negatives, - need_sid_tokens=False, - need_embeddings=False, - ), - seed=2025, - ) - assert collator._negative_sampler.__class__.__name__ == "UniformNegativeSampler" + collator = GenRecCollator(dataset, config=GenRecCollatorConfig(need_embeddings=False), seed=2025) loader = DataLoader( dataset, batch_size=batch_size, @@ -484,7 +567,7 @@ def test_large_scale_multiworker_uniform_negative_sampler(): shuffle=False, multiprocessing_context="fork", ) - _assert_genrec_batches(loader, expected_sizes, num_negatives) + _assert_genrec_batches(loader, expected_sizes) del loader for split in splits: @@ -498,7 +581,7 @@ def test_large_scale_multiworker_uniform_negative_sampler(): assert dataset.split == split collator = SeqRecCollator( dataset, - config=SeqRecCollatorConfig(num_negative_samples=num_negatives), + config=SeqRecCollatorConfig(num_negative_samples=256), seed=2025, ) assert collator._negative_sampler.__class__.__name__ == "UniformNegativeSampler" @@ -510,7 +593,7 @@ def test_large_scale_multiworker_uniform_negative_sampler(): shuffle=False, multiprocessing_context="fork", ) - _assert_seqrec_batches(loader, expected_sizes, num_negatives) + _assert_seqrec_batches(loader, expected_sizes, 256) del loader encoder = DummyEncoder() diff --git a/tests/models/model_genrec/test_tiger.py b/tests/models/model_genrec/test_tiger.py new file mode 100644 index 0000000..e61b6ae --- /dev/null +++ b/tests/models/model_genrec/test_tiger.py @@ -0,0 +1,384 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +import torch +import torch.nn as nn +from types import MethodType + +from genrec.datasets import DatasetSplitLiteral, GenRecCollator, GenRecDataset +from genrec.models.model_genrec.tiger import TIGERModel, TIGERModelConfig + + +def _make_interaction_frame() -> pd.DataFrame: + return pd.DataFrame( + { + "UserID": np.array([0, 1, 2], dtype=np.int64), + "ItemID": [ + [1, 2, 3, 4, 5], + [2, 3, 4, 5, 6], + [3, 4, 5, 6, 7], + ], + "Timestamp": [ + [100, 101, 102, 103, 104], + [200, 201, 202, 203, 204], + [300, 301, 302, 303, 304], + ], + } + ) + + +def _make_sid_cache(num_items: int, sid_width: int = 3) -> np.ndarray: + cache = np.zeros((num_items + 1, sid_width), dtype=np.int64) + for item_id in range(1, num_items + 1): + base = item_id * sid_width + cache[item_id] = np.arange(base, base + sid_width, dtype=np.int64) + return cache + + +def _build_batch(batch_size: int = 2) -> tuple[dict[str, torch.Tensor], np.ndarray]: + sid_cache = _make_sid_cache(num_items=7, sid_width=3) + dataset = GenRecDataset( + interaction_data_path=_make_interaction_frame(), + split=DatasetSplitLiteral.TRAIN, + max_seq_length=3, + min_seq_length=1, + sid_cache=sid_cache, + ) + collator = GenRecCollator(dataset) + examples = [dataset[idx] for idx in range(batch_size)] + batch = collator(examples) + return batch, sid_cache + + +def _make_config( + vocab_size: int, + *, + enable_rope: bool = False, + tie_word_embeddings: bool = True, +) -> TIGERModelConfig: + return TIGERModelConfig( + hidden_size=32, + num_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + linear_dropout=0.1, + attention_dropout=0.0, + attention_bias=False, + ffn_bias=False, + enable_rope=enable_rope, + vocab_size=vocab_size, + decoder_start_token_id=0, + pad_token_id=0, + dropout_rate=0.0, + use_cache=True, + tie_word_embeddings=tie_word_embeddings, + ) + + +def test_tiger_model_forward_matches_batch_shapes() -> None: + batch, sid_cache = _build_batch(batch_size=2) + vocab_size = int(np.max(sid_cache) + 32) + model = TIGERModel(_make_config(vocab_size)) + + output = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"], + decoder_attention_mask=torch.ones_like(batch["labels"]), + ) + + batch_size, enc_seq_len = batch["input_ids"].shape + label_width = batch["labels"].shape[1] + + assert output.logits.shape == (batch_size, label_width, vocab_size) + assert output.encoder_last_hidden_state.shape == (batch_size, enc_seq_len, model.config.hidden_size) + assert output.past_key_values is not None + assert output.model_loss is None + + +def test_tiger_model_returns_optional_collections() -> None: + batch, sid_cache = _build_batch(batch_size=2) + vocab_size = int(np.max(sid_cache) + 32) + model = TIGERModel(_make_config(vocab_size)) + + output = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"], + decoder_attention_mask=torch.ones_like(batch["labels"]), + output_hidden_states=True, + output_attentions=True, + ) + + assert output.decoder_hidden_states is not None + assert len(output.decoder_hidden_states) == model.config.num_decoder_layers + 1 + for layer_state in output.decoder_hidden_states: + assert layer_state.shape[-1] == model.config.hidden_size + + assert output.decoder_attentions is not None + assert len(output.decoder_attentions) == model.config.num_decoder_layers + for attn in output.decoder_attentions: + assert attn.shape[1] == model.config.num_heads + + assert output.cross_attentions is not None + assert len(output.cross_attentions) == model.config.num_decoder_layers + + assert output.encoder_hidden_states is not None + assert len(output.encoder_hidden_states) == model.config.num_encoder_layers + 1 + + assert output.encoder_attentions is not None + assert len(output.encoder_attentions) == model.config.num_encoder_layers + + +def test_tiger_model_reuses_past_key_values_for_incremental_generation() -> None: + batch, sid_cache = _build_batch(batch_size=1) + vocab_size = int(np.max(sid_cache) + 32) + model = TIGERModel(_make_config(vocab_size)) + + first_step = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"], + decoder_attention_mask=torch.ones_like(batch["labels"]), + use_cache=True, + ) + + assert first_step.past_key_values is not None + initial_decoder_length = first_step.past_key_values.get_seq_length() + assert initial_decoder_length == batch["labels"].shape[1] + + next_token = batch["labels"][:, -1:] + second_step = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + decoder_input_ids=next_token, + decoder_attention_mask=torch.ones_like(next_token), + past_key_values=first_step.past_key_values, + use_cache=True, + ) + + assert second_step.logits.shape == (1, 1, vocab_size) + assert second_step.past_key_values is not None + assert second_step.past_key_values.get_seq_length() == initial_decoder_length + 1 + + +def test_tiger_model_uses_provided_encoder_outputs() -> None: + batch, sid_cache = _build_batch(batch_size=1) + vocab_size = int(np.max(sid_cache) + 32) + model = TIGERModel(_make_config(vocab_size)) + + with torch.no_grad(): + encoder_outputs = model.encoder( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + output_hidden_states=True, + output_attentions=True, + ) + + decoder_input_ids = batch["labels"] + decoder_attention_mask = torch.ones_like(decoder_input_ids) + output = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + ) + + assert output.encoder_last_hidden_state is encoder_outputs.last_hidden_state + assert output.encoder_attentions is encoder_outputs.attentions + assert output.encoder_hidden_states is encoder_outputs.hidden_states + + +def test_tiger_model_returns_model_loss_when_requested() -> None: + batch, sid_cache = _build_batch(batch_size=1) + vocab_size = int(np.max(sid_cache) + 32) + model = TIGERModel(_make_config(vocab_size)) + + output = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"], + decoder_attention_mask=torch.ones_like(batch["labels"]), + output_model_loss=True, + ) + + assert output.model_loss is not None + torch.testing.assert_close(output.model_loss, torch.tensor(0.0, device=output.model_loss.device)) + + +def test_tiger_encoder_stack_rejects_cache_usage() -> None: + batch, sid_cache = _build_batch(batch_size=1) + vocab_size = int(np.max(sid_cache) + 32) + model = TIGERModel(_make_config(vocab_size)) + + with pytest.raises(ValueError): + model.encoder( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + use_cache=True, + ) + + +def test_tiger_decoder_turns_off_cache_with_gradient_checkpointing() -> None: + batch, sid_cache = _build_batch(batch_size=1) + vocab_size = int(np.max(sid_cache) + 32) + model = TIGERModel(_make_config(vocab_size)) + + encoder_outputs = model.encoder( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + ) + + decoder = model.decoder + decoder.gradient_checkpointing_enable() + decoder.train() + + outputs = decoder( + input_ids=batch["labels"], + attention_mask=torch.ones_like(batch["labels"]), + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=batch["attention_mask"], + use_cache=True, + ) + + assert outputs.past_key_values is None + + +def test_tiger_decoder_builds_attention_mask_when_missing() -> None: + batch, sid_cache = _build_batch(batch_size=1) + vocab_size = int(np.max(sid_cache) + 32) + model = TIGERModel(_make_config(vocab_size)) + + encoder_outputs = model.encoder( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + ) + + outputs = model.decoder( + input_ids=batch["labels"], + attention_mask=None, + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=batch["attention_mask"], + ) + + assert outputs.last_hidden_state.shape[0] == batch["labels"].shape[0] + + +def test_tiger_decoder_passes_cache_position_to_rope() -> None: + batch, sid_cache = _build_batch(batch_size=2) + vocab_size = int(np.max(sid_cache) + 32) + model = TIGERModel(_make_config(vocab_size, enable_rope=True)) + + class RecordingRotary(torch.nn.Module): + def __init__(self, head_dim: int) -> None: + super().__init__() + self.position_ids: torch.Tensor | None = None + self.head_dim = head_dim + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if position_ids is not None: + self.position_ids = position_ids.clone() + sliced = hidden_states[..., : self.head_dim] + return sliced, sliced + + recorder = RecordingRotary(model.config.head_dim) + model.decoder.rotary_emb = recorder + + encoder_outputs = model.encoder( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + ) + cache_position = torch.arange(batch["labels"].shape[1], device=batch["labels"].device) + + model.decoder( + input_ids=batch["labels"], + attention_mask=torch.ones_like(batch["labels"]), + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=batch["attention_mask"], + cache_position=cache_position, + ) + + assert recorder.position_ids is not None + assert recorder.position_ids.shape == (batch["labels"].shape[0], batch["labels"].shape[1]) + expected = cache_position.unsqueeze(0).expand(batch["labels"].shape[0], -1) + assert torch.equal(recorder.position_ids, expected) + + +def test_tiger_model_set_input_embeddings_propagates_to_stacks() -> None: + vocab_size = 64 + model = TIGERModel(_make_config(vocab_size)) + + new_embeddings = nn.Embedding(vocab_size, model.config.hidden_size) + model.set_input_embeddings(new_embeddings) + + assert model.get_input_embeddings() is new_embeddings + assert model.encoder.embed_tokens is new_embeddings + assert model.decoder.embed_tokens is new_embeddings + + +def test_tiger_decoder_supports_inputs_embeds_path() -> None: + batch, sid_cache = _build_batch(batch_size=1) + vocab_size = int(np.max(sid_cache) + 32) + model = TIGERModel(_make_config(vocab_size)) + + encoder_outputs = model.encoder( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + ) + + decoder_embeds = model.decoder.embed_tokens(batch["labels"]) + outputs = model.decoder( + input_ids=None, + inputs_embeds=decoder_embeds, + attention_mask=torch.ones_like(batch["labels"]), + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=batch["attention_mask"], + ) + + assert outputs.last_hidden_state.shape == ( + batch["labels"].shape[0], + batch["labels"].shape[1], + model.config.hidden_size, + ) + + +def test_tiger_decoder_infers_missing_encoder_attention_mask() -> None: + batch, sid_cache = _build_batch(batch_size=1) + vocab_size = int(np.max(sid_cache) + 32) + model = TIGERModel(_make_config(vocab_size)) + + encoder_outputs = model.encoder( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + ) + + outputs = model.decoder( + input_ids=batch["labels"], + attention_mask=torch.ones_like(batch["labels"]), + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=None, + ) + + assert outputs.last_hidden_state.shape[-1] == model.config.hidden_size + + +def test_tiger_model_allows_disabling_tied_embeddings() -> None: + vocab_size = 32 + model = TIGERModel(_make_config(vocab_size, tie_word_embeddings=False)) + + sentinel = [] + + def recorder(self: TIGERModel, module: nn.Module, shared: nn.Module) -> None: # type: ignore[override] + sentinel.append((module, shared)) + + model._tie_or_clone_weights = MethodType(recorder, model) + model._tie_weights() + + assert sentinel == [] diff --git a/tests/models/modules/test_attention.py b/tests/models/modules/test_attention.py index 8ab672c..b33c472 100644 --- a/tests/models/modules/test_attention.py +++ b/tests/models/modules/test_attention.py @@ -1,10 +1,49 @@ +from __future__ import annotations + import torch -from genrec.models.modules.attention import MaskedSelfAttentionWithRoPE +from genrec.models.modules.attention import MaskedSelfAttentionWithRoPE, T5Attention from genrec.models.modules.posemb import RotaryEmbedding from genrec.models.modules.utils import create_attention_mask +class _MockCacheLayer: + """Lightweight container to mimic cached key/value tensors.""" + + def __init__(self) -> None: + self.keys: torch.Tensor | None = None + self.values: torch.Tensor | None = None + + +class _MockCache: + """Minimal cache that supports the DynamicCache interface subset.""" + + def __init__(self, num_layers: int) -> None: + self.layers = [_MockCacheLayer() for _ in range(num_layers)] + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, info: dict | None): + cache_position = info.get("cache_position") if info is not None else None + layer = self.layers[layer_idx] + if layer.keys is None: + layer.keys = key_states + layer.values = value_states + else: + should_append = cache_position is None + if cache_position is not None: + should_append = int(cache_position.min().item()) >= layer.keys.size(-2) + if should_append: + layer.keys = torch.cat([layer.keys, key_states], dim=-2) + layer.values = torch.cat([layer.values, value_states], dim=-2) + return layer.keys, layer.values + + +class _MockEncoderDecoderCache: + def __init__(self, num_layers: int) -> None: + self.is_updated = {idx: False for idx in range(num_layers)} + self.self_attention_cache = _MockCache(num_layers) + self.cross_attention_cache = _MockCache(num_layers) + + def test_masked_self_attention_with_rope_preserves_shapes() -> None: hidden_size = 8 num_heads = 2 @@ -59,3 +98,273 @@ def test_masked_self_attention_without_rope_or_mask_behaves_well() -> None: assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len) row_sums = attn_weights.sum(dim=-1) torch.testing.assert_close(row_sums, torch.ones_like(row_sums)) + + +def test_t5_attention_with_relative_bias_and_padding_mask() -> None: + hidden_size = 12 + num_heads = 3 + head_dim = hidden_size // num_heads + + attention = T5Attention( + hidden_size=hidden_size, + head_dim=head_dim, + num_heads=num_heads, + attention_dropout=0.0, + attention_bias=True, + is_decoder=False, + has_relative_attention_bias=True, + enable_rope=False, + ) + attention.eval() + + batch_size, seq_len = 2, 5 + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + base_mask = torch.tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]], dtype=torch.long) + attention_mask = create_attention_mask(base_mask, is_causal=False) + + attn_output, position_bias, attn_weights = attention( + hidden_states, + attention_mask=attention_mask, + output_attentions=True, + ) + + assert attn_output.shape == (batch_size, seq_len, hidden_size) + assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len) + assert position_bias is not None + assert torch.count_nonzero(position_bias).item() > 0 + torch.testing.assert_close(attn_weights[..., -1], torch.zeros_like(attn_weights[..., -1])) + + +def test_t5_attention_switches_to_rope_when_bias_disabled() -> None: + torch.manual_seed(42) + hidden_size = 12 + num_heads = 2 + head_dim = hidden_size // num_heads + + rotary = RotaryEmbedding(head_dim=head_dim) + + torch.manual_seed(123) + attention_with_rope = T5Attention( + hidden_size=hidden_size, + head_dim=head_dim, + num_heads=num_heads, + attention_dropout=0.0, + attention_bias=False, + has_relative_attention_bias=False, + enable_rope=True, + ) + attention_with_rope.eval() + + torch.manual_seed(123) + attention_without_rope = T5Attention( + hidden_size=hidden_size, + head_dim=head_dim, + num_heads=num_heads, + attention_dropout=0.0, + attention_bias=False, + has_relative_attention_bias=False, + enable_rope=False, + ) + attention_without_rope.eval() + + batch_size, seq_len = 1, 4 + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + position_embeddings = rotary(torch.zeros(batch_size, seq_len, head_dim)) + + rope_output, rope_bias, _ = attention_with_rope(hidden_states, position_embeddings=position_embeddings) + plain_output, plain_bias, _ = attention_without_rope(hidden_states) + + assert rope_bias is not None and torch.count_nonzero(rope_bias).item() == 0 + assert plain_bias is not None and torch.count_nonzero(plain_bias).item() == 0 + assert rope_output.shape == plain_output.shape + assert not torch.allclose(rope_output, plain_output) + + +def test_t5_attention_cross_attention_cache_reuse() -> None: + hidden_size = 9 + num_heads = 3 + head_dim = hidden_size // num_heads + + attention = T5Attention( + hidden_size=hidden_size, + head_dim=head_dim, + num_heads=num_heads, + attention_dropout=0.0, + attention_bias=False, + is_decoder=True, + layer_idx=0, + ) + attention.eval() + + cache = _MockEncoderDecoderCache(num_layers=1) + + batch_size, query_len, key_len = 1, 2, 4 + hidden_states = torch.randn(batch_size, query_len, hidden_size) + encoder_states = torch.randn(batch_size, key_len, hidden_size) + overlong_mask = torch.ones(batch_size, key_len + 2, dtype=torch.long) + attention_mask = create_attention_mask(overlong_mask, tgt_len=query_len, is_causal=False) + + attn_output, position_bias, attn_weights = attention( + hidden_states, + attention_mask=attention_mask, + key_value_states=encoder_states, + past_key_values=cache, + output_attentions=True, + ) + + assert attn_output.shape == (batch_size, query_len, hidden_size) + assert attn_weights.shape[-1] == key_len + assert position_bias is not None and position_bias.shape[-1] == key_len + assert cache.is_updated[0] + + cached_keys = cache.cross_attention_cache.layers[0].keys + cached_values = cache.cross_attention_cache.layers[0].values + assert cached_keys is not None and cached_keys.shape[-2] == key_len + assert cached_values is not None and cached_values.shape[-2] == key_len + + zero_encoder = torch.zeros_like(encoder_states) + cached_output, _, _ = attention( + hidden_states, + attention_mask=attention_mask, + key_value_states=zero_encoder, + past_key_values=cache, + output_attentions=True, + ) + + torch.testing.assert_close(attn_output, cached_output) + + +def test_t5_attention_self_attention_cache_position_appends_tokens() -> None: + hidden_size = 8 + num_heads = 2 + head_dim = hidden_size // num_heads + + attention = T5Attention( + hidden_size=hidden_size, + head_dim=head_dim, + num_heads=num_heads, + attention_dropout=0.0, + attention_bias=False, + is_decoder=True, + has_relative_attention_bias=True, + enable_rope=False, + layer_idx=0, + ) + attention.eval() + + cache = _MockEncoderDecoderCache(num_layers=1) + + first_hidden = torch.randn(1, 2, hidden_size) + first_mask = create_attention_mask(torch.ones(1, 2, dtype=torch.long), is_causal=True) + first_positions = torch.arange(0, 2) + + _, first_bias, first_weights = attention( + first_hidden, + attention_mask=first_mask, + past_key_values=cache, + cache_position=first_positions, + output_attentions=True, + ) + + assert first_weights.shape[-1] == 2 + assert first_bias is not None and first_bias.shape[-1] == 2 + + second_hidden = torch.randn(1, 1, hidden_size) + # key length becomes 3 after appending cached tokens + decoder_memory_mask = torch.ones(1, 3, dtype=torch.long) + second_mask = create_attention_mask(decoder_memory_mask, tgt_len=1, is_causal=True) + second_positions = torch.tensor([2]) + + _, second_bias, second_weights = attention( + second_hidden, + attention_mask=second_mask, + past_key_values=cache, + cache_position=second_positions, + output_attentions=True, + ) + + cached_layer = cache.self_attention_cache.layers[0] + assert cached_layer.keys is not None and cached_layer.keys.shape[-2] == 3 + assert cached_layer.values is not None and cached_layer.values.shape[-2] == 3 + assert second_weights.shape[-1] == 3 + assert second_bias is not None and second_bias.shape[-1] == 3 + + +def test_t5_attention_zero_bias_branch_and_mask_slicing() -> None: + hidden_size = 8 + num_heads = 2 + head_dim = hidden_size // num_heads + + attention = T5Attention( + hidden_size=hidden_size, + head_dim=head_dim, + num_heads=num_heads, + attention_dropout=0.0, + attention_bias=False, + has_relative_attention_bias=False, + ) + attention.eval() + + batch_size, seq_len = 1, 3 + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + + # create a mask whose key dimension is longer than the actual key length to trigger slicing + base_mask = torch.tensor([[1, 1, 1, 0, 0]], dtype=torch.long) + attention_mask = create_attention_mask(base_mask, tgt_len=seq_len, is_causal=False) + + attn_output, position_bias, attn_weights = attention( + hidden_states, + attention_mask=attention_mask, + output_attentions=True, + ) + + assert attn_output.shape == (batch_size, seq_len, hidden_size) + assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len) + assert position_bias is not None and position_bias.shape[-1] == seq_len + + expected_mask = attention_mask[:, :, :, :seq_len] + expanded_mask = expected_mask.expand(-1, num_heads, -1, -1) + torch.testing.assert_close(position_bias, expanded_mask) + + +def test_t5_attention_relative_bias_branch_computes_and_adds_bias() -> None: + torch.manual_seed(7) + hidden_size = 16 + num_heads = 4 + head_dim = hidden_size // num_heads + + attention = T5Attention( + hidden_size=hidden_size, + head_dim=head_dim, + num_heads=num_heads, + attention_dropout=0.0, + attention_bias=True, + has_relative_attention_bias=True, + enable_rope=False, + ) + attention.eval() + + batch_size, seq_len = 1, 5 + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + + attn_output_with_bias, position_bias, attn_weights_with_bias = attention( + hidden_states, + output_attentions=True, + ) + + assert attn_output_with_bias.shape == (batch_size, seq_len, hidden_size) + assert attn_weights_with_bias.shape == (batch_size, num_heads, seq_len, seq_len) + assert position_bias is not None + + expected_bias = attention.rel_pos_bias(seq_len, seq_len, device=hidden_states.device) + expected_bias = expected_bias[:, :, -seq_len:, :] + torch.testing.assert_close(position_bias, expected_bias) + + zero_bias = torch.zeros_like(position_bias) + _, _, attn_weights_without_bias = attention( + hidden_states, + position_bias=zero_bias, + output_attentions=True, + ) + + assert not torch.allclose(attn_weights_with_bias, attn_weights_without_bias) diff --git a/tests/models/modules/test_feedforward.py b/tests/models/modules/test_feedforward.py index 2df9228..f1f7630 100644 --- a/tests/models/modules/test_feedforward.py +++ b/tests/models/modules/test_feedforward.py @@ -1,4 +1,5 @@ import torch +import torch.nn as nn from genrec.models.modules.feedforward import FeedForwardNetwork, SwiGLU @@ -38,3 +39,27 @@ def test_swiglu_backward_pass_with_bias() -> None: assert outputs.shape == inputs.shape outputs.mean().backward() assert inputs.grad is not None + + +def test_feedforward_network_respects_custom_activation() -> None: + hidden_size = 2 + intermediate_size = 3 + ffn = FeedForwardNetwork( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ffn_bias=True, + activation=nn.Identity(), + ) + + with torch.no_grad(): + ffn.fc1.weight.fill_(1.0) + ffn.fc1.bias.zero_() + ffn.fc2.weight.fill_(1.0) + ffn.fc2.bias.zero_() + + inputs = torch.tensor([[[1.0, 2.0]]]) + outputs = ffn(inputs) + + # fc1: sum to 3 along hidden dims, replicated across intermediate_size; fc2: sums three 3's to 9 for each dim + expected = torch.full_like(inputs, 9.0) + torch.testing.assert_close(outputs, expected) diff --git a/tests/models/modules/test_layers.py b/tests/models/modules/test_layers.py index 55c5fa2..909d148 100644 --- a/tests/models/modules/test_layers.py +++ b/tests/models/modules/test_layers.py @@ -1,6 +1,7 @@ import torch +from transformers.cache_utils import DynamicCache, EncoderDecoderCache -from genrec.models.modules.layers import LlamaDecoderLayer, SequentialTransductionUnit +from genrec.models.modules.layers import LlamaDecoderLayer, SequentialTransductionUnit, T5Block from genrec.models.modules.posemb import RotaryEmbedding from genrec.models.modules.utils import create_attention_mask @@ -243,3 +244,177 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: assert isinstance(unit.mlp, RecordingMLP) assert unit.mlp.called + + +def test_t5_block_encoder_supports_masks_bias_and_rope() -> None: + torch.manual_seed(0) + + hidden_size = 16 + num_heads = 4 + head_dim = hidden_size // num_heads + block = T5Block( + hidden_size=hidden_size, + head_dim=head_dim, + num_heads=num_heads, + intermediate_size=hidden_size * 2, + linear_dropout=0.0, + attention_dropout=0.0, + is_decoder=False, + has_relative_attention_bias=False, + enable_rope=True, + ) + + batch_size, seq_len = 2, 5 + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + + padding_mask = torch.ones(batch_size, seq_len, dtype=torch.long) + padding_mask[1, -2:] = 0 + full_mask = create_attention_mask(padding_mask, is_causal=False) + attention_mask = full_mask[:, :, :1, :] + + position_bias = torch.randn(1, num_heads, seq_len, seq_len) + rotary = RotaryEmbedding(head_dim=head_dim) + rope_embeddings = rotary(hidden_states) + + encoded_rope, self_outputs, cross_outputs = block( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + position_embeddings=rope_embeddings, + output_attentions=True, + ) + + assert cross_outputs == (None, None) + assert encoded_rope.shape == (batch_size, seq_len, hidden_size) + + self_position_bias, self_attn_weights = self_outputs + assert self_position_bias is not None + assert self_attn_weights is not None + assert self_position_bias.shape == (batch_size, num_heads, seq_len, seq_len) + assert self_attn_weights.shape == (batch_size, num_heads, seq_len, seq_len) + + mask_threshold = torch.finfo(self_position_bias.dtype).min / 2 + assert torch.all(self_position_bias[1, :, :, -2:] <= mask_threshold) + assert torch.allclose(self_attn_weights[1, :, :, -2:], torch.zeros_like(self_attn_weights[1, :, :, -2:])) + + encoded_no_rope, _, _ = block( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_bias=position_bias.clone(), + position_embeddings=None, + output_attentions=False, + ) + assert not torch.allclose(encoded_rope, encoded_no_rope) + + +def test_t5_block_decoder_cross_attention_and_cache_behaves_autoregressively() -> None: + torch.manual_seed(1) + + hidden_size = 16 + num_heads = 4 + head_dim = hidden_size // num_heads + prefill_len = 3 + encoder_len = 4 + + block = T5Block( + hidden_size=hidden_size, + head_dim=head_dim, + num_heads=num_heads, + intermediate_size=hidden_size * 2, + linear_dropout=0.0, + attention_dropout=0.0, + is_decoder=True, + has_relative_attention_bias=False, + enable_rope=True, + layer_idx=0, + ) + + rotary = RotaryEmbedding(head_dim=head_dim) + + encoder_hidden_states = torch.randn(1, encoder_len, hidden_size) + encoder_token_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) + encoder_attention_mask_full = create_attention_mask(encoder_token_mask, is_causal=False) + encoder_attention_mask = encoder_attention_mask_full[:, :, :1, :] + + encoder_decoder_position_bias = torch.randn(1, num_heads, prefill_len, encoder_len) + + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + decoder_states = torch.randn(1, prefill_len, hidden_size) + decoder_attention_mask = create_attention_mask(torch.ones(1, prefill_len, dtype=torch.long), is_causal=True) + cache_position = torch.arange(prefill_len, dtype=torch.long) + + decoder_outputs = block( + hidden_states=decoder_states, + attention_mask=decoder_attention_mask, + position_embeddings=rotary(decoder_states), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + past_key_values=past_key_values, + cache_position=cache_position, + output_attentions=True, + ) + + decoded_prefill, self_prefill, cross_prefill = decoder_outputs + assert decoded_prefill.shape == (1, prefill_len, hidden_size) + + self_bias_prefill, self_attn_prefill = self_prefill + assert self_bias_prefill is not None + assert self_attn_prefill is not None + assert self_bias_prefill.shape == (1, num_heads, prefill_len, prefill_len) + assert self_attn_prefill.shape == (1, num_heads, prefill_len, prefill_len) + + cross_bias_prefill, cross_attn_prefill = cross_prefill + assert cross_bias_prefill is not None + assert cross_attn_prefill is not None + assert cross_bias_prefill.shape == (1, num_heads, prefill_len, encoder_len) + assert cross_attn_prefill.shape == (1, num_heads, prefill_len, encoder_len) + + mask_threshold = torch.finfo(cross_bias_prefill.dtype).min / 2 + assert torch.all(cross_bias_prefill[..., -1] <= mask_threshold) + + self_cache = past_key_values.self_attention_cache + cross_cache = past_key_values.cross_attention_cache + assert self_cache.layers[0].keys.shape == (1, num_heads, prefill_len, head_dim) + assert cross_cache.layers[0].keys.shape == (1, num_heads, encoder_len, head_dim) + cross_keys_before = cross_cache.layers[0].keys + assert past_key_values.is_updated[0] is True + + next_state = torch.randn(1, 1, hidden_size) + next_attention_mask = create_attention_mask( + torch.ones(1, prefill_len + 1, dtype=torch.long), + tgt_len=1, + is_causal=True, + ) + next_bias = torch.randn(1, num_heads, 1, encoder_len) + next_cache_position = torch.tensor([prefill_len], dtype=torch.long) + + decoder_next_outputs = block( + hidden_states=next_state, + attention_mask=next_attention_mask, + position_embeddings=rotary(next_state), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=next_bias, + past_key_values=past_key_values, + cache_position=next_cache_position, + output_attentions=True, + ) + + _, self_next, cross_next = decoder_next_outputs + self_bias_next, self_attn_next = self_next + cross_bias_next, cross_attn_next = cross_next + + assert self_bias_next is not None + assert self_attn_next is not None + assert cross_bias_next is not None + assert cross_attn_next is not None + + assert self_bias_next.shape == (1, num_heads, 1, prefill_len + 1) + assert cross_bias_next.shape == (1, num_heads, 1, encoder_len) + assert self_attn_next.shape == (1, num_heads, 1, prefill_len + 1) + assert cross_attn_next.shape == (1, num_heads, 1, encoder_len) + + assert self_cache.layers[0].keys.shape[-2] == prefill_len + 1 + assert cross_cache.layers[0].keys is cross_keys_before diff --git a/tests/models/modules/test_model_utils.py b/tests/models/modules/test_model_utils.py index 13b06bf..25e732d 100644 --- a/tests/models/modules/test_model_utils.py +++ b/tests/models/modules/test_model_utils.py @@ -22,3 +22,52 @@ def test_create_attention_mask_applies_causal_blocking_and_tgt_len() -> None: assert mask[0, 0, 0, 1] == min_value assert mask[0, 0, 1, 3] == min_value assert mask[0, 0, 1, 0] == 0 + + +def test_create_attention_mask_respects_cache_positions_and_past() -> None: + attention_mask = torch.ones(1, 6, dtype=torch.float32) + cache_position = torch.tensor([[3, 4]]) + + mask = create_attention_mask( + attention_mask, + tgt_len=2, + is_causal=True, + cache_position=cache_position, + past_key_values_length=3, + ) + + assert mask.shape == (1, 1, 2, 6) + min_value = torch.finfo(mask.dtype).min + assert mask[0, 0, 0, 4] == min_value # cannot attend to future key + assert mask[0, 0, 0, 2] == 0 # can attend to earlier key + assert mask[0, 0, 1, 4] == 0 # later query can attend to aligned key + + +def test_create_attention_mask_extends_with_kv_seq_len() -> None: + attention_mask = torch.ones(1, 3, dtype=torch.float32) + + mask = create_attention_mask( + attention_mask, + tgt_len=1, + is_causal=True, + kv_seq_len=6, + ) + + assert mask.shape == (1, 1, 1, 6) + min_value = torch.finfo(mask.dtype).min + assert mask[0, 0, 0, 0] == 0 + assert mask[0, 0, 0, 5] == min_value # padded keys beyond seq_len are masked + + +def test_create_attention_mask_trims_when_kv_seq_len_is_smaller() -> None: + attention_mask = torch.tensor([[1, 1, 1, 0, 0]], dtype=torch.float32) + + mask = create_attention_mask( + attention_mask, + tgt_len=1, + is_causal=False, + kv_seq_len=3, + ) + + assert mask.shape == (1, 1, 1, 3) + assert torch.all(mask == 0) diff --git a/tests/models/modules/test_posemb.py b/tests/models/modules/test_posemb.py index 7d3a7f9..d8506bd 100644 --- a/tests/models/modules/test_posemb.py +++ b/tests/models/modules/test_posemb.py @@ -4,6 +4,7 @@ LearnableInputPositionalEmbedding, RelativeBucketedTimeAndPositionAttentionBias, RotaryEmbedding, + T5RelativePositionBias, apply_rotary_pos_emb, ) @@ -117,3 +118,38 @@ def test_relative_bucketed_bias_combines_time_and_position_components() -> None: expected_time = module.time_bias_table.weight[bucketed].squeeze(-1) torch.testing.assert_close(bias[:, 0], expected_pos + expected_time) + + +def test_t5_relative_position_bias_matches_bucket_computation() -> None: + num_heads = 2 + module = T5RelativePositionBias( + num_buckets=8, + max_distance=16, + num_heads=num_heads, + is_bidirectional=False, + ) + + with torch.no_grad(): + weights = torch.arange(module.num_buckets * num_heads, dtype=torch.float32).view(module.num_buckets, num_heads) + module.relative_attention_bias.weight.copy_(weights) + + query_length, key_length = 3, 4 + cache_position = torch.tensor([5, 6, 7], dtype=torch.long) + + bias = module(query_length, key_length, cache_position=cache_position) + assert bias.shape == (1, num_heads, query_length, key_length) + + # replicate bucket logic for expected values + query_pos = cache_position + key_pos = torch.arange(key_length, dtype=torch.long) + rel_pos = key_pos[None, :] - query_pos[:, None] + expected_buckets = module._relative_position_bucket( + rel_pos, + bidirectional=module.is_bidirectional, + num_buckets=module.num_buckets, + max_distance=module.max_distance, + ) + expected_values = module.relative_attention_bias(expected_buckets) + expected_bias = expected_values.permute(2, 0, 1).unsqueeze(0) + + torch.testing.assert_close(bias, expected_bias) diff --git a/tests/trainers/trainer_quantizer/utils/test_base.py b/tests/trainers/trainer_quantizer/test_base.py similarity index 82% rename from tests/trainers/trainer_quantizer/utils/test_base.py rename to tests/trainers/trainer_quantizer/test_base.py index 3192679..0baebc3 100644 --- a/tests/trainers/trainer_quantizer/utils/test_base.py +++ b/tests/trainers/trainer_quantizer/test_base.py @@ -148,7 +148,7 @@ def test_quantizer_trainer_defaults_and_initialize_codebooks(tmp_path) -> None: assert model.initialize_called is True assert model.last_init_embeddings is not None - torch.testing.assert_close(model.last_init_embeddings.cpu(), torch.from_numpy(dataset.item_embeddings)) + torch.testing.assert_close(model.last_init_embeddings.cpu(), torch.from_numpy(dataset.item_textual_embeddings)) def test_quantizer_trainer_compute_loss_with_model_loss(tmp_path) -> None: @@ -179,13 +179,57 @@ def test_quantizer_trainer_compute_loss_with_model_loss(tmp_path) -> None: batch = [dataset[0], dataset[1]] inputs = collator(batch) - loss, output_dict = trainer.compute_loss(trainer.model, inputs, return_outputs=True) + loss, outputs = trainer.compute_loss(trainer.model, inputs, return_outputs=True) expected = torch.tensor(1.0 + 4.0 * args.model_loss_weight) torch.testing.assert_close(loss, expected) - torch.testing.assert_close(output_dict["loss"], expected) - assert output_dict["semantic_ids"].shape == (2, model.config.num_codebooks) - assert output_dict["item_id"].shape == (2,) + assert isinstance(outputs, QuantizerOutput) + assert outputs.semantic_ids is not None + assert outputs.semantic_ids.shape == (2, model.config.num_codebooks) + + +def test_quantizer_trainer_prediction_step_returns_metric_payload(tmp_path) -> None: + dataset = _build_quantizer_dataset() + collator = QuantizerCollator(dataset) + model = DummyQuantizerModel( + QuantizerModelConfig( + embed_dim=4, + hidden_sizes=(3,), + num_codebooks=2, + codebook_size=8, + codebook_dim=2, + ) + ) + args = QuantizerTrainingArguments(output_dir=str(tmp_path)) + + trainer = MinimalQuantizerTrainer( + model=model, + args=args, + data_collator=collator, + train_dataset=dataset, + ) + + batch = [dataset[0], dataset[1]] + inputs = collator(batch) + + loss, predictions, labels = trainer.prediction_step( + trainer.model, + inputs, + prediction_loss_only=False, + ) + + assert loss is not None + assert predictions is not None + assert isinstance(predictions, tuple) + assert len(predictions) == 5 + semantic_ids, reconstruction_loss, codebook_loss, commitment_loss, item_id = predictions + assert semantic_ids.shape == (len(batch), model.config.num_codebooks) + assert reconstruction_loss.shape == (len(batch),) + assert codebook_loss.shape == (len(batch),) + assert commitment_loss.shape == (len(batch),) + assert item_id.shape == (len(batch),) + assert labels is not None + torch.testing.assert_close(item_id, labels) def test_quantizer_trainer_compute_loss_without_model_loss(tmp_path) -> None: diff --git a/tests/trainers/trainer_quantizer/test_evaluations.py b/tests/trainers/trainer_quantizer/utils/test_evaluations.py similarity index 100% rename from tests/trainers/trainer_quantizer/test_evaluations.py rename to tests/trainers/trainer_quantizer/utils/test_evaluations.py diff --git a/tests/trainers/trainer_seqrec/test_base.py b/tests/trainers/trainer_seqrec/test_base.py index f78740b..aa61c2f 100644 --- a/tests/trainers/trainer_seqrec/test_base.py +++ b/tests/trainers/trainer_seqrec/test_base.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from transformers import EvalPrediction -from genrec.models.model_seqrec.base import SeqRecModelConfig +from genrec.models.model_seqrec.base import SeqRecModelConfig, SeqRecOutput from genrec.trainers.trainer_seqrec.base import compute_seqrec_metrics from genrec.trainers.trainer_seqrec.utils.evaluations import clip_top_k from genrec.trainers.trainer_seqrec.utils.callbacks import EpochIntervalEvalCallback @@ -131,7 +131,7 @@ def test_seqrec_trainer_compute_loss_with_model_loss_and_outputs( inputs = collator(batch) num_items_in_batch = torch.tensor([dataset.seq_len, dataset.seq_len], dtype=torch.long) - loss, output_dict = trainer.compute_loss( + loss, outputs = trainer.compute_loss( trainer.model, inputs, return_outputs=True, @@ -140,15 +140,9 @@ def test_seqrec_trainer_compute_loss_with_model_loss_and_outputs( expected_loss = 1.5 + 2.0 * args.model_loss_weight torch.testing.assert_close(loss, torch.tensor(expected_loss)) - torch.testing.assert_close(output_dict["loss"], loss) - assert "topk_indices" in output_dict - assert output_dict["topk_indices"].shape == (len(batch), trainer.max_top_k) + assert isinstance(outputs, SeqRecOutput) forward_outputs = trainer.model(**inputs) - last_hidden = forward_outputs.last_hidden_state[:, -1, :] - item_embed_weight = trainer.model.item_embed_weight - expected_logits = last_hidden @ item_embed_weight.T - _, expected_topk = torch.topk(expected_logits, k=trainer.max_top_k, dim=1) - torch.testing.assert_close(output_dict["topk_indices"], expected_topk) + torch.testing.assert_close(outputs.last_hidden_state, forward_outputs.last_hidden_state) assert torch.equal(trainer.last_seen_num_items, num_items_in_batch) @@ -221,8 +215,39 @@ def test_seqrec_trainer_compute_loss_without_model_loss_returns_outputs( loss, outputs = trainer.compute_loss(trainer.model, inputs, return_outputs=True) torch.testing.assert_close(loss, torch.tensor(0.6)) - assert "topk_indices" in outputs - assert outputs["topk_indices"].shape == (len(batch), trainer.max_top_k) + assert isinstance(outputs, SeqRecOutput) + + +def test_seqrec_trainer_prediction_step_returns_topk_indices(tmp_path: Path) -> None: + args = build_training_args(tmp_path, eval_interval=1) + model = DummySeqRecModel(SeqRecModelConfig(item_size=10, hidden_size=4)) + dataset = DummySeqRecDataset(seq_len=2, num_negatives=1, item_size=model.config.item_size) + trainer = MinimalSeqRecTrainer( + model=model, + args=args, + data_collator=DummySeqRecCollator(), + train_dataset=dataset, + eval_dataset=dataset, + ) + + batch = [dataset[0], dataset[1]] + inputs = trainer.data_collator(batch) + + loss, predictions, labels = trainer.prediction_step( + trainer.model, + inputs, + prediction_loss_only=False, + ) + + assert loss is not None + assert predictions is not None + assert labels is not None + assert predictions.shape == (len(batch), trainer.max_top_k) + forward_outputs = trainer.model(**inputs) + last_hidden = forward_outputs.last_hidden_state[:, -1, :] + logits = last_hidden @ trainer.model.item_embed_weight.T + _, expected_topk = torch.topk(logits, k=trainer.max_top_k, dim=1) + torch.testing.assert_close(predictions, expected_topk) def test_seqrec_trainer_normalizes_logits_when_enabled(tmp_path: Path) -> None: @@ -240,8 +265,12 @@ def test_seqrec_trainer_normalizes_logits_when_enabled(tmp_path: Path) -> None: batch = [dataset[0], dataset[1]] inputs = trainer.data_collator(batch) - loss, output_dict = trainer.compute_loss(trainer.model, inputs, return_outputs=True) - torch.testing.assert_close(loss, torch.zeros((), device=loss.device)) + loss, predictions, _ = trainer.prediction_step( + trainer.model, + inputs, + prediction_loss_only=False, + ) + assert loss is not None forward_outputs = trainer.model(**inputs) last_hidden = forward_outputs.last_hidden_state[:, -1, :] @@ -249,7 +278,8 @@ def test_seqrec_trainer_normalizes_logits_when_enabled(tmp_path: Path) -> None: expected_logits = F.normalize(last_hidden, p=2, dim=-1) @ F.normalize(item_embed_weight, p=2, dim=-1).T _, expected_topk = torch.topk(expected_logits, k=trainer.max_top_k, dim=1) - torch.testing.assert_close(output_dict["topk_indices"], expected_topk) + assert predictions is not None + torch.testing.assert_close(predictions, expected_topk) def test_seqrec_trainer_predict_returns_metrics(tmp_path: Path) -> None: