Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 57 additions & 15 deletions optimum/amd/brevitas/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from datasets import load_dataset
from tqdm import tqdm

from optimum.exporters.onnx.utils import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.utils.normalized_config import NormalizedConfigManager
from transformers import AutoConfig

Expand Down Expand Up @@ -40,7 +41,14 @@ def recursive_to_device(tensor_or_iterable: Union[Iterable, torch.Tensor], devic


@torch.no_grad()
def compute_perplexity(model: torch.nn.Module, data: List[Dict], context_length: int, tokenizer: Any, seed: int = 0):
def compute_perplexity(
model: torch.nn.Module,
data: List[Dict],
context_length: int,
tokenizer: Any,
seed: int = 0,
add_bos_token_id: bool = True,
):
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
Expand All @@ -50,10 +58,15 @@ def compute_perplexity(model: torch.nn.Module, data: List[Dict], context_length:
cross_entropy_loss = nn.CrossEntropyLoss()

nlls = []
total_eval_length = 0
for sample in tqdm(data, desc="Computing perplexity..."):
sample_length = sample["input_ids"].shape[1]
for start_index in range(0, sample_length, context_length * 2):
end_index = min(start_index + sample_length, sample_length - 1)
batch_size, sample_length = sample["input_ids"].shape
for start_index in range(0, sample_length, context_length):
end_index = min(start_index + 2 * context_length, sample_length - 1)

eval_length = end_index - start_index + 1 - context_length
if eval_length <= 0:
continue

subsample = {
"input_ids": sample["input_ids"][:, start_index : end_index + 1],
Expand All @@ -64,8 +77,24 @@ def compute_perplexity(model: torch.nn.Module, data: List[Dict], context_length:
if "past_key_values" in sample and isinstance(model, torch.fx.GraphModule):
subsample["past_key_values"] = sample["past_key_values"]

dtype = subsample["input_ids"].dtype
device = subsample["input_ids"].device

# Add BOS token.
subsample["input_ids"][:, 0] = tokenizer.bos_token_id
current_context_length = context_length
if add_bos_token_id and not torch.all(subsample["input_ids"][:, 0] == tokenizer.bos_token_id):
bos_tensor = torch.full((batch_size, 1), fill_value=tokenizer.bos_token_id, dtype=dtype, device=device)
bos_mask_tensor = torch.full((batch_size, 1), fill_value=1, dtype=dtype, device=device)

subsample["input_ids"] = torch.cat((subsample["input_ids"], bos_tensor), dim=-1)
subsample["attention_mask"] = torch.cat((subsample["attention_mask"], bos_mask_tensor), dim=-1)

current_context_length += 1

if "position_ids" in sample:
subsample["position_ids"] = torch.arange(
subsample["input_ids"].shape[1], dtype=dtype, device=device
).expand(batch_size, -1)

use_accelerate = hasattr(model, "hf_device_map")
if not use_accelerate or (use_accelerate and not hasattr(model, "_hf_hook")):
Expand All @@ -80,19 +109,21 @@ def compute_perplexity(model: torch.nn.Module, data: List[Dict], context_length:

lm_logits = model(**subsample)["logits"]

reference_labels = subsample["input_ids"][:, context_length:]
reference_labels = subsample["input_ids"][:, current_context_length:]

shift_logits = lm_logits[:, context_length - 1 : -1]
shift_logits = lm_logits[:, current_context_length - 1 : -1]

# Fuse batch and sequence length dimensions.
reference_labels = reference_labels.view(reference_labels.shape[-1])
shift_logits = shift_logits.view(-1, shift_logits.shape[-1])

loss = cross_entropy_loss(shift_logits, reference_labels)
neg_log_likelihood = loss.float() * eval_length

nlls.append(loss)
total_eval_length += eval_length
nlls.append(neg_log_likelihood)

ppl = torch.exp(torch.stack(nlls).mean())
ppl = torch.exp(torch.stack(nlls).sum() / total_eval_length)

return ppl

Expand Down Expand Up @@ -261,26 +292,37 @@ def get_dataset_for_model(
tokenizer=tokenizer, nsamples=nsamples, seqlen=seqlen, split=split, fuse_sequences=fuse_sequences, seed=seed
)

config = AutoConfig.from_pretrained(model_name_or_path)

# In case the dataset is loaded to be used with an fx.GraphModule, we need to add empty past_key_values inputs in the dataset.
if qconfig.requires_fx_graph():
config = AutoConfig.from_pretrained(model_name_or_path)

normalized_config_class = NormalizedConfigManager.get_normalized_config_class(config.model_type)
normalized_config = normalized_config_class(config)

num_heads = normalized_config.num_attention_heads
head_dim = normalized_config.hidden_size // num_heads
num_kv_heads = (
normalized_config.num_key_value_heads
if hasattr(normalized_config, "num_key_value_heads")
else normalized_config.numattention_heads
)
head_dim = normalized_config.hidden_size // normalized_config.num_attention_heads
num_layers = normalized_config.num_layers

for sample in data:
sample["past_key_values"] = tuple(
(
torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device),
torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device),
torch.zeros(1, num_kv_heads, 0, head_dim, device=sample["input_ids"].device),
torch.zeros(1, num_kv_heads, 0, head_dim, device=sample["input_ids"].device),
)
for _ in range(num_layers)
)

if config.model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
for sample in data:
input_ids = sample["input_ids"]
sample["position_ids"] = torch.arange(
input_ids.shape[1], device=input_ids.device, dtype=input_ids.dtype
).unsqueeze(0)

data = DatasetToDevice(data, device=device)

return data
4 changes: 4 additions & 0 deletions optimum/amd/brevitas/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tqdm import tqdm

from optimum.exporters import TasksManager
from optimum.exporters.onnx.utils import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.quantization_base import OptimumQuantizer
from transformers.utils.fx import symbolic_trace

Expand Down Expand Up @@ -160,6 +161,9 @@ def quantize(
input_name in forward_signature for input_name in ["input_ids", "attention_mask", "past_key_values"]
):
input_names = ["input_ids", "attention_mask", "past_key_values"]

if self.config.model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
input_names.append("position_ids")
else:
raise ValueError(
f"Quantization with an FX graph is currently only supported for models taking `input_ids`, `attention_mask` and `past_key_values` as inputs. The model only has the following inputs: {forward_signature}"
Expand Down