From 301cabb1724c86c071da3e551b4b696efff9393b Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Mon, 8 Jun 2026 09:37:53 +0000 Subject: [PATCH] Use flyc.compile for MOE kernel fast dispatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace exe(*args) with flyc.compile-based caching: first call compiles and executes via flyc.compile(), subsequent calls use the cached CompiledFunction for ~5µs dispatch instead of full JitFunction overhead. Co-Authored-By: Claude Opus 4 --- aiter/ops/flydsl/moe_kernels.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/aiter/ops/flydsl/moe_kernels.py b/aiter/ops/flydsl/moe_kernels.py index f1c12cd6b5..f5aa2627d1 100644 --- a/aiter/ops/flydsl/moe_kernels.py +++ b/aiter/ops/flydsl/moe_kernels.py @@ -625,11 +625,18 @@ def _s2_args_std( def _run_compiled(exe, args): - """Call the JitFunction with the given args. - JitFunction.__call__ handles compilation caching internally. + """First call: JIT-compile via flyc.compile (compiles + executes + returns CompiledFunction). + Subsequent calls: fast dispatch via the cached CompiledFunction. """ + import flydsl.compiler as flyc + + cf = getattr(exe, "_cf", None) + if cf is not None: + cf(*args) + return try: - exe(*args) + cf = flyc.compile(exe, *args) + exe._cf = cf except Exception: # JitFunction.__call__ leaks ir.Context on compilation failure, # causing all subsequent JitFunction calls to take a wrong code path