-
Notifications
You must be signed in to change notification settings - Fork 107
Fix module.to("meta") for models with plain Tensors
#1688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -22,6 +22,32 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from auto_round.utils import get_lm_head_name, get_module | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _safe_to_meta(module: torch.nn.Module) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Move a module to meta device, handling plain Tensors in ``_parameters``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PyTorch's ``module.to("meta")`` asserts ``isinstance(param, Parameter)`` for | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| every entry in ``_parameters``. Some third-party modules (e.g. fla's | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``FusedRMSNormGated`` used by Qwen3-Next) store plain ``torch.Tensor`` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| objects there, causing ``AssertionError``. This helper moves both | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``Parameter`` and plain ``Tensor`` entries safely. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| See https://github.com/intel/auto-round/issues/1499 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for child in module.children(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _safe_to_meta(child) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for name, p in list(module._parameters.items()): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if p is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(p, torch.nn.Parameter): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| module._parameters[name] = torch.nn.Parameter( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.empty_like(p, device="meta"), requires_grad=p.requires_grad | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| module._parameters[name] = torch.empty_like(p, device="meta") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for name, b in list(module._buffers.items()): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if b is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| module._buffers[name] = torch.empty_like(b, device="meta") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+32
to
+48
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``Parameter`` and plain ``Tensor`` entries safely. | |
| See https://github.com/intel/auto-round/issues/1499 | |
| """ | |
| for child in module.children(): | |
| _safe_to_meta(child) | |
| for name, p in list(module._parameters.items()): | |
| if p is not None: | |
| if isinstance(p, torch.nn.Parameter): | |
| module._parameters[name] = torch.nn.Parameter( | |
| torch.empty_like(p, device="meta"), requires_grad=p.requires_grad | |
| ) | |
| else: | |
| module._parameters[name] = torch.empty_like(p, device="meta") | |
| for name, b in list(module._buffers.items()): | |
| if b is not None: | |
| module._buffers[name] = torch.empty_like(b, device="meta") | |
| ``Parameter`` and plain ``Tensor`` entries safely while preserving object | |
| identity for tied/shared weights. | |
| See https://github.com/intel/auto-round/issues/1499 | |
| """ | |
| for child in module.children(): | |
| _safe_to_meta(child) | |
| with torch.no_grad(): | |
| for name, p in list(module._parameters.items()): | |
| if p is not None: | |
| p.data = torch.empty_like(p, device="meta") | |
| for name, b in list(module._buffers.items()): | |
| if b is not None: | |
| b.data = torch.empty_like(b, device="meta") |
Copilot
AI
Apr 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change switches one .to("meta") call site to _safe_to_meta, but ShardWriter.finalize() still calls lm_head_module.to("meta") when tie_word_embeddings is true (auto_round/compressors/shard_writer.py around line ~231). That remaining call can hit the same AssertionError for third-party modules with plain tensors in _parameters, so it should also use _safe_to_meta for consistency with the PR’s goal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For non-Parameter entries in module.parameters, torch.empty_like(...) will default to requires_grad=False (even if the original tensor had requires_grad=True). If you keep supporting plain tensors here, preserve their autograd flag (e.g., by passing requires_grad=p.requires_grad where supported or reapplying requires_grad afterward).