Skip to content

Fix module.to("meta") for models with plain Tensors#1688

Open
yiliu30 wants to merge 2 commits intomainfrom
fix-to-meta-assertion-error-1499
Open

Fix module.to("meta") for models with plain Tensors#1688
yiliu30 wants to merge 2 commits intomainfrom
fix-to-meta-assertion-error-1499

Conversation

@yiliu30
Copy link
Copy Markdown
Contributor

@yiliu30 yiliu30 commented Apr 15, 2026

Fix #1499

Description

Please briefly describe your main changes, the motivation.

Type of Change

  • Bug fix
  • New feature
  • Documentation update
  • Performance improvement
  • Code refactoring
  • Other (please specify):

Related Issues

Fixes or relates to #

Checklist Before Submitting

  • My code has been tested locally.
  • Documentation has been updated as needed.
  • New or updated tests are included where applicable.

…#1499)

PyTorch's module.to("meta") asserts isinstance(param, Parameter) for
every entry in _parameters.  Third-party modules like fla's
FusedRMSNormGated (used by Qwen3-Next) store plain torch.Tensor objects
there, causing AssertionError during shard_writer offloading.

Add _safe_to_meta() helper that manually moves both Parameter and plain
Tensor entries to meta device without triggering the assertion.  Replace
all 6 .to("meta") call sites in shard_writer.py and base.py.

Closes #1499

Signed-off-by: yiliu30 <yi4.liu@intel.com>
Copilot AI review requested due to automatic review settings April 15, 2026 08:49
@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines:
Successfully started running 6 pipeline(s).
1 pipeline(s) require an authorized user to comment /azp run to run.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes failures when offloading modules to the meta device in cases where third-party torch.nn.Module implementations store plain torch.Tensor objects inside module._parameters (triggering an AssertionError in module.to("meta")).

Changes:

  • Added _safe_to_meta() helper to move modules (including plain tensors in _parameters) onto meta without hitting PyTorch’s assertion.
  • Replaced several .to("meta") call sites in shard-writing / quantization flows with _safe_to_meta().
  • Added a regression test covering a module that injects a plain tensor into _parameters.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
auto_round/compressors/shard_writer.py Introduces _safe_to_meta() and uses it during shard-writer offload-to-meta.
auto_round/compressors/base.py Switches immediate-save/offload logic to use _safe_to_meta() instead of .to("meta").
test/test_cpu/utils/test_shard_writer.py Adds regression coverage for modules with plain tensors in _parameters.

Comment on lines +32 to +48
``Parameter`` and plain ``Tensor`` entries safely.

See https://github.com/intel/auto-round/issues/1499
"""
for child in module.children():
_safe_to_meta(child)
for name, p in list(module._parameters.items()):
if p is not None:
if isinstance(p, torch.nn.Parameter):
module._parameters[name] = torch.nn.Parameter(
torch.empty_like(p, device="meta"), requires_grad=p.requires_grad
)
else:
module._parameters[name] = torch.empty_like(p, device="meta")
for name, b in list(module._buffers.items()):
if b is not None:
module._buffers[name] = torch.empty_like(b, device="meta")
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_safe_to_meta replaces entries in module._parameters with brand new Tensor/Parameter objects. This differs from torch.nn.Module.to/_apply (which mutates .data in-place) and can break shared-weight bindings (e.g., tied embeddings) that the codebase explicitly tries to preserve (see auto_round/utils/model.py:40-65). Consider switching to an in-place update (under torch.no_grad) that preserves object identity, e.g. update p.data / b.data to a meta empty_like, rather than reassigning dict entries.

Suggested change
``Parameter`` and plain ``Tensor`` entries safely.
See https://github.com/intel/auto-round/issues/1499
"""
for child in module.children():
_safe_to_meta(child)
for name, p in list(module._parameters.items()):
if p is not None:
if isinstance(p, torch.nn.Parameter):
module._parameters[name] = torch.nn.Parameter(
torch.empty_like(p, device="meta"), requires_grad=p.requires_grad
)
else:
module._parameters[name] = torch.empty_like(p, device="meta")
for name, b in list(module._buffers.items()):
if b is not None:
module._buffers[name] = torch.empty_like(b, device="meta")
``Parameter`` and plain ``Tensor`` entries safely while preserving object
identity for tied/shared weights.
See https://github.com/intel/auto-round/issues/1499
"""
for child in module.children():
_safe_to_meta(child)
with torch.no_grad():
for name, p in list(module._parameters.items()):
if p is not None:
p.data = torch.empty_like(p, device="meta")
for name, b in list(module._buffers.items()):
if b is not None:
b.data = torch.empty_like(b, device="meta")

Copilot uses AI. Check for mistakes.
torch.empty_like(p, device="meta"), requires_grad=p.requires_grad
)
else:
module._parameters[name] = torch.empty_like(p, device="meta")
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-Parameter entries in module.parameters, torch.empty_like(...) will default to requires_grad=False (even if the original tensor had requires_grad=True). If you keep supporting plain tensors here, preserve their autograd flag (e.g., by passing requires_grad=p.requires_grad where supported or reapplying requires_grad afterward).

Suggested change
module._parameters[name] = torch.empty_like(p, device="meta")
module._parameters[name] = torch.empty_like(p, device="meta").requires_grad_(p.requires_grad)

Copilot uses AI. Check for mistakes.
Comment on lines 199 to 203
and isinstance(module, torch.nn.Module)
and all(f"{module_path}.{k}" in self._all_saved for k in module.state_dict().keys())
):
module.to("meta")
_safe_to_meta(module)

Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change switches one .to("meta") call site to _safe_to_meta, but ShardWriter.finalize() still calls lm_head_module.to("meta") when tie_word_embeddings is true (auto_round/compressors/shard_writer.py around line ~231). That remaining call can hit the same AssertionError for third-party modules with plain tensors in _parameters, so it should also use _safe_to_meta for consistency with the PR’s goal.

Copilot uses AI. Check for mistakes.
@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines:
Successfully started running 6 pipeline(s).
1 pipeline(s) require an authorized user to comment /azp run to run.

@yiliu30
Copy link
Copy Markdown
Contributor Author

yiliu30 commented Apr 15, 2026

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines:
Successfully started running 1 pipeline(s).

@yiliu30
Copy link
Copy Markdown
Contributor Author

yiliu30 commented Apr 15, 2026

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines:
Successfully started running 1 pipeline(s).

@yiliu30
Copy link
Copy Markdown
Contributor Author

yiliu30 commented Apr 15, 2026

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines:
Successfully started running 1 pipeline(s).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Qwen3-Next-80B-A3B-Instruct-int4-mixed-AutoRound vLLM inference failed

2 participants