Fix module.to("meta") for models with plain Tensors#1688
Fix module.to("meta") for models with plain Tensors#1688
module.to("meta") for models with plain Tensors#1688Conversation
…#1499) PyTorch's module.to("meta") asserts isinstance(param, Parameter) for every entry in _parameters. Third-party modules like fla's FusedRMSNormGated (used by Qwen3-Next) store plain torch.Tensor objects there, causing AssertionError during shard_writer offloading. Add _safe_to_meta() helper that manually moves both Parameter and plain Tensor entries to meta device without triggering the assertion. Replace all 6 .to("meta") call sites in shard_writer.py and base.py. Closes #1499 Signed-off-by: yiliu30 <yi4.liu@intel.com>
|
Azure Pipelines: Successfully started running 6 pipeline(s). 1 pipeline(s) require an authorized user to comment /azp run to run. |
There was a problem hiding this comment.
Pull request overview
Fixes failures when offloading modules to the meta device in cases where third-party torch.nn.Module implementations store plain torch.Tensor objects inside module._parameters (triggering an AssertionError in module.to("meta")).
Changes:
- Added
_safe_to_meta()helper to move modules (including plain tensors in_parameters) ontometawithout hitting PyTorch’s assertion. - Replaced several
.to("meta")call sites in shard-writing / quantization flows with_safe_to_meta(). - Added a regression test covering a module that injects a plain tensor into
_parameters.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
auto_round/compressors/shard_writer.py |
Introduces _safe_to_meta() and uses it during shard-writer offload-to-meta. |
auto_round/compressors/base.py |
Switches immediate-save/offload logic to use _safe_to_meta() instead of .to("meta"). |
test/test_cpu/utils/test_shard_writer.py |
Adds regression coverage for modules with plain tensors in _parameters. |
| ``Parameter`` and plain ``Tensor`` entries safely. | ||
|
|
||
| See https://github.com/intel/auto-round/issues/1499 | ||
| """ | ||
| for child in module.children(): | ||
| _safe_to_meta(child) | ||
| for name, p in list(module._parameters.items()): | ||
| if p is not None: | ||
| if isinstance(p, torch.nn.Parameter): | ||
| module._parameters[name] = torch.nn.Parameter( | ||
| torch.empty_like(p, device="meta"), requires_grad=p.requires_grad | ||
| ) | ||
| else: | ||
| module._parameters[name] = torch.empty_like(p, device="meta") | ||
| for name, b in list(module._buffers.items()): | ||
| if b is not None: | ||
| module._buffers[name] = torch.empty_like(b, device="meta") |
There was a problem hiding this comment.
_safe_to_meta replaces entries in module._parameters with brand new Tensor/Parameter objects. This differs from torch.nn.Module.to/_apply (which mutates .data in-place) and can break shared-weight bindings (e.g., tied embeddings) that the codebase explicitly tries to preserve (see auto_round/utils/model.py:40-65). Consider switching to an in-place update (under torch.no_grad) that preserves object identity, e.g. update p.data / b.data to a meta empty_like, rather than reassigning dict entries.
| ``Parameter`` and plain ``Tensor`` entries safely. | |
| See https://github.com/intel/auto-round/issues/1499 | |
| """ | |
| for child in module.children(): | |
| _safe_to_meta(child) | |
| for name, p in list(module._parameters.items()): | |
| if p is not None: | |
| if isinstance(p, torch.nn.Parameter): | |
| module._parameters[name] = torch.nn.Parameter( | |
| torch.empty_like(p, device="meta"), requires_grad=p.requires_grad | |
| ) | |
| else: | |
| module._parameters[name] = torch.empty_like(p, device="meta") | |
| for name, b in list(module._buffers.items()): | |
| if b is not None: | |
| module._buffers[name] = torch.empty_like(b, device="meta") | |
| ``Parameter`` and plain ``Tensor`` entries safely while preserving object | |
| identity for tied/shared weights. | |
| See https://github.com/intel/auto-round/issues/1499 | |
| """ | |
| for child in module.children(): | |
| _safe_to_meta(child) | |
| with torch.no_grad(): | |
| for name, p in list(module._parameters.items()): | |
| if p is not None: | |
| p.data = torch.empty_like(p, device="meta") | |
| for name, b in list(module._buffers.items()): | |
| if b is not None: | |
| b.data = torch.empty_like(b, device="meta") |
| torch.empty_like(p, device="meta"), requires_grad=p.requires_grad | ||
| ) | ||
| else: | ||
| module._parameters[name] = torch.empty_like(p, device="meta") |
There was a problem hiding this comment.
For non-Parameter entries in module.parameters, torch.empty_like(...) will default to requires_grad=False (even if the original tensor had requires_grad=True). If you keep supporting plain tensors here, preserve their autograd flag (e.g., by passing requires_grad=p.requires_grad where supported or reapplying requires_grad afterward).
| module._parameters[name] = torch.empty_like(p, device="meta") | |
| module._parameters[name] = torch.empty_like(p, device="meta").requires_grad_(p.requires_grad) |
| and isinstance(module, torch.nn.Module) | ||
| and all(f"{module_path}.{k}" in self._all_saved for k in module.state_dict().keys()) | ||
| ): | ||
| module.to("meta") | ||
| _safe_to_meta(module) | ||
|
|
There was a problem hiding this comment.
This change switches one .to("meta") call site to _safe_to_meta, but ShardWriter.finalize() still calls lm_head_module.to("meta") when tie_word_embeddings is true (auto_round/compressors/shard_writer.py around line ~231). That remaining call can hit the same AssertionError for third-party modules with plain tensors in _parameters, so it should also use _safe_to_meta for consistency with the PR’s goal.
|
Azure Pipelines: Successfully started running 6 pipeline(s). 1 pipeline(s) require an authorized user to comment /azp run to run. |
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines: Successfully started running 1 pipeline(s). |
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines: Successfully started running 1 pipeline(s). |
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines: Successfully started running 1 pipeline(s). |
Fix #1499
Description
Please briefly describe your main changes, the motivation.
Type of Change
Related Issues
Fixes or relates to #
Checklist Before Submitting