diff --git a/tests/test_lenses.py b/tests/test_lenses.py index efbfafc..b4238f2 100644 --- a/tests/test_lenses.py +++ b/tests/test_lenses.py @@ -156,3 +156,74 @@ def test_tuned_lens_generate_smoke(random_small_model: trf.PreTrainedModel): ) assert tokens.shape[-1] <= 11 assert tokens.shape[-1] > 1 + + +# --- Tests for negative indexing --- + + +def test_tuned_lens_negative_index(random_tuned_lens: TunedLens): + """Negative index -1 should return the last translator.""" + last = random_tuned_lens[-1] + explicit_last = random_tuned_lens[len(random_tuned_lens) - 1] + assert last is explicit_last + + +def test_tuned_lens_negative_index_minus_n(random_tuned_lens: TunedLens): + """Negative index -N should return the first translator.""" + n = len(random_tuned_lens) + first = random_tuned_lens[-n] + explicit_first = random_tuned_lens[0] + assert first is explicit_first + + +def test_tuned_lens_index_out_of_range(random_tuned_lens: TunedLens): + """Out-of-range indices should raise IndexError.""" + n = len(random_tuned_lens) + with pytest.raises(IndexError): + random_tuned_lens[n] + with pytest.raises(IndexError): + random_tuned_lens[-(n + 1)] + + +def test_tuned_lens_forward_negative_idx(random_tuned_lens: TunedLens): + """forward() should accept negative layer indices.""" + randn = th.randn(1, 10, 128) + logits_neg = random_tuned_lens.forward(randn, -1) + logits_pos = random_tuned_lens.forward(randn, len(random_tuned_lens) - 1) + assert th.allclose(logits_neg, logits_pos) + + +def test_tuned_lens_transform_hidden_negative_idx(random_tuned_lens: TunedLens): + """transform_hidden() should accept negative layer indices.""" + randn = th.randn(1, 10, 128) + h_neg = random_tuned_lens.transform_hidden(randn, -1) + h_pos = random_tuned_lens.transform_hidden(randn, len(random_tuned_lens) - 1) + assert th.allclose(h_neg, h_pos) + + +# --- Tests for forward_all --- + + +def test_logit_lens_forward_all(logit_lens): + """forward_all() should return one logit tensor per layer.""" + hidden_states = [th.randn(1, 10, 128) for _ in range(3)] + results = logit_lens.forward_all(hidden_states) + assert len(results) == 3 + for r in results: + assert r.shape == (1, 10, 100) + + +def test_tuned_lens_forward_all(random_tuned_lens: TunedLens): + """forward_all() should return one logit tensor per layer.""" + hidden_states = [th.randn(1, 10, 128) for _ in range(3)] + results = random_tuned_lens.forward_all(hidden_states) + assert len(results) == 3 + + +def test_forward_all_matches_sequential(random_tuned_lens: TunedLens): + """forward_all() results should match sequential forward() calls.""" + hidden_states = [th.randn(1, 10, 128) for _ in range(3)] + batch_results = random_tuned_lens.forward_all(hidden_states) + for i, h in enumerate(hidden_states): + single = random_tuned_lens.forward(h, i) + assert th.allclose(batch_results[i], single) diff --git a/tuned_lens/nn/lenses.py b/tuned_lens/nn/lenses.py index 0ef308b..bdb5860 100644 --- a/tuned_lens/nn/lenses.py +++ b/tuned_lens/nn/lenses.py @@ -6,7 +6,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from pathlib import Path -from typing import Dict, Generator, Optional, Union +from typing import Dict, Generator, Optional, Sequence, Union import torch as th from transformers import PreTrainedModel @@ -47,6 +47,25 @@ def forward(self, h: th.Tensor, idx: int) -> th.Tensor: """Decode hidden states into logits.""" ... + @th.inference_mode() + def forward_all( + self, hidden_states: Sequence[th.Tensor] + ) -> list[th.Tensor]: + """Decode hidden states from all layers into logits. + + Convenience method that applies :meth:`forward` to each layer's + hidden states in a single call under ``torch.inference_mode``. + + Args: + hidden_states: Sequence of hidden state tensors, one per layer. + Each tensor should have shape ``(batch, seq_len, d_model)``. + + Returns: + List of logit tensors, one per layer, each of shape + ``(batch, seq_len, vocab_size)``. + """ + return [self.forward(h, idx=i) for i, h in enumerate(hidden_states)] + class LogitLens(Lens): """Unembeds the residual stream into logits.""" @@ -169,8 +188,22 @@ def __init__( ) def __getitem__(self, item: int) -> th.nn.Module: - """Get the probe module at the given index.""" - return self.layer_translators[item] + """Get the probe module at the given index. + + Supports Python-style negative indexing (e.g., ``-1`` for the last + layer translator). + + Args: + item: Layer index. Negative values count from the end. + + Returns: + The translator module for the requested layer. + + Raises: + IndexError: If the index is out of range. + """ + resolved_idx = self._resolve_idx(item) + return self.layer_translators[resolved_idx] def __iter__(self) -> Generator[th.nn.Module, None, None]: """Get iterator over the translators within the lens.""" @@ -303,15 +336,57 @@ def save( with open(path / config, "w") as f: json.dump(self.config.to_dict(), f) + def _resolve_idx(self, idx: int) -> int: + """Normalize a possibly-negative layer index. + + Args: + idx: Layer index. Negative values count from the end. + + Returns: + A non-negative layer index. + + Raises: + IndexError: If the resolved index is out of range. + """ + num_layers = len(self.layer_translators) + resolved = idx if idx >= 0 else num_layers + idx + if resolved < 0 or resolved >= num_layers: + raise IndexError( + f"Layer index {idx} out of range for lens with " + f"{num_layers} translators." + ) + return resolved + def transform_hidden(self, h: th.Tensor, idx: int) -> th.Tensor: - """Transform hidden state from layer `idx`.""" + """Transform hidden state from layer ``idx``. + + Supports negative indexing (e.g., ``-1`` for the last layer). + + Args: + h: Hidden state tensor of shape ``(batch, seq_len, d_model)``. + idx: Layer index. Negative values count from the end. + + Returns: + Transformed hidden state, same shape as input. + """ # Note that we add the translator output residually, in contrast to the formula # in the paper. By parametrizing it this way we ensure that weight decay # regularizes the transform toward the identity, not the zero transformation. + idx = self._resolve_idx(idx) return h + self[idx](h) def forward(self, h: th.Tensor, idx: int) -> th.Tensor: - """Transform and then decode the hidden states into logits.""" + """Transform and then decode the hidden states into logits. + + Supports negative indexing (e.g., ``-1`` for the last layer). + + Args: + h: Hidden state tensor of shape ``(batch, seq_len, d_model)``. + idx: Layer index. Negative values count from the end. + + Returns: + Logit tensor of shape ``(batch, seq_len, vocab_size)``. + """ h = self.transform_hidden(h, idx) return self.unembed.forward(h)