diff --git a/Makefile b/Makefile index 5e6c4bf2..35774ba8 100644 --- a/Makefile +++ b/Makefile @@ -36,3 +36,23 @@ clean: rm -rf build/ rm -rf dist/ rm -rf optimum_amd.egg-info/ + +build-quark: + docker build -t quark-mht docker/quantization-quark/ + +interact: + docker run --rm -it --entrypoint bash \ + --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd \ + --device=/dev/dri --group-add video --ipc=host --shm-size 64g --net host \ + -v /home/amd/.cache/huggingface/hub:/data \ + -v $(PWD):/tgi \ + tgi-mht:2.5 + +interact-quark: + docker run --rm -it --entrypoint bash \ + --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd \ + --device=/dev/dri --group-add video --ipc=host --shm-size 64g --net host \ + -v /home/amd/.cache/huggingface/hub:/data \ + -v $(PWD):/quark \ + -v $(PWD)/../transformers:/tr \ + quark-mht diff --git a/docker/quantization-quark/Dockerfile b/docker/quantization-quark/Dockerfile new file mode 100644 index 00000000..eaa2943b --- /dev/null +++ b/docker/quantization-quark/Dockerfile @@ -0,0 +1,42 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# Licensed under the MIT License. + +FROM rocm/dev-ubuntu-22.04:6.1 + +LABEL maintainer="Hugging Face" + +ARG DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + sudo \ + python3.10 \ + python3.10-dev \ + python3-pip \ + git \ + wget \ + unzip \ + libsndfile1-dev \ + tesseract-ocr \ + espeak-ng \ + rocthrust-dev \ + hipsparse-dev \ + hipblaslt-dev \ + hipsolver-dev \ + hipblas-dev && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* && \ + update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 && \ + python -m pip install -U pip + +RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1 --no-cache-dir + +WORKDIR /quark +RUN wget -O quark-0.2.0-23-py3-none-any.whl https://www.xilinx.com/bin/public/openDownload?filename=quark-0.2.0+6af1bac23-py3-none-any.whl && \ + pip install quark-0.2.0-23-py3-none-any.whl && \ + rm -rf quark-0.2.0-23-py3-none-any.whl + +RUN python -c "import quark.torch.kernel" + +RUN pip install git+https://github.com/mht-sharma/transformers.git@fc62c00e1f2a927acb354e28e43828a47fa776b6 + +ENTRYPOINT ["bash"] \ No newline at end of file diff --git a/optimum/amd/quantizers/__init__.py b/optimum/amd/quantizers/__init__.py new file mode 100644 index 00000000..d92810ee --- /dev/null +++ b/optimum/amd/quantizers/__init__.py @@ -0,0 +1,4 @@ +from .quark import ( + AutoQuantizationConfig, + QuarkPlugin, +) \ No newline at end of file diff --git a/optimum/amd/quantizers/quark/__init__.py b/optimum/amd/quantizers/quark/__init__.py new file mode 100644 index 00000000..d85ebafc --- /dev/null +++ b/optimum/amd/quantizers/quark/__init__.py @@ -0,0 +1,2 @@ +from .configuration import AutoQuantizationConfig +from .quantizer import QuarkPlugin \ No newline at end of file diff --git a/optimum/amd/quantizers/quark/algo_config_constants.py b/optimum/amd/quantizers/quark/algo_config_constants.py new file mode 100644 index 00000000..24283625 --- /dev/null +++ b/optimum/amd/quantizers/quark/algo_config_constants.py @@ -0,0 +1,195 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# Licensed under the MIT License. + + +ALGO_CONFIG_PARAMS = { + "llama": { + "scaling_layers": [ + { + "prev_op": "input_layernorm", + "layers": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], + "inp": "self_attn.q_proj", + "module2inspect": "self_attn", + "has_kwargs": True, + "help": "attention input", + }, + { + "prev_op": "self_attn.v_proj", + "layers": ["self_attn.o_proj"], + "inp": "self_attn.o_proj", + "module2inspect": None, + "has_kwargs": False, + "help": "attention out, Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696, if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape", + "condition": "module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape", + }, + { + "prev_op": "post_attention_layernorm", + "layers": ["mlp.gate_proj", "mlp.up_proj"], + "inp": "mlp.gate_proj", + "module2inspect": "mlp", + "has_kwargs": False, + "help": "linear 1", + }, + { + "prev_op": "mlp.up_proj", + "layers": ["mlp.down_proj"], + "inp": "mlp.down_proj", + "module2inspect": None, + "has_kwargs": False, + "help": "linear 2", + }, + ], + "inside_layer_modules": [ + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.q_proj", + "self_attn.o_proj", + "mlp.up_proj", + "mlp.gate_proj", + "mlp.down_proj", + ], + "model_decoder_layers": "model.layers", + "embedding_layers": ["model.embed_tokens"], + }, + "mistral": { + "scaling_layers": [ + { + "prev_op": "input_layernorm", + "layers": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], + "inp": "self_attn.q_proj", + "module2inspect": "self_attn", + "has_kwargs": True, + "help": "attention input", + }, + { + "prev_op": "self_attn.v_proj", + "layers": ["self_attn.o_proj"], + "inp": "self_attn.o_proj", + "module2inspect": None, + "has_kwargs": False, + "help": "attention out, Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696, if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape", + "condition": "module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape", + }, + { + "prev_op": "post_attention_layernorm", + "layers": ["mlp.gate_proj", "mlp.up_proj"], + "inp": "mlp.gate_proj", + "module2inspect": "mlp", + "has_kwargs": False, + "help": "linear 1", + }, + { + "prev_op": "mlp.up_proj", + "layers": ["mlp.down_proj"], + "inp": "mlp.down_proj", + "module2inspect": None, + "has_kwargs": False, + "help": "linear 2", + }, + ], + "inside_layer_modules": [ + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.q_proj", + "self_attn.o_proj", + "mlp.up_proj", + "mlp.gate_proj", + "mlp.down_proj", + ], + "model_decoder_layers": "model.layers", + "embedding_layers": ["model.embed_tokens"], + }, + "opt": { + "scaling_layers": [ + { + "prev_op": "self_attn_layer_norm", + "layers": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], + "inp": "self_attn.q_proj", + "module2inspect": "self_attn", + "has_kwargs": True, + "help": "attention input", + }, + { + "prev_op": "self_attn.v_proj", + "layers": ["self_attn.out_proj"], + "inp": "self_attn.out_proj", + "module2inspect": None, + "has_kwargs": False, + "help": "attention out", + }, + { + "prev_op": "final_layer_norm", + "layers": ["fc1"], + "inp": "fc1", + "module2inspect": None, + "has_kwargs": False, + "help": "linear 1", + }, + { + "prev_op": "fc1", + "layers": ["fc2"], + "inp": "fc2", + "module2inspect": None, + "has_kwargs": False, + "help": "linear 2", + }, + ], + "inside_layer_modules": [ + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.q_proj", + "self_attn.out_proj", + "fc1", + "fc2", + ], + "model_decoder_layers": "model.decoder.layers", + "embedding_layers": ["model.decoder.embed_tokens", "model.decoder.embed_positions"], + }, + "qwen2": { + "scaling_layers": [ + { + "prev_op": "input_layernorm", + "layers": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], + "inp": "self_attn.q_proj", + "module2inspect": "self_attn", + "has_kwargs": True, + "help": "attention input", + }, + { + "prev_op": "self_attn.v_proj", + "layers": ["self_attn.o_proj"], + "inp": "self_attn.o_proj", + "module2inspect": None, + "has_kwargs": False, + "help": "attention out, Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696, if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape", + }, + { + "prev_op": "post_attention_layernorm", + "layers": ["mlp.gate_proj", "mlp.up_proj"], + "inp": "mlp.gate_proj", + "module2inspect": "mlp", + "has_kwargs": False, + "help": "linear 1", + }, + { + "prev_op": "mlp.up_proj", + "layers": ["mlp.down_proj"], + "inp": "mlp.down_proj", + "module2inspect": None, + "has_kwargs": False, + "help": "linear 2", + }, + ], + "inside_layer_modules": [ + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.q_proj", + "self_attn.o_proj", + "mlp.up_proj", + "mlp.gate_proj", + "mlp.down_proj", + ], + "model_decoder_layers": "model.layers", + "embedding_layers": ["model.embed_tokens"], + }, +} diff --git a/optimum/amd/quantizers/quark/configuration.py b/optimum/amd/quantizers/quark/configuration.py new file mode 100644 index 00000000..4e228ac2 --- /dev/null +++ b/optimum/amd/quantizers/quark/configuration.py @@ -0,0 +1,305 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# Licensed under the MIT License. +"""Configuration classes for quantization with AMD Quark.""" + +from enum import Enum +from typing import Dict, List, Optional, Union + +from quark.torch.quantization.config.config import ( + AlgoConfig, + AWQConfig, + Config, + GPTQConfig, + PreQuantOptConfig, + QuantizationConfig, + QuantizationSpec, + SmoothQuantConfig, +) +from quark.torch.quantization.config.type import QuantizationMode + +from .algo_config_constants import ALGO_CONFIG_PARAMS +from .quantizer import QuarkConfig + +from .custom_configs import ( + FLOAT16_CONFIG, + FP8_PER_TENSOR_SPEC, + W_FP8_A_FP8_OFP8_PER_TENSOR_CONFIG, + W_FP8_A_FP8_PER_TENSOR_CONFIG, + W_INT4_PER_CHANNEL_CONFIG, + W_INT4_PER_GROUP_SYM_CONFIG, + W_INT4_PER_TENSOR_CONFIG, + W_INT8_A_INT8_PER_TENSOR_CONFIG, + W_INT8_A_INT8_PER_TENSOR_DYNAMIC_CONFIG, + W_INT8_PER_GROUP_CONFIG, + W_INT8_PER_TENSOR_CONFIG, + W_MX_FP8_A_MX_FP8_CONFIG, + W_MX_FP8_CONFIG, + W_UINT4_A_BFLOAT16_PER_GROUP_CONFIG, + W_UINT4_PER_GROUP_CONFIG, +) + + +Config = Config + +__all__ = [ + "Config", + "QuantizationSpec", + "SmoothQuantConfig", + "AWQConfig", + "GPTQConfig", + "AlgoConfig", +] + + +class KVCacheDType(Enum): + FP8 = "fp8" + + +class AutoQuantizationConfig: + QUANT_CONFIG_MAP = { + "w_fp8_a_fp8": (W_FP8_A_FP8_PER_TENSOR_CONFIG, False), + "w_fp8_a_fp8_o_fp8": (W_FP8_A_FP8_OFP8_PER_TENSOR_CONFIG, False), + "w_int4_per_tensor": (W_INT4_PER_TENSOR_CONFIG, False), + "w_int4_per_channel_sym": (W_INT4_PER_CHANNEL_CONFIG, False), + "w_int4_per_group_sym": (W_INT4_PER_GROUP_SYM_CONFIG, True), + "w_uint4_per_group_asym": (W_UINT4_PER_GROUP_CONFIG, True), + "w_uint4_a_bfloat16_per_group_asym": (W_UINT4_A_BFLOAT16_PER_GROUP_CONFIG, True), + "w_int8_per_tensor_sym": (W_INT8_PER_TENSOR_CONFIG, False), + "w_int8_per_group_sym": (W_INT8_PER_GROUP_CONFIG, True), + "w_int8_a_int8_per_tensor_sym": (W_INT8_A_INT8_PER_TENSOR_CONFIG, False), + "w_int8_a_int8_per_tensor_sym_dynamic": (W_INT8_A_INT8_PER_TENSOR_DYNAMIC_CONFIG, False), + "w_mx_fp8": (W_MX_FP8_CONFIG, False), + "w_mx_fp8_a_mx_fp8": (W_MX_FP8_A_MX_FP8_CONFIG, False), + "float16": (FLOAT16_CONFIG, False), + } + + @staticmethod + def from_quant_scheme( + quant_scheme: str, + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + algo_config: Optional[AlgoConfig] = None, + exclude: List[str] = [], + pre_quant_opt_config: Optional[Union[PreQuantOptConfig, List[PreQuantOptConfig]]] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + quant_algo: Optional[str] = None, + model_type: Optional[str] = None, + group_size: Optional[int] = None, + quant_mode: QuantizationMode = QuantizationMode.eager_mode, # Default value provided + ): + """ + Generates a quantization configuration based on the specified quantization scheme. + + Args: + quant_scheme (str): The name of the quantization scheme to be used. Available schemes are: + + - `w_fp8_a_fp8`: Quantization with FP8 weights and activations. + - `w_fp8_a_fp8_o_fp8`: Quantization with FP8 weights, activations, and outputs. + - `w_int4_per_tensor`: Quantization with INT4 weights per-tensor configuration. + - `w_int4_per_channel_sym`: Quantization with INT4 weights per-channel symmetric configuration. + - `w_int4_per_group_sym`: Quantization with INT4 weights per-group symmetric configuration. + - `w_uint4_per_group_asym`: Quantization with UINT4 weights per-group asymmetric configuration. + - `w_uint4_a_bfloat16_per_group_asym`: Quantization with UINT4 weights and BFLOAT16 activations per-group asymmetric configuration. + - `w_int8_per_tensor_sym`: Quantization with INT8 weights per-tensor symmetric configuration. + - `w_int8_per_group_sym`: Quantization with INT8 weights per-group symmetric configuration. + - `w_int8_a_int8_per_tensor_sym`: Quantization with INT8 weights and activations per-tensor symmetric configuration. + - `w_int8_a_int8_per_tensor_sym_dynamic`: Dynamic quantization with INT8 weights and activations per-tensor symmetric configuration. + - `w_mx_fp8`: Quantization with MX-FP8 configuration. + - `w_mx_fp8_a_mx_fp8`: Quantization with MX-FP8 weights and activations configuration. + - `float16`: Quantization with FLOAT16 configuration. + layer_type_quant_config (Dict[str, QuantizationConfig], optional): + A dictionary mapping from layer types (e.g., `nn.Conv2d`, `nn.Linear`) to their quantization configurations. Defaults to `{}`. + layer_quant_config (Dict[str, QuantizationConfig], optional): + A dictionary mapping from layer names to their quantization configurations, allowing for per-layer customization. Defaults to `{}`. + algo_config (Optional[AlgoConfig], optional): + Optional configuration for the quantization algorithm (e.g., GPTQ, AWQ). Defaults to `None`. + exclude (List[str], optional): + A list of layer names to exclude from quantization. Defaults to `[]`. + pre_quant_opt_config (Optional[Union[PreQuantOptConfig, List[PreQuantOptConfig]]], optional): + Optional pre-quantization optimization configurations. Defaults to `None`. + kv_cache_dtype (Optional[KVCacheDType], optional): + Optional data type for the key-value cache (e.g., 'fp8'). Defaults to `None`. + quant_algo (Optional[str], optional): + The name of the quantization algorithm (e.g., 'awq', 'gptq'). Defaults to `None`. + model_type (Optional[str], optional): + The type of the model (e.g., 'llama', 'opt'). Required if `quant_algo` is provided. Defaults to `None`. + group_size (Optional[int], optional): + Group size for the quantization scheme, if required. Defaults to `None`. + quant_mode (QuantizationMode, optional): + The quantization mode (e.g., EAGER_MODE or POST_TRAINING_MODE). Defaults to `QuantizationMode.eager_mode`. + + Returns: + Config: + The final quantization configuration for the specified quantization scheme. + + Raises: + ValueError: + If an invalid `quant_scheme` is provided. + If `group_size` is required by the `quant_scheme` but not provided. + If both `quant_algo` and `algo_config` are provided. + If `quant_algo` is provided but `model_type` is missing. + """ + if quant_scheme not in AutoQuantizationConfig.QUANT_CONFIG_MAP: + raise ValueError(f"Invalid quantization scheme: {quant_scheme}") + + global_quant_config, requires_group_size = AutoQuantizationConfig.QUANT_CONFIG_MAP[quant_scheme] + + if requires_group_size and group_size is None: + raise ValueError(f"Quantization scheme '{quant_scheme}' requires 'group_size'.") + + if requires_group_size: + global_quant_config.weight.group_size = group_size + + # Proceed with the rest of the logic + config = AutoQuantizationConfig._get_config( + global_quant_config=global_quant_config, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + algo_config=algo_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + kv_cache_dtype=kv_cache_dtype, + quant_algo=quant_algo, + model_type=model_type, + quant_mode=quant_mode, + ) + + return QuarkConfig(config) + + @staticmethod + def _get_config( + global_quant_config: QuantizationConfig, + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + algo_config: Optional[AlgoConfig] = None, + exclude: List[str] = [], + pre_quant_opt_config: Optional[Union[PreQuantOptConfig, List[PreQuantOptConfig]]] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + quant_algo: Optional[str] = None, + model_type: Optional[str] = None, + quant_mode: QuantizationMode = QuantizationMode.eager_mode, # Default value provided + ): + if quant_algo and algo_config: + raise ValueError("Only one of quant_algo and algo_config can be provided.") + if quant_algo and model_type is None: + raise ValueError("model_type must be provided when quant_algo is provided.") + + algo_config = algo_config or AutoQuantizationConfig._get_algo_config( + quant_algo, global_quant_config, model_type + ) + + # Apply kv_cache_dtype if necessary + if kv_cache_dtype: + layer_quant_config = AutoQuantizationConfig._apply_fp8_config( + global_quant_config, layer_quant_config, kv_cache_dtype + ) + + quant_config = Config( + global_quant_config=global_quant_config, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + algo_config=algo_config, + quant_mode=quant_mode, + ) + + if algo_config and quant_algo is None: + model_type = AutoQuantizationConfig._validate_model_type(model_type) + algo_config_info = ALGO_CONFIG_PARAMS[model_type] + quant_config.algo_config.inside_layer_modules = algo_config_info["inside_layer_modules"] + quant_config.algo_config.model_decoder_layers = algo_config_info["model_decoder_layers"] + quant_config.algo_config.embedding_layers = algo_config_info["embedding_layers"] + + return quant_config + + @staticmethod + def _apply_fp8_config(global_quant_config, layer_quant_config, kv_cache_dtype): + """ + Applies FP8 configuration if kv_cache_dtype is 'fp8'. + + Args: + global_quant_config (QuantizationConfig): The global quantization configuration. + layer_quant_config (Dict[str, QuantizationConfig]): The layer-specific quantization configuration. + kv_cache_dtype (str): The data type for the key-value cache (e.g., 'fp8'). + + Returns: + Dict[str, QuantizationConfig]: The updated layer-specific quantization configuration. + + Raises: + ValueError: If an invalid kv_cache_dtype is provided. + """ + if kv_cache_dtype.lower() != "fp8": + raise ValueError(f"Invalid value for kv_cache_dtype: {kv_cache_dtype}. Expected 'fp8'.") + + FP8_CONFIG = { + "*.v_proj": QuantizationConfig( + input_tensors=global_quant_config.input_tensors, + output_tensors=FP8_PER_TENSOR_SPEC, + weight=global_quant_config.weight, + ), + "*.k_proj": QuantizationConfig( + input_tensors=global_quant_config.input_tensors, + output_tensors=FP8_PER_TENSOR_SPEC, + weight=global_quant_config.weight, + ), + } + return {**layer_quant_config, **FP8_CONFIG} + + @staticmethod + def _get_algo_config( + quant_algo: Optional[str], global_quant_config: QuantizationConfig, model_type: Optional[str] = None + ) -> Optional[AlgoConfig]: + """ + Retrieves the appropriate algorithm configuration for the quantization algorithm. + + Args: + quant_algo (Optional[str]): The name of the quantization algorithm (e.g., 'awq', 'gptq'). + global_quant_config (QuantizationConfig): The global quantization configuration. + model_type (Optional[str], optional): The model type, required if using quant_algo. Defaults to None. + + Returns: + Optional[AlgoConfig]: The algorithm configuration if a valid `quant_algo` is provided. + + Raises: + ValueError: If an invalid model_type is provided or if unsupported configurations are used. + """ + if quant_algo is None: + return None + + if model_type is None: + raise ValueError("model_type must be provided when quant_algo is specified.") + + SUPPORTED_MODEL_TYPES = ["llama", "mistral", "opt", "qwen2"] + if model_type not in SUPPORTED_MODEL_TYPES: + raise ValueError(f"Invalid model_type: {model_type}. Expected one of {SUPPORTED_MODEL_TYPES}.") + + algo_config_map = {"awq": AWQConfig, "smoothquant": SmoothQuantConfig, "gptq": GPTQConfig} + + if quant_algo not in algo_config_map: + return None + + algo_config = algo_config_map[quant_algo]() + + # Ensure compatibility of quant_algo with global_quant_config + if quant_algo == "awq": + assert global_quant_config in [ + W_UINT4_PER_GROUP_CONFIG, + W_INT4_PER_GROUP_SYM_CONFIG, + W_INT8_PER_GROUP_CONFIG, + ] + elif quant_algo == "gptq": + assert global_quant_config in [W_UINT4_PER_GROUP_CONFIG] + + return algo_config + + @staticmethod + def _validate_model_type(model_type: str): + SUPPORTED_MODEL_TYPES = ["llama", "mistral", "opt", "qwen2"] + + if model_type not in SUPPORTED_MODEL_TYPES: + raise ValueError( + f"Invalid value for model_type for AutoQuantizationConfig: {model_type}. Expected one of {SUPPORTED_MODEL_TYPES}." + ) + + return model_type \ No newline at end of file diff --git a/optimum/amd/quantizers/quark/custom_configs.py b/optimum/amd/quantizers/quark/custom_configs.py new file mode 100644 index 00000000..171128a0 --- /dev/null +++ b/optimum/amd/quantizers/quark/custom_configs.py @@ -0,0 +1,164 @@ +# +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +# + + +from quark.torch.quantization.config.config import ( + QuantizationConfig, + QuantizationSpec, +) +from quark.torch.quantization.config.type import Dtype, QSchemeType, RoundType, ScaleType +from quark.torch.quantization.observer.observer import ( + PerBlockMXObserver, + PerChannelMinMaxObserver, + PerGroupMinMaxObserver, + PerTensorMinMaxObserver, + PlaceholderObserver, +) + + +# Configure `QuantizationSpec` for torch.Tensors. Specify attributes such as dtype, observer_cls, etc. +FLOAT16_SPEC = QuantizationSpec(dtype=Dtype.float16, observer_cls=PlaceholderObserver) + +BFLOAT16_SPEC = QuantizationSpec(dtype=Dtype.bfloat16, observer_cls=PlaceholderObserver) + +FP8_PER_TENSOR_SPEC = QuantizationSpec( + dtype=Dtype.fp8_e4m3, qscheme=QSchemeType.per_tensor, observer_cls=PerTensorMinMaxObserver, is_dynamic=False +) + +INT4_PER_TENSER_SPEC = QuantizationSpec( + dtype=Dtype.int4, + qscheme=QSchemeType.per_tensor, + observer_cls=PerTensorMinMaxObserver, + symmetric=True, + scale_type=ScaleType.float, + round_method=RoundType.half_even, + is_dynamic=False, +) + +INT4_PER_CHANNEL_SPEC = QuantizationSpec( + dtype=Dtype.int4, + observer_cls=PerChannelMinMaxObserver, + symmetric=True, + scale_type=ScaleType.float, + round_method=RoundType.half_even, + qscheme=QSchemeType.per_channel, + ch_axis=0, + is_dynamic=False, +) + +INT4_PER_GROUP_SYM_SPEC = QuantizationSpec( + dtype=Dtype.int4, + observer_cls=PerGroupMinMaxObserver, + symmetric=True, + scale_type=ScaleType.float, + round_method=RoundType.half_even, + qscheme=QSchemeType.per_group, + ch_axis=1, + is_dynamic=False, + group_size=128, +) + +UINT4_PER_GROUP_ASYM_SPEC = QuantizationSpec( + dtype=Dtype.uint4, + observer_cls=PerGroupMinMaxObserver, + symmetric=False, + scale_type=ScaleType.float, + round_method=RoundType.half_even, + qscheme=QSchemeType.per_group, + ch_axis=1, + is_dynamic=False, + group_size=128, +) + +INT8_PER_TENSER_SPEC = QuantizationSpec( + dtype=Dtype.int8, + qscheme=QSchemeType.per_tensor, + observer_cls=PerTensorMinMaxObserver, + symmetric=True, + scale_type=ScaleType.float, + round_method=RoundType.half_even, + is_dynamic=False, +) + +INT8_PER_TENSER_DYNAMIC_SPEC = QuantizationSpec( + dtype=Dtype.int8, + qscheme=QSchemeType.per_tensor, + observer_cls=PerTensorMinMaxObserver, + symmetric=True, + scale_type=ScaleType.float, + round_method=RoundType.half_even, + is_dynamic=True, +) + +INT8_PER_GROUP_SYM_SPEC = QuantizationSpec( + dtype=Dtype.int8, + observer_cls=PerGroupMinMaxObserver, + symmetric=True, + scale_type=ScaleType.float, + round_method=RoundType.half_even, + qscheme=QSchemeType.per_group, + ch_axis=1, + is_dynamic=False, + group_size=128, +) + +W_MX_FP8_SPEC = QuantizationSpec( + dtype=Dtype.mx, + observer_cls=PerBlockMXObserver, + qscheme=QSchemeType.per_group, + mx_element_dtype=Dtype.fp8_e4m3, + ch_axis=-1, + group_size=32, + is_dynamic=False, + round_method=RoundType.half_even, +) +A_MX_FP8_SPEC = QuantizationSpec( + dtype=Dtype.mx, + observer_cls=PerBlockMXObserver, + qscheme=QSchemeType.per_group, + mx_element_dtype=Dtype.fp8_e4m3, + ch_axis=-1, + group_size=32, + is_dynamic=True, + round_method=RoundType.half_even, +) + + +# Establish `QuantizationConfig` for nn.Module. Define the QuantizationSpec of input_tensors, output_tensors, weight, and bias. +# Float16 config +FLOAT16_CONFIG = QuantizationConfig(input_tensors=FLOAT16_SPEC, weight=FLOAT16_SPEC) + +# Fp8(e4m3) config +W_FP8_A_FP8_PER_TENSOR_CONFIG = QuantizationConfig(input_tensors=FP8_PER_TENSOR_SPEC, weight=FP8_PER_TENSOR_SPEC) + +W_FP8_A_FP8_OFP8_PER_TENSOR_CONFIG = QuantizationConfig( + input_tensors=FP8_PER_TENSOR_SPEC, weight=FP8_PER_TENSOR_SPEC, output_tensors=FP8_PER_TENSOR_SPEC +) + +# Int per tensor config +W_INT4_PER_TENSOR_CONFIG = QuantizationConfig(weight=INT4_PER_TENSER_SPEC) + +W_INT8_PER_TENSOR_CONFIG = QuantizationConfig(weight=INT8_PER_TENSER_SPEC) + +W_INT8_A_INT8_PER_TENSOR_CONFIG = QuantizationConfig(input_tensors=INT8_PER_TENSER_SPEC, weight=INT8_PER_TENSER_SPEC) + +W_INT8_A_INT8_PER_TENSOR_DYNAMIC_CONFIG = QuantizationConfig( + input_tensors=INT8_PER_TENSER_DYNAMIC_SPEC, weight=INT8_PER_TENSER_DYNAMIC_SPEC +) + +# Int per Channel Config +W_INT4_PER_CHANNEL_CONFIG = QuantizationConfig(weight=INT4_PER_CHANNEL_SPEC) + +# Int per Group Config +W_INT4_PER_GROUP_SYM_CONFIG = QuantizationConfig(weight=INT4_PER_GROUP_SYM_SPEC) + +W_UINT4_PER_GROUP_CONFIG = QuantizationConfig(weight=UINT4_PER_GROUP_ASYM_SPEC) + +W_UINT4_A_BFLOAT16_PER_GROUP_CONFIG = QuantizationConfig(input_tensors=BFLOAT16_SPEC, weight=UINT4_PER_GROUP_ASYM_SPEC) + +W_INT8_PER_GROUP_CONFIG = QuantizationConfig(weight=INT8_PER_GROUP_SYM_SPEC) + +W_MX_FP8_CONFIG = QuantizationConfig(weight=W_MX_FP8_SPEC) +W_MX_FP8_A_MX_FP8_CONFIG = QuantizationConfig(weight=W_MX_FP8_SPEC, input_tensors=A_MX_FP8_SPEC) diff --git a/optimum/amd/quantizers/quark/quantizer.py b/optimum/amd/quantizers/quark/quantizer.py new file mode 100644 index 00000000..5d941961 --- /dev/null +++ b/optimum/amd/quantizers/quark/quantizer.py @@ -0,0 +1,70 @@ +from enum import Enum +from typing import ( + Iterable, + Union, +) + +import torch + +from quark.torch import ModelQuantizer +from quark.torch.quantization.config.config import Config +from transformers.quantizers import HfQuantizer, HfQuantizerPlugin +from transformers.utils.quantization_config import QuantizationConfigMixin + +class QuantizationMethod(str, Enum): + QUARK = "quark" + +class QuarkPlugin(HfQuantizerPlugin): + @staticmethod + def get_quantizer(): + return QuarkQuantizer + + @staticmethod + def get_config(): + return QuarkConfig + +class QuarkConfig(QuantizationConfigMixin): + def __init__( + self, + qconfig: Union[Config], + dataset: Iterable = None, + ): + self._qconfig = qconfig + self._dataset = dataset + + @property + def quant_method(self): + return QuantizationMethod.QUARK + + +class QuarkQuantizer(HfQuantizer): + requires_calibration = False + optimum_quantizer = None + + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + super().__init__(quantization_config, **kwargs) + + if self.pre_quantized is True: + raise ValueError("Pre-quantized models are not supported by QuarkQuantizer") + + self._quantizer = ModelQuantizer(quantization_config._qconfig) + + def _process_model_before_weight_loading(self, model, **kwargs): + return model + + def _process_model_after_weight_loading(self, model, **kwargs): + if not self._quantizer.is_all_dynamic and self.quantization_config.dataset is None: + raise ValueError("A calibration dataset is required for the quantization method.") + + with torch.inference_mode(): + qmodel = self._quantizer.quantize_model(model, self._quantization_config.dataset) + + return qmodel + + @property + def is_serializable(self): + return True + + @property + def is_trainable(self): + return False diff --git a/optimum/amd/quantizers/quark_back/configuration.py b/optimum/amd/quantizers/quark_back/configuration.py new file mode 100644 index 00000000..d82ea586 --- /dev/null +++ b/optimum/amd/quantizers/quark_back/configuration.py @@ -0,0 +1,367 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# Licensed under the MIT License. +"""Configuration classes for quantization with AMD Quark.""" + +from enum import Enum +from typing import Dict, List, Optional + +from quark.torch.quantization.config.config import ( + AlgoConfig, + AWQConfig, + Config, + GPTQConfig, + QuantizationConfig, + QuantizationSpec, + SmoothQuantConfig, +) +from quark.torch.quantization.config.custom_config import ( + DEFAULT_AWQ_CONFIG, + DEFAULT_FLOAT16_CONFIG, + DEFAULT_GPTQ_CONFIG, + DEFAULT_SMOOTH_QUANT_CONFIG, + DEFAULT_W_FP8_A_FP8_OFP8_PER_TENSOR_CONFIG, + DEFAULT_W_FP8_A_FP8_PER_TENSOR_CONFIG, + DEFAULT_W_INT4_PER_CHANNEL_CONFIG, + DEFAULT_W_INT4_PER_GROUP_SYM_CONFIG, + DEFAULT_W_INT4_PER_TENSOR_CONFIG, + DEFAULT_W_INT8_A_INT8_PER_TENSOR_CONFIG, + DEFAULT_W_INT8_A_INT8_PER_TENSOR_DYNAMIC_CONFIG, + DEFAULT_W_UINT4_A_BFLOAT16_PER_GROUP_CONFIG, + DEFAULT_W_UINT4_PER_GROUP_CONFIG, + FP8_PER_TENSOR_SPEC, +) + +from ..quark.algo_config_constants import ALGO_CONFIG_PARAMS + + +QuarkQuantizationConfig = Config + +__all__ = [ + "QuarkQuantizationConfig", + "QuantizationSpec", + "SmoothQuantConfig", + "AWQConfig", + "GPTQConfig", + "AlgoConfig", +] + + +class KVCacheDType(Enum): + FP8 = "fp8" + + +class AutoQuantizationConfig: + @staticmethod + def _apply_fp8_config(layer_quant_config): + KV_CACHE_CFG = { + "*.v_proj": QuantizationConfig( + input_tensors=FP8_PER_TENSOR_SPEC, output_tensors=FP8_PER_TENSOR_SPEC, weight=FP8_PER_TENSOR_SPEC + ), + "*.k_proj": QuantizationConfig( + input_tensors=FP8_PER_TENSOR_SPEC, output_tensors=FP8_PER_TENSOR_SPEC, weight=FP8_PER_TENSOR_SPEC + ), + } + return {**layer_quant_config, **KV_CACHE_CFG} + + @staticmethod + def _validate_kv_cache_dtype(kv_cache_dtype: str): + if kv_cache_dtype.lower() != "fp8": + raise ValueError(f"Invalid value for kv_cache_dtype: {kv_cache_dtype}. Expected 'fp8'.") + + @staticmethod + def _validate_model_type(model_type: str): + SUPPORTED_MODEL_TYPES = ["llama", "mistral", "opt", "qwen2"] + + if model_type not in SUPPORTED_MODEL_TYPES: + raise ValueError( + f"Invalid value for model_type for AutoQuantizationConfig: {model_type}. Expected one of {SUPPORTED_MODEL_TYPES}." + ) + + return model_type + + @staticmethod + def w_fp8_a_fp8( + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + return QuarkQuantizationConfig( + global_quant_config=DEFAULT_W_FP8_A_FP8_PER_TENSOR_CONFIG, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def w_fp8_a_fp8_o_fp8( + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + return QuarkQuantizationConfig( + global_quant_config=DEFAULT_W_FP8_A_FP8_OFP8_PER_TENSOR_CONFIG, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def w_int4_per_tensor( + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + return QuarkQuantizationConfig( + global_quant_config=DEFAULT_W_INT4_PER_TENSOR_CONFIG, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def w_int4_per_channel_sym( + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + return QuarkQuantizationConfig( + global_quant_config=DEFAULT_W_INT4_PER_CHANNEL_CONFIG, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def w_int4_per_group_sym( + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + return QuarkQuantizationConfig( + global_quant_config=DEFAULT_W_INT4_PER_GROUP_SYM_CONFIG, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def w_uint4_per_group_asym_awq( + model_type: str, + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + model_type = AutoQuantizationConfig._validate_model_type(model_type) + algo_config_info = ALGO_CONFIG_PARAMS[model_type] + + quant_config = DEFAULT_AWQ_CONFIG + quant_config.algo_config.scaling_layers = algo_config_info["scaling_layers"] + quant_config.algo_config.model_decoder_layers = algo_config_info["model_decoder_layers"] + quant_config.algo_config.embedding_layers = algo_config_info["embedding_layers"] + return QuarkQuantizationConfig( + global_quant_config=quant_config.global_quant_config, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + algo_config=quant_config.algo_config, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def w_uint4_per_group_asym_smoothquant( + model_type: str, + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + model_type = AutoQuantizationConfig._validate_model_type(model_type) + algo_config_info = ALGO_CONFIG_PARAMS[model_type] + + quant_config = DEFAULT_SMOOTH_QUANT_CONFIG + quant_config.algo_config.scaling_layers = algo_config_info["scaling_layers"] + quant_config.algo_config.model_decoder_layers = algo_config_info["model_decoder_layers"] + quant_config.algo_config.embedding_layers = algo_config_info["embedding_layers"] + return QuarkQuantizationConfig( + global_quant_config=quant_config.global_quant_config, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + algo_config=quant_config.algo_config, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def w_uint4_per_group_asym_gptq( + model_type: str, + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + model_type = AutoQuantizationConfig._validate_model_type(model_type) + algo_config_info = ALGO_CONFIG_PARAMS[model_type] + + quant_config = DEFAULT_GPTQ_CONFIG + quant_config.algo_config.inside_layer_modules = algo_config_info["inside_layer_modules"] + quant_config.algo_config.model_decoder_layers = algo_config_info["model_decoder_layers"] + quant_config.algo_config.embedding_layers = algo_config_info["embedding_layers"] + return QuarkQuantizationConfig( + global_quant_config=quant_config.global_quant_config, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + algo_config=quant_config.algo_config, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def w_uint4_per_group_asym( + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + return QuarkQuantizationConfig( + global_quant_config=DEFAULT_W_UINT4_PER_GROUP_CONFIG, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def w_uint4_a_bfloat16_per_group_asym( + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + return QuarkQuantizationConfig( + global_quant_config=DEFAULT_W_UINT4_A_BFLOAT16_PER_GROUP_CONFIG, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def w_int8_a_int8_per_tensor_sym( + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + return QuarkQuantizationConfig( + global_quant_config=DEFAULT_W_INT8_A_INT8_PER_TENSOR_CONFIG, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def w_int8_a_int8_per_tensor_sym_dynamic( + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + return QuarkQuantizationConfig( + global_quant_config=DEFAULT_W_INT8_A_INT8_PER_TENSOR_DYNAMIC_CONFIG, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + ) + + @staticmethod + def float16( + layer_type_quant_config: Dict[str, QuantizationConfig] = {}, + layer_quant_config: Dict[str, QuantizationConfig] = {}, + exclude: List[str] = [], + pre_quant_opt_config: Optional[QuantizationConfig] = None, + kv_cache_dtype: Optional[KVCacheDType] = None, + ): + if kv_cache_dtype: + AutoQuantizationConfig._validate_kv_cache_dtype(kv_cache_dtype) + layer_quant_config = AutoQuantizationConfig._apply_fp8_config(layer_quant_config) + + return QuarkQuantizationConfig( + global_quant_config=DEFAULT_FLOAT16_CONFIG, + layer_type_quant_config=layer_type_quant_config, + layer_quant_config=layer_quant_config, + exclude=exclude, + pre_quant_opt_config=pre_quant_opt_config, + ) diff --git a/optimum/amd/quantizers/quark_back/quantizer.py b/optimum/amd/quantizers/quark_back/quantizer.py new file mode 100644 index 00000000..27a0740b --- /dev/null +++ b/optimum/amd/quantizers/quark_back/quantizer.py @@ -0,0 +1,219 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# Licensed under the MIT License. +"""AMD Quark Quantizer""" + +import logging +from typing import Callable, Dict, List, Optional, Union + +import torch +from datasets import Dataset, load_dataset +from quark.torch import ModelExporter, ModelQuantizer +from quark.torch.export.config.custom_config import DEFAULT_EXPORTER_CONFIG, EMPTY_EXPORTER_CONFIG +from torch.utils.data import DataLoader + +from optimum.exporters import TasksManager +from optimum.quantization_base import OptimumQuantizer + +from .configuration import QuarkQuantizationConfig + + +logger = logging.getLogger(__name__) + + +class QuarkQuantizer(OptimumQuantizer): + """ + Handles the quantization process for models shared on huggingface.co/models. + """ + + def __init__( + self, + model: torch.nn.Module, + quantization_config: QuarkQuantizationConfig, + model_name_or_path: Optional[str], + ): + super().__init__() + self.model_name_or_path = model_name_or_path + + self.model = model + self.model_type = model.config.model_type + self.model_dtype = next(model.parameters()).dtype + + # Initialize the quantizer + self.quantizer = ModelQuantizer(quantization_config) + + @classmethod + def from_pretrained( + cls, + model_name_or_path: str, + quantization_config: QuarkQuantizationConfig, + subfolder: str = "", + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + trust_remote_code: bool = False, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + device_map: Optional[Union[Dict, str, torch.device]] = None, + **model_kwargs, + ): + """ + Loads the QuarkQuantizer and model. + + Args: + model_name_or_path (`Union[str, Path]`): + Can be either the model id of a model repo on the Hugging Face Hub, or a path to a local directory + containing a model. + subfolder (`str`, defaults to `""`): + In case the model files are located inside a subfolder of the model directory / repo on the Hugging + Face Hub, you can specify the subfolder name here. + revision (`Optional[str]`, *optional*, defaults to `None`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, *optional*): + Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the + model repository. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[str]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + """ + + # TODO: fix + # task = TasksManager.infer_task_from_model(model_name_or_path) + task = "text-generation" + + device = None + if not isinstance(device_map, dict) and device_map not in ["auto", "balanced"]: + device = device_map + device_map = None + + model = TasksManager.get_model_from_task( + task, + model_name_or_path, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + device_map=device_map, + device=device, + framework="pt", + **model_kwargs, + ).eval() + + return cls(model, quantization_config, model_name_or_path) + + def quantize( + self, + dataloader: Optional[Dataset] = None, + ) -> torch.nn.Module: + """_summary_ + + Args: + dataloader (`Optional[Union[DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]], + DataLoader[Dict[str, torch.Tensor]]]]`, defaults to `None`): + The DataLoader providing data that the quantization process will use for calibration. This can be a simple DataLoader returning + tensors, or a more complex structure returning either a list of dictionaries or a dictionary of tensors. + + Returns: + torch.nn.Module: Quantized model + """ + if not self.quantizer.is_all_dynamic and dataloader is None: + raise ValueError("A calibration dataset is required for the quantization method.") + + self.model = self.quantizer.quantize_model(self.model, dataloader) + + return self.model + + def save_pretrained(self, save_directory: str, no_weight_matrix_merge=False): + """ + Save the quantized model to the specified directory. + + Args: + save_directory (`str`): + Directory to save the quantized model to. + no_weight_matrix_merge (`bool`, defaults to `False`): + Whether to merge weight matrix when dump quantized model + """ + if self.model_type != "llama": + raise ValueError("Only models with dtype `llama` can be saved.") + model = self.quantizer.freeze(self.model) + + with torch.inference_mode(): + export_config = EMPTY_EXPORTER_CONFIG if no_weight_matrix_merge else DEFAULT_EXPORTER_CONFIG + + exporter = ModelExporter(config=export_config, export_dir=save_directory) + exporter.export_model_info(model, self.model_type, self.model_dtype, export_type="vllm-adopt") + + model.config.save_pretrained(save_directory) + + def get_calibration_data( + self, + dataset_name: str, + num_samples: int = 100, + dataset_config_name: Optional[str] = None, + dataset_split: Optional[str] = None, + preprocess_function: Optional[Callable] = None, + preprocess_batch: bool = True, + seed: int = 2016, + token: Optional[Union[bool, str]] = None, + batch_size: int = 1, + ) -> Union[ + DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]], DataLoader[Dict[str, torch.Tensor]] + ]: + """ + Creates the calibration `datasets.Dataset` to use for the post-training static quantization calibration step. + + Args: + dataset_name (`str`): + The dataset repository name on the Hugging Face Hub or path to a local directory containing data files + to load to use for the calibration step. + num_samples (`int`, defaults to 100): + The maximum number of samples composing the calibration dataset. + dataset_config_name (`Optional[str]`, defaults to `None`): + The name of the dataset configuration. + dataset_split (`Optional[str]`, defaults to `None`): + Which split of the dataset to use to perform the calibration step. + preprocess_function (`Optional[Callable]`, defaults to `None`): + Processing function to apply to each example after loading dataset. + preprocess_batch (`bool`, defaults to `True`): + Whether the `preprocess_function` should be batched. + seed (`int`, defaults to 2016): + The random seed to use when shuffling the calibration dataset. + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + + Returns: + The calibration `datasets.Dataset` to use for the post-training static quantization calibration + step. + """ + calib_dataset = load_dataset( + dataset_name, + name=dataset_config_name, + split=dataset_split, + token=token, + ) + + if num_samples is not None: + num_samples = min(num_samples, len(calib_dataset)) + calib_dataset = calib_dataset.shuffle(seed=seed).select(range(num_samples)) + + if preprocess_function is not None: + processed_calib_dataset = calib_dataset.map(preprocess_function, batched=preprocess_batch) + else: + processed_calib_dataset = calib_dataset + + dataloader = DataLoader(processed_calib_dataset, batch_size=batch_size, shuffle=False) + + return dataloader diff --git a/quantize.py b/quantize.py new file mode 100644 index 00000000..a42ba832 --- /dev/null +++ b/quantize.py @@ -0,0 +1,21 @@ +# 1. Set model +from transformers import AutoModelForCausalLM, AutoTokenizer +from optimum.amd.quantizers.quark import AutoQuantizationConfig +from torch.utils.data import DataLoader + +model_id = "meta-llama/Llama-2-7b-chat-hf" + +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# 3. Define calibration dataloader (still need this step for weight only and dynamic quantization) +text = "Hello, how are you?" +tokenized_outputs = tokenizer(text, return_tensors="pt") +calib_dataloader = DataLoader(tokenized_outputs['input_ids']) + +config = AutoQuantizationConfig.from_quant_scheme("w_fp8_a_fp8") +config.dataset = calib_dataloader +model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=config) + +from pdb import set_trace; set_trace() +# quant_model = quantizer.quantize(calib_dataloader) +# quantizer.save_pretrained("quantized_model_quantizer") \ No newline at end of file diff --git a/setup.py b/setup.py index 4eedceb7..41b3c145 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,12 @@ author_email="hardware@huggingface.co", license="MIT", packages=find_namespace_packages(include=["optimum*"]), - entry_points={"console_scripts": ["amdrun=optimum.amd.cli:amdrun"]}, + entry_points={ + "console_scripts": ["amdrun=optimum.amd.cli:amdrun"], + "hf_quantizers": [ + "quark = optimum.amd.quantizers:QuarkPlugin" + ] + }, install_requires=INSTALL_REQUIRE, extras_require=EXTRAS_REQUIRE, package_data={"optimum": ["amd/ryzenai/configs/*.json"]},