Skip to content

[REFACTOR] major refactor for configs#478

Merged
mayank31398 merged 26 commits into
mainfrom
cl
Jun 24, 2026
Merged

[REFACTOR] major refactor for configs#478
mayank31398 merged 26 commits into
mainfrom
cl

Conversation

@mayank31398

Copy link
Copy Markdown
Collaborator

No description provided.

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the codebase by restructuring and modularizing model configurations, parameter utilities, MLP/MoE blocks, and attention mechanisms into cleaner directory structures, updating all corresponding imports and tests. The review feedback highlights a critical runtime crash risk in experts.py due to duplicate registration of the lm_engine::bincount custom operator and compute_bincount helper. Additionally, several __init__ methods violate PEP 484 by annotating their return type as the class itself instead of None, and potential NameError exceptions should be mitigated by importing conditionally loaded functions locally within their usage blocks.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +29 to +46
# TODO add support for combileable bincount in PyTorch directly
@torch.library.custom_op("lm_engine::bincount", mutates_args={})
def bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
return x.bincount(minlength=minlength).to(torch.uint32)


@bincount.register_fake
def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
return torch.empty(minlength, device=x.device, dtype=torch.uint32)


def compute_bincount(x: torch.Tensor, size: int, use_continuous_count: bool) -> torch.Tensor:
if use_continuous_count:
count = continuous_count(x, bins=size)
else:
count = bincount(x, minlength=size)

return count

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This file defines a duplicate registration of the custom operator lm_engine::bincount and the helper function compute_bincount, both of which are already defined in lm_engine/modeling_utils/mlp_blocks/moe/utils.py. Since both experts.py and utils.py are imported when using the MoE module, registering the same custom operator twice will raise a runtime crash (ValueError: Operator lm_engine::bincount already registered). Additionally, these functions are completely unused in this file. Please remove them entirely.

Comment on lines +60 to +62
def __init__(
self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None
) -> ParameterizedExperts:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method should be annotated with -> None instead of the class type -> ParameterizedExperts. Returning anything other than None from __init__ is a type checking violation under PEP 484 and will be flagged by static analysis tools like mypy or pyright.

    def __init__(\n        self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None\n    ) -> None:

Comment on lines +92 to +94
def __init__(
self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None
) -> ColumnParallelExperts:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method should be annotated with -> None instead of the class type -> ColumnParallelExperts. Returning anything other than None from __init__ is a type checking violation under PEP 484 and will be flagged by static analysis tools like mypy or pyright.

    def __init__(\n        self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None\n    ) -> None:

Comment on lines +164 to +166
def __init__(
self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None
) -> RowParallelExperts:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method should be annotated with -> None instead of the class type -> RowParallelExperts. Returning anything other than None from __init__ is a type checking violation under PEP 484 and will be flagged by static analysis tools like mypy or pyright.

    def __init__(\n        self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None\n    ) -> None:

Comment on lines 45 to +48
use_depth_scaled_init: bool,
use_padding_free_transformer: bool = False,
sequence_parallel: bool = False,
) -> Attention:
) -> SoftmaxAttention:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method should be annotated with -> None instead of the class type -> SoftmaxAttention. Returning anything other than None from __init__ is a type checking violation under PEP 484 and will be flagged by static analysis tools like mypy or pyright.

        use_depth_scaled_init: bool,\n        use_padding_free_transformer: bool = False,\n        sequence_parallel: bool = False,\n    ) -> None:

Comment on lines +27 to +33
def compute_bincount(x: torch.Tensor, size: int, use_continuous_count: bool) -> torch.Tensor:
if use_continuous_count:
count = continuous_count(x, bins=size)
else:
count = bincount(x, minlength=size)

return count

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential NameError exceptions if continuous_count is not successfully imported due to conditional module-level imports, import continuous_count locally inside the if block with a clear error message if the package is missing.

def compute_bincount(x: torch.Tensor, size: int, use_continuous_count: bool) -> torch.Tensor:\n    if use_continuous_count:\n        try:\n            from xma import continuous_count\n        except ImportError:\n            raise RuntimeError(\"continuous_count is not available. Ensure xma is installed.\")\n        count = continuous_count(x, bins=size)\n    else:\n        count = bincount(x, minlength=size)\n\n    return count

Comment on lines +137 to +147
if is_kernel_allowed(Kernel.scattermoe):
x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False)

x = up_projection_experts(
x=x,
expert_weights=weight.permute(0, 2, 1),
k=num_experts_per_token,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
expert_offsets=expert_offsets,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential NameError exceptions if up_projection_experts is not successfully imported due to conditional module-level imports, import up_projection_experts locally inside the if block with a clear error message if the package is missing.

        if is_kernel_allowed(Kernel.scattermoe):\n            try:\n                from xma.layers.moe import up_projection_experts\n            except ImportError:\n                raise RuntimeError(\"scattermoe kernel is allowed but xma.layers.moe is not available.\")\n            x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False)\n\n            x = up_projection_experts(\n                x=x,\n                expert_weights=weight.permute(0, 2, 1),\n                k=num_experts_per_token,\n                sorted_expert_idxs=sorted_expert_idxs,\n                sorted_scattered_idxs=sorted_scattered_idxs,\n                expert_offsets=expert_offsets,\n            )

Comment on lines +212 to +223
if is_kernel_allowed(Kernel.scattermoe):
x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False)

x = down_projection_experts(
x=x,
expert_weights=weight.permute(0, 2, 1),
k=num_experts_per_token,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
expert_offsets=expert_offsets,
router_weights=router_weights,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential NameError exceptions if down_projection_experts is not successfully imported due to conditional module-level imports, import down_projection_experts locally inside the if block with a clear error message if the package is missing.

        if is_kernel_allowed(Kernel.scattermoe):\n            try:\n                from xma.layers.moe import down_projection_experts\n            except ImportError:\n                raise RuntimeError(\"scattermoe kernel is allowed but xma.layers.moe is not available.\")\n            x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False)\n\n            x = down_projection_experts(\n                x=x,\n                expert_weights=weight.permute(0, 2, 1),\n                k=num_experts_per_token,\n                sorted_expert_idxs=sorted_expert_idxs,\n                sorted_scattered_idxs=sorted_scattered_idxs,\n                expert_offsets=expert_offsets,\n                router_weights=router_weights,\n            )

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
@mayank31398 mayank31398 merged commit 9d97903 into main Jun 24, 2026
2 checks passed
@mayank31398 mayank31398 deleted the cl branch June 24, 2026 19:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant