diff --git a/aiter/ops/flydsl/moe_kernels.py b/aiter/ops/flydsl/moe_kernels.py index f1c12cd6b5..f5aa2627d1 100644 --- a/aiter/ops/flydsl/moe_kernels.py +++ b/aiter/ops/flydsl/moe_kernels.py @@ -625,11 +625,18 @@ def _s2_args_std( def _run_compiled(exe, args): - """Call the JitFunction with the given args. - JitFunction.__call__ handles compilation caching internally. + """First call: JIT-compile via flyc.compile (compiles + executes + returns CompiledFunction). + Subsequent calls: fast dispatch via the cached CompiledFunction. """ + import flydsl.compiler as flyc + + cf = getattr(exe, "_cf", None) + if cf is not None: + cf(*args) + return try: - exe(*args) + cf = flyc.compile(exe, *args) + exe._cf = cf except Exception: # JitFunction.__call__ leaks ir.Context on compilation failure, # causing all subsequent JitFunction calls to take a wrong code path