Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from auto_round import envs
from auto_round.auto_scheme.gen_auto_scheme import AutoScheme
from auto_round.compressors.shard_writer import shard_writer
from auto_round.compressors.shard_writer import _safe_to_meta, shard_writer
from auto_round.compressors.utils import (
IndexSampler,
block_forward,
Expand Down Expand Up @@ -62,12 +62,7 @@
from auto_round.logger import logger
from auto_round.modeling.fused_moe.replace_modules import materialize_model_, safe_to_cpu_
from auto_round.modeling.unfused_moe import apply_model_monkey_patches
from auto_round.schemes import (
QuantizationScheme,
_handle_special_schemes,
get_gguf_scheme,
preset_name_to_scheme,
)
from auto_round.schemes import QuantizationScheme, _handle_special_schemes, get_gguf_scheme, preset_name_to_scheme
from auto_round.sign_sgd import SignSGD
from auto_round.special_model_handler import get_predefined_ignore_layers, update_module
from auto_round.utils import (
Expand Down Expand Up @@ -1319,7 +1314,7 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T
# Free RAM immediately: the data is now in the shard-writer buffer
# (and will be flushed to disk). Keeping it also in the model tree
# causes linear RAM growth for large models.
m.to("meta")
_safe_to_meta(m)

def _immediate_pack(self, name: str):
if not self.is_immediate_packing:
Expand Down Expand Up @@ -1444,8 +1439,8 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
if self.is_immediate_saving:
shard_writer(self, name=m.global_name)
copied_m = get_module(self.model, m.global_name)
copied_m.to("meta")
m.to("meta")
_safe_to_meta(copied_m)
_safe_to_meta(m)
# Move remaining GPU tensors to CPU; offload to disk if low_cpu_mem_usage.
# This mirrors _quantize_via_rtn_blockwise's post-block cleanup.
if not self.is_immediate_saving:
Expand All @@ -1454,7 +1449,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
# Save once at block scope to capture tensors that are not saved
# in per-layer branch (e.g., custom module-level params/buffers).
shard_writer(self, name=block_name)
block.to("meta")
_safe_to_meta(block)
if self.low_cpu_mem_usage and not self.is_immediate_saving:
self._offloader(self.model, block_name)
clear_memory(device_list=self.device_list)
Expand Down Expand Up @@ -1494,7 +1489,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
):
set_module(self.model, n, copy.deepcopy(m))
shard_writer(self, name=n)
m.to("meta")
_safe_to_meta(m)

# Convert remaining fp8
convert_module_to_hp_if_necessary(self.model, self.amp_dtype, self.device)
Expand Down
28 changes: 27 additions & 1 deletion auto_round/compressors/shard_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,32 @@
from auto_round.utils import get_lm_head_name, get_module


def _safe_to_meta(module: torch.nn.Module) -> None:
"""Move a module to meta device, handling plain Tensors in ``_parameters``.

PyTorch's ``module.to("meta")`` asserts ``isinstance(param, Parameter)`` for
every entry in ``_parameters``. Some third-party modules (e.g. fla's
``FusedRMSNormGated`` used by Qwen3-Next) store plain ``torch.Tensor``
objects there, causing ``AssertionError``. This helper moves both
``Parameter`` and plain ``Tensor`` entries safely.

See https://github.com/intel/auto-round/issues/1499
"""
for child in module.children():
_safe_to_meta(child)
for name, p in list(module._parameters.items()):
if p is not None:
if isinstance(p, torch.nn.Parameter):
module._parameters[name] = torch.nn.Parameter(
torch.empty_like(p, device="meta"), requires_grad=p.requires_grad
)
else:
module._parameters[name] = torch.empty_like(p, device="meta")
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-Parameter entries in module.parameters, torch.empty_like(...) will default to requires_grad=False (even if the original tensor had requires_grad=True). If you keep supporting plain tensors here, preserve their autograd flag (e.g., by passing requires_grad=p.requires_grad where supported or reapplying requires_grad afterward).

Suggested change
module._parameters[name] = torch.empty_like(p, device="meta")
module._parameters[name] = torch.empty_like(p, device="meta").requires_grad_(p.requires_grad)

Copilot uses AI. Check for mistakes.
for name, b in list(module._buffers.items()):
if b is not None:
module._buffers[name] = torch.empty_like(b, device="meta")
Comment on lines +32 to +48
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_safe_to_meta replaces entries in module._parameters with brand new Tensor/Parameter objects. This differs from torch.nn.Module.to/_apply (which mutates .data in-place) and can break shared-weight bindings (e.g., tied embeddings) that the codebase explicitly tries to preserve (see auto_round/utils/model.py:40-65). Consider switching to an in-place update (under torch.no_grad) that preserves object identity, e.g. update p.data / b.data to a meta empty_like, rather than reassigning dict entries.

Suggested change
``Parameter`` and plain ``Tensor`` entries safely.
See https://github.com/intel/auto-round/issues/1499
"""
for child in module.children():
_safe_to_meta(child)
for name, p in list(module._parameters.items()):
if p is not None:
if isinstance(p, torch.nn.Parameter):
module._parameters[name] = torch.nn.Parameter(
torch.empty_like(p, device="meta"), requires_grad=p.requires_grad
)
else:
module._parameters[name] = torch.empty_like(p, device="meta")
for name, b in list(module._buffers.items()):
if b is not None:
module._buffers[name] = torch.empty_like(b, device="meta")
``Parameter`` and plain ``Tensor`` entries safely while preserving object
identity for tied/shared weights.
See https://github.com/intel/auto-round/issues/1499
"""
for child in module.children():
_safe_to_meta(child)
with torch.no_grad():
for name, p in list(module._parameters.items()):
if p is not None:
p.data = torch.empty_like(p, device="meta")
for name, b in list(module._buffers.items()):
if b is not None:
b.data = torch.empty_like(b, device="meta")

Copilot uses AI. Check for mistakes.


class ShardWriter:
"""
Handles shard-saving of model parameters to disk with memory management.
Expand Down Expand Up @@ -195,7 +221,7 @@ def _offload_to_meta(self, saved_params):
and isinstance(module, torch.nn.Module)
and all(f"{module_path}.{k}" in self._all_saved for k in module.state_dict().keys())
):
module.to("meta")
_safe_to_meta(module)

Comment on lines 221 to 225
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change switches one .to("meta") call site to _safe_to_meta, but ShardWriter.finalize() still calls lm_head_module.to("meta") when tie_word_embeddings is true (auto_round/compressors/shard_writer.py around line ~231). That remaining call can hit the same AssertionError for third-party modules with plain tensors in _parameters, so it should also use _safe_to_meta for consistency with the PR’s goal.

Copilot uses AI. Check for mistakes.
def finalize(self):
"""Saves remaining weights, renames files, and writes the index JSON."""
Expand Down
32 changes: 32 additions & 0 deletions test/test_cpu/utils/test_shard_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,35 @@ def test_finalize_skips_lm_head_when_tie_word_embeddings_true(tmp_path):
assert "transformer_blocks.0.linear.weight" in saved_tensors
assert "lm_head.weight" not in saved_tensors, "lm_head must be skipped when tied"
assert model.lm_head.weight.device.type == "meta"


class _BlockWithPlainTensor(torch.nn.Module):
"""Simulates third-party modules that store plain Tensors in _parameters."""

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
# Inject a plain Tensor (not nn.Parameter) into _parameters,
# mimicking what some fla/CUDA extension modules do.
self._parameters["plain_weight"] = torch.zeros(4)


def test_finalize_handles_module_with_plain_tensor_in_parameters(tmp_path):
"""_offload_to_meta must not crash when a module has a plain torch.Tensor
in _parameters (regression test for issue #1499)."""
model = _DiffusionStyleModel()
# Replace one block with one that has a plain tensor
model.transformer_blocks[0] = _BlockWithPlainTensor()
rounder = _RounderStub(model, str(tmp_path))
writer = ShardWriter(rounder)

writer.save_module(model.transformer_blocks[0], "transformer_blocks.0")
writer.finalize() # Must not raise AssertionError

shard_path = os.path.join(tmp_path, "model.bin")
saved_tensors = torch.load(shard_path, map_location="cpu")
assert "transformer_blocks.0.linear.weight" in saved_tensors
assert "transformer_blocks.0.plain_weight" in saved_tensors
# The offloaded module should be on meta device
assert model.transformer_blocks[0].linear.weight.device.type == "meta"
assert model.transformer_blocks[0]._parameters["plain_weight"].device.type == "meta"
Loading