From c5d7b77f65b86fad4e184dd565bbcaa01e9be82f Mon Sep 17 00:00:00 2001 From: silveroxides Date: Sat, 18 Apr 2026 13:03:07 +0200 Subject: [PATCH 1/3] wire int8_linear to ck dispatch (triton->eager), add TODO guards utils/eager_quantization.py: int8_linear now tries ck.int8_linear first, which routes through torch.ops.comfy_kitchen.int8_linear -> registry -> triton backend if available, else eager. Falls back to local chunked torch.int8_mm path on ImportError or any runtime failure. unified_ops.py: add TODO #2 comment on broken-state else branch in _load_from_state_dict (is_tensorwise + no ck = raw int8 + null layout). Add TODO #3 comment on missing 3D input reshape guard in else branch of forward_comfy_cast_weights. --- pyproject.toml | 2 +- unified_ops.py | 27 +++++++++++++++++++++++++-- utils/eager_quantization.py | 23 +++++++++++++++++------ 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8f3dd61..91e396c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "ComfyUI-QuantOps" description = "Extended quantization layouts for ComfyUI (INT8, row/block-wise FP8)" -version = "1.7.2" +version = "1.8.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/unified_ops.py b/unified_ops.py index bfdcf21..29586a9 100644 --- a/unified_ops.py +++ b/unified_ops.py @@ -173,11 +173,14 @@ def _load_from_state_dict( if is_tensorwise and _HAS_TENSORWISE_INT8_LAYOUT: self.layout_type = "TensorWiseINT8Layout" + _orig_dtype_str = layer_conf.get("orig_dtype", "torch.bfloat16") if layer_conf else "torch.bfloat16" + _DTYPE_MAP = {"torch.bfloat16": torch.bfloat16, "torch.float16": torch.float16, "torch.float32": torch.float32} + _orig_dtype = _DTYPE_MAP.get(_orig_dtype_str, torch.bfloat16) layout_params = TensorWiseINT8Layout.Params( scale=scale.to(torch.float32) if scale is not None else None, - orig_dtype=torch.bfloat16, + orig_dtype=_orig_dtype, orig_shape=tuple(weight_tensor.shape), is_weight=True, ) @@ -205,6 +208,14 @@ def _load_from_state_dict( requires_grad=False, ) else: + # TODO (#2 — medium severity, low risk): this branch fires when + # is_tensorwise=True but _HAS_TENSORWISE_INT8_LAYOUT=False (ck absent). + # Result: raw int8 tensor stored with is_quantized=True, layout_type=None. + # That is a broken state — forward() will hit F.linear with raw int8 weight. + # Fix: degrade to BlockWiseINT8Layout if _HAS_INT8_LAYOUT, else set + # is_quantized=False and log a warning. Not patching now because ck is + # effectively required for tensorwise; if ck import failed the checkpoint + # is already unrunnable regardless. self.weight = torch.nn.Parameter( weight_tensor, requires_grad=False ) @@ -463,7 +474,19 @@ def forward_comfy_cast_weights(self, input): ) else: - # Default trigger for QuantizedTensor dispatch -> layout-specific handler + # Default trigger for QuantizedTensor dispatch -> layout-specific handler. + # TensorWiseINT8Layout and BlockWiseINT8Layout land here — aten.linear + # dispatch in comfy_kitchen handles the actual matmul. + # + # TODO (#3 — low-medium severity, medium risk): this else branch has no 3D + # input reshape guard, unlike all the explicit elif branches above. ComfyUI + # transformer attention layers pass [batch, seq, hidden] (3D). F.linear + # handles 3D natively so it works, but ck dispatch handlers may not. If + # tensorwise inference produces wrong shapes on 3D inputs, add the standard + # tensor_3d guard here (reshape -1,hidden before linear, reshape back after). + # Not patching now — risk of breaking currently-working layouts that fall + # through to this branch (e.g. RowWiseFP8, BlockWiseFP8 if aten dispatch + # handles them here too). out = torch.nn.functional.linear(input, weight, bias) else: diff --git a/utils/eager_quantization.py b/utils/eager_quantization.py index cb5b6bb..1b2cde2 100644 --- a/utils/eager_quantization.py +++ b/utils/eager_quantization.py @@ -35,13 +35,24 @@ def int8_linear( bias: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: - """INT8 linear layer using torch.int8_mm for direct quantized matmul. - - Uses native torch.int8_mm which avoids materializing large float32 intermediates - and handles scaling more efficiently than manual int32 -> float32 conversion. - - Ported from comfy-kitchen eager backend with OOM fixes. + """INT8 linear layer. Delegates to comfy_kitchen.int8_linear (triton->eager) + when available, falls back to local torch.int8_mm chunked path. + + ck.int8_linear signature matches exactly: + (x, weight, weight_scale, bias=None, out_dtype=None) + weight: [N, K] int8, weight_scale: scalar float32, out_dtype defaults bfloat16. """ + # Prefer comfy_kitchen dispatch (triton -> eager via registry). + # ck.int8_linear routes through torch.ops.comfy_kitchen.int8_linear which + # goes through the registry with priority ["cuda", "triton", "eager"]. + # cuda backend has no int8_linear, so triton wins if available, else eager. + try: + import comfy_kitchen as ck + return ck.int8_linear(x, weight, weight_scale, bias, out_dtype) + except (ImportError, Exception): + pass + + # --- Local fallback: chunked torch.int8_mm path (OOM-safe) --- orig_shape = x.shape x_2d = x.reshape(-1, x.shape[-1]) From 0c1e432b66624d83d6cdc53ca0e5b46eb815a955 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Sat, 18 Apr 2026 18:40:35 +0200 Subject: [PATCH 2/3] Fix exception --- utils/eager_quantization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/utils/eager_quantization.py b/utils/eager_quantization.py index 1b2cde2..5935a1e 100644 --- a/utils/eager_quantization.py +++ b/utils/eager_quantization.py @@ -49,8 +49,11 @@ def int8_linear( try: import comfy_kitchen as ck return ck.int8_linear(x, weight, weight_scale, bias, out_dtype) - except (ImportError, Exception): + except ImportError: pass + except Exception as e: + import logging + logging.warning(f"ComfyUI-QuantOps: ck.int8_linear failed, falling back to local path: {e}") # --- Local fallback: chunked torch.int8_mm path (OOM-safe) --- orig_shape = x.shape From 2bbe18796835fd6c4e281eb399975ce1e1d6920e Mon Sep 17 00:00:00 2001 From: silveroxides Date: Sat, 18 Apr 2026 18:41:38 +0200 Subject: [PATCH 3/3] Unwrap QuantizedTensor if weight arrived still wrapped in int8 tensorwise fallback --- utils/eager_quantization.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/utils/eager_quantization.py b/utils/eager_quantization.py index 5935a1e..7cd3d87 100644 --- a/utils/eager_quantization.py +++ b/utils/eager_quantization.py @@ -56,6 +56,19 @@ def int8_linear( logging.warning(f"ComfyUI-QuantOps: ck.int8_linear failed, falling back to local path: {e}") # --- Local fallback: chunked torch.int8_mm path (OOM-safe) --- + # Unwrap QuantizedTensor if weight arrived still wrapped (defensive). + try: + from comfy.quant_ops import QuantizedTensor + if isinstance(weight, QuantizedTensor): + weight_scale = weight._params.scale + weight = weight._qdata + except ImportError: + pass + + # Ensure weight is raw int8 and contiguous before torch.int8_mm. + if not weight.is_contiguous(): + weight = weight.contiguous() + orig_shape = x.shape x_2d = x.reshape(-1, x.shape[-1])