Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions self_speculation/llama_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def crop_past_key_values(
return past_key_values


def _compute_position_embeddings(model, hidden_states, position_ids):
return model.model.rotary_emb(hidden_states, position_ids)


# Our forward_early(...) and forward_remainder(...) functions currently use transformers library's legacy KV cache implementation that is less efficient.
# To ensure an apples to apples comparison, we created this forward function to use in autoregressive decoding to ensure it uses the same KV cache implementation instead.
# FIXME: update forward_early(...) and forward_remainder(...) to use the updated more efficient KV cache implementation.
Expand Down Expand Up @@ -189,16 +193,19 @@ def forward(
)

hidden_states = inputs_embeds
position_embeddings = _compute_position_embeddings(model, hidden_states, position_ids)

for decoder_layer in model.model.layers:
hidden_states, past_key_values = decoder_layer(
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
padding_mask=None,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]

past_key_values = past_key_values.to_legacy_cache()
hidden_states = model.model.norm(hidden_states)
Expand Down Expand Up @@ -249,20 +256,22 @@ def forward_early(
)

hidden_states = inputs_embeds
position_embeddings = _compute_position_embeddings(model, hidden_states, position_ids)

for decoder_layer in model.model.layers[:exit_layer]:
hidden_states, past_key_values = decoder_layer(
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
padding_mask=None,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]

past_key_values = past_key_values.to_legacy_cache()

# next_cache = next_decoder_cache
if exit_query_cache is None:
exit_query_cache = hidden_states
else:
Expand Down Expand Up @@ -336,51 +345,48 @@ def forward_remainder(
full_past_key_values_length, # we have no past for the full model
)

next_decoder_cache = []
hidden_states = inputs_embeds
# TODO simplify
full_hidden_states: Optional[torch.FloatTensor] = None

for idx, decoder_layer in enumerate(model.model.layers):
is_early_exit = idx < exit_layer
past_key_value = (
past_key_values[idx]
if (past_key_values is not None and idx < len(past_key_values))
else None
)

if is_early_exit:
# early hidden states: B x num_gen x C
early_hidden_states = hidden_states[:, -num_tokens_to_generate:]
early_position_ids = position_ids[:, -num_tokens_to_generate:]
hidden_states, past_key_values = decoder_layer(
early_position_embeddings = _compute_position_embeddings(model, early_hidden_states, early_position_ids)

layer_outputs = decoder_layer(
early_hidden_states,
attention_mask=early_attention_mask,
position_ids=early_position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
padding_mask=None,
position_embeddings=early_position_embeddings,
)
hidden_states = layer_outputs[0]
else:
if full_hidden_states is None and exit_query_cache is not None:
# first time seeing the full hidden states, we need to rely on the
# query cache
# only use if exit query cache exists, if not this is our first call
full_hidden_states = torch.cat(
[exit_query_cache, hidden_states[:, -num_tokens_to_generate:]],
dim=1,
)
else:
# we already have seen the fully hidden states we can re-use them now
full_hidden_states = hidden_states
hidden_states, past_key_values = decoder_layer(

full_position_embeddings = _compute_position_embeddings(model, full_hidden_states, position_ids)

layer_outputs = decoder_layer(
full_hidden_states,
attention_mask=full_attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
padding_mask=None,
position_embeddings=full_position_embeddings,
)
hidden_states = layer_outputs[0]

past_key_values = past_key_values.to_legacy_cache()
hidden_states = model.model.norm(hidden_states)
Expand Down