[REFACTOR] major refactor for configs#478
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the codebase by restructuring and modularizing model configurations, parameter utilities, MLP/MoE blocks, and attention mechanisms into cleaner directory structures, updating all corresponding imports and tests. The review feedback highlights a critical runtime crash risk in experts.py due to duplicate registration of the lm_engine::bincount custom operator and compute_bincount helper. Additionally, several __init__ methods violate PEP 484 by annotating their return type as the class itself instead of None, and potential NameError exceptions should be mitigated by importing conditionally loaded functions locally within their usage blocks.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| # TODO add support for combileable bincount in PyTorch directly | ||
| @torch.library.custom_op("lm_engine::bincount", mutates_args={}) | ||
| def bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: | ||
| return x.bincount(minlength=minlength).to(torch.uint32) | ||
|
|
||
|
|
||
| @bincount.register_fake | ||
| def _(x: torch.Tensor, minlength: int) -> torch.Tensor: | ||
| return torch.empty(minlength, device=x.device, dtype=torch.uint32) | ||
|
|
||
|
|
||
| def compute_bincount(x: torch.Tensor, size: int, use_continuous_count: bool) -> torch.Tensor: | ||
| if use_continuous_count: | ||
| count = continuous_count(x, bins=size) | ||
| else: | ||
| count = bincount(x, minlength=size) | ||
|
|
||
| return count |
There was a problem hiding this comment.
This file defines a duplicate registration of the custom operator lm_engine::bincount and the helper function compute_bincount, both of which are already defined in lm_engine/modeling_utils/mlp_blocks/moe/utils.py. Since both experts.py and utils.py are imported when using the MoE module, registering the same custom operator twice will raise a runtime crash (ValueError: Operator lm_engine::bincount already registered). Additionally, these functions are completely unused in this file. Please remove them entirely.
| def __init__( | ||
| self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None | ||
| ) -> ParameterizedExperts: |
There was a problem hiding this comment.
The __init__ method should be annotated with -> None instead of the class type -> ParameterizedExperts. Returning anything other than None from __init__ is a type checking violation under PEP 484 and will be flagged by static analysis tools like mypy or pyright.
def __init__(\n self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None\n ) -> None:| def __init__( | ||
| self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None | ||
| ) -> ColumnParallelExperts: |
There was a problem hiding this comment.
The __init__ method should be annotated with -> None instead of the class type -> ColumnParallelExperts. Returning anything other than None from __init__ is a type checking violation under PEP 484 and will be flagged by static analysis tools like mypy or pyright.
def __init__(\n self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None\n ) -> None:| def __init__( | ||
| self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None | ||
| ) -> RowParallelExperts: |
There was a problem hiding this comment.
The __init__ method should be annotated with -> None instead of the class type -> RowParallelExperts. Returning anything other than None from __init__ is a type checking violation under PEP 484 and will be flagged by static analysis tools like mypy or pyright.
def __init__(\n self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None\n ) -> None:| use_depth_scaled_init: bool, | ||
| use_padding_free_transformer: bool = False, | ||
| sequence_parallel: bool = False, | ||
| ) -> Attention: | ||
| ) -> SoftmaxAttention: |
There was a problem hiding this comment.
The __init__ method should be annotated with -> None instead of the class type -> SoftmaxAttention. Returning anything other than None from __init__ is a type checking violation under PEP 484 and will be flagged by static analysis tools like mypy or pyright.
use_depth_scaled_init: bool,\n use_padding_free_transformer: bool = False,\n sequence_parallel: bool = False,\n ) -> None:| def compute_bincount(x: torch.Tensor, size: int, use_continuous_count: bool) -> torch.Tensor: | ||
| if use_continuous_count: | ||
| count = continuous_count(x, bins=size) | ||
| else: | ||
| count = bincount(x, minlength=size) | ||
|
|
||
| return count |
There was a problem hiding this comment.
To prevent potential NameError exceptions if continuous_count is not successfully imported due to conditional module-level imports, import continuous_count locally inside the if block with a clear error message if the package is missing.
def compute_bincount(x: torch.Tensor, size: int, use_continuous_count: bool) -> torch.Tensor:\n if use_continuous_count:\n try:\n from xma import continuous_count\n except ImportError:\n raise RuntimeError(\"continuous_count is not available. Ensure xma is installed.\")\n count = continuous_count(x, bins=size)\n else:\n count = bincount(x, minlength=size)\n\n return count| if is_kernel_allowed(Kernel.scattermoe): | ||
| x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False) | ||
|
|
||
| x = up_projection_experts( | ||
| x=x, | ||
| expert_weights=weight.permute(0, 2, 1), | ||
| k=num_experts_per_token, | ||
| sorted_expert_idxs=sorted_expert_idxs, | ||
| sorted_scattered_idxs=sorted_scattered_idxs, | ||
| expert_offsets=expert_offsets, | ||
| ) |
There was a problem hiding this comment.
To prevent potential NameError exceptions if up_projection_experts is not successfully imported due to conditional module-level imports, import up_projection_experts locally inside the if block with a clear error message if the package is missing.
if is_kernel_allowed(Kernel.scattermoe):\n try:\n from xma.layers.moe import up_projection_experts\n except ImportError:\n raise RuntimeError(\"scattermoe kernel is allowed but xma.layers.moe is not available.\")\n x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False)\n\n x = up_projection_experts(\n x=x,\n expert_weights=weight.permute(0, 2, 1),\n k=num_experts_per_token,\n sorted_expert_idxs=sorted_expert_idxs,\n sorted_scattered_idxs=sorted_scattered_idxs,\n expert_offsets=expert_offsets,\n )| if is_kernel_allowed(Kernel.scattermoe): | ||
| x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False) | ||
|
|
||
| x = down_projection_experts( | ||
| x=x, | ||
| expert_weights=weight.permute(0, 2, 1), | ||
| k=num_experts_per_token, | ||
| sorted_expert_idxs=sorted_expert_idxs, | ||
| sorted_scattered_idxs=sorted_scattered_idxs, | ||
| expert_offsets=expert_offsets, | ||
| router_weights=router_weights, | ||
| ) |
There was a problem hiding this comment.
To prevent potential NameError exceptions if down_projection_experts is not successfully imported due to conditional module-level imports, import down_projection_experts locally inside the if block with a clear error message if the package is missing.
if is_kernel_allowed(Kernel.scattermoe):\n try:\n from xma.layers.moe import down_projection_experts\n except ImportError:\n raise RuntimeError(\"scattermoe kernel is allowed but xma.layers.moe is not available.\")\n x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False)\n\n x = down_projection_experts(\n x=x,\n expert_weights=weight.permute(0, 2, 1),\n k=num_experts_per_token,\n sorted_expert_idxs=sorted_expert_idxs,\n sorted_scattered_idxs=sorted_scattered_idxs,\n expert_offsets=expert_offsets,\n router_weights=router_weights,\n )
No description provided.