Feat: JAX integration via XLA FFI custom calls #226
Conversation
|
If it's convenient for you, could you please share your testing steps, starting from the Docker build? Maybe we can incorporate it into the CI testing process. |
There was a problem hiding this comment.
Pull request overview
Adds a Torch-free JAX integration path for MORI by exposing EP dispatch/combine via XLA FFI custom calls, and updates build/Python packaging to optionally include these bindings.
Changes:
- Introduce XLA FFI pybind bindings (
mori_epinstantiate/execute handlers + type info) gated behindBUILD_XLA_FFI_OPS. - Add JAX Python module (
mori.jax) that registers the FFI target and provides a JAX-facingEpDispatchCombineOp. - Add config packing/unpacking helpers on
EpDispatchCombineConfigfor passing config through XLA attributes.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| CMakeLists.txt | Adds BUILD_XLA_FFI_OPS option and build flags for the XLA FFI path. |
| src/pybind/CMakeLists.txt | Conditionally builds pybind_xla_ffi_ops.cpp and includes vendored XLA FFI headers. |
| src/pybind/mori.cpp | Registers XLA FFI ops into the main pybind module when enabled. |
| src/pybind/mori.hpp | Declares RegisterXLAFFIOps behind BUILD_XLA_FFI_OPS. |
| src/pybind/pybind_ops.cpp | Exposes config packing + max token helpers to Python. |
| src/pybind/pybind_xla_ffi_ops.cpp | Implements XLA FFI state/handlers and a global handle cache. |
| include/mori/ops/dispatch_combine/dispatch_combine.hpp | Declares config pack/unpack APIs. |
| src/ops/dispatch_combine/dispatch_combine.cpp | Implements config pack/unpack. |
| python/mori/init.py | Adds jax as a lazy-imported submodule. |
| python/mori/jax/init.py | Registers the mori_ep FFI type/target with JAX on import. |
| python/mori/jax/ops.py | Adds JAX wrappers for shmem init/finalize and EP dispatch/combine calls. |
| python/mori/ops/init.py | Gates torch-based ops imports behind a torch import check. |
| python/mori/io/init.py | Gates torch-dependent IO engine import behind a torch import check. |
| tests/cpp/CMakeLists.txt | Restricts an MPI-dependent test target to WITH_MPI. |
| examples/ops/dispatch_combine/test_dispatch_combine_jax.py | Adds a JAX example script for dispatch/combine. |
| 3rdparty/xla_ffi/** | Vendors XLA FFI headers and a version note for building the custom call interface. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 19 out of 19 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
jhchouuu
left a comment
There was a problem hiding this comment.
LGTM
And we would like to add a CI test to ensure that the future work will not broke the integrated use of JAX. So, could you please provide the test procedures?
Finally, Thanks again to your excellent work and contribution to mori! @pemeliya @i-chaochen
yes, I think I can move my dispatch_combine_test_jax.py into the test folder and adapt it for pytest execution. But this also needs a Dockerfile to build jax 0.9.1. Need some time to get it working |
Thanks, It would be best if there is a Dockerfile, or a Docker image that can be pulled. |
…mbineHandles moved to using int32 array attr to encode epdispatchcombineconfig using one FFI call to cover all MORI EP functions getting rid of pjrt registration major update, first structure for jax interface almost getting dispatch/combine ops working getting the jax dispatch-combine test almost working update almost working dispatch/combine test fixing combine op update added FFI headers from 0.9.0 branch build script update update test fixes build update 0.9.1 xla update fixes after 0.9.1 xla update fixing cmake added runtime check improved run script removed test files updated formatting some fixes fixes in xla ffi ops
d676c22 to
a620001
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 19 out of 19 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| option(BUILD_PYBINDS "Whether to build mori python bindings" ON) | ||
| option(BUILD_XLA_FFI_OPS "Whether to build mori xla ffi python bindings" OFF) | ||
| option(BUILD_UMBP "Whether to build UMBP (Unified Memory/Bandwidth Pool)" OFF) |
There was a problem hiding this comment.
PR description says enabling BUILD_XLA_FFI_OPS builds a separate libmori_xla_ffi_ops.so, but the current CMake changes only add pybind_xla_ffi_ops.cpp into the existing mori_pybinds target (producing libmori_pybinds.so). Either update the PR description to match the actual build artifact, or adjust the build to produce the documented standalone shared library.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 18 out of 18 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 19 out of 19 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if(BUILD_XLA_FFI_OPS) | ||
| if (NOT BUILD_OPS_DEVICE) | ||
| message(FATAL_ERROR "BUILD_XLA_FFI_OPS=ON requires BUILD_OPS_DEVICE=ON") | ||
| endif() | ||
| add_compile_definitions(BUILD_XLA_FFI_OPS) | ||
| add_compile_options(-Wno-deprecated-literal-operator) | ||
| endif() |
There was a problem hiding this comment.
The PR description mentions building a separate libmori_xla_ffi_ops.so, but with the current CMake changes BUILD_XLA_FFI_OPS just adds pybind_xla_ffi_ops.cpp into the existing mori_pybinds shared library (producing libmori_pybinds.so). Either update the PR description/docs to match, or adjust the build to actually emit a dedicated libmori_xla_ffi_ops.so target if that separation is required.
| _LAZY_SUBMODULES = { | ||
| "cpp", | ||
| "ops", | ||
| "jax", | ||
| "shmem", |
There was a problem hiding this comment.
Adding jax to the lazy submodule list does not currently make the JAX path torch-free: mori.jax imports mori.cpp, and python/mori/cpp/__init__.py still attempts import torch on module import (and will import it if installed). If the goal is that import mori.jax... never triggers a torch import, consider removing/moving the torch import from mori.cpp’s package init (or splitting out a minimal torch-free binding loader used by JAX).
| import jax # noqa: F401 | ||
| from mori.cpp import mori_ep_handler, mori_ep_type_info, preload_kernels | ||
| from .ops import * |
There was a problem hiding this comment.
mori.jax imports symbols from mori.cpp at import time, which (via python/mori/cpp/__init__.py) attempts to import torch. This contradicts the stated goal of a torch-free JAX path when torch is installed. To keep JAX usage independent of torch, load the pybind module without importing torch (e.g., refactor mori.cpp’s init to only import torch in torch-specific code paths, or provide a separate lightweight loader module for JAX).
| static constexpr int32_t EP_CONFIG_I32_VERSION = 1; | ||
|
|
||
| std::vector<int32_t> EpDispatchCombineConfig::ToPackedI32Array() const { | ||
| return { | ||
| EP_CONFIG_I32_VERSION, | ||
| rank, | ||
| worldSize, | ||
| hiddenDim, | ||
| scaleDim, | ||
| scaleTypeSize, | ||
| maxTokenTypeSize, | ||
| maxNumInpTokenPerRank, | ||
| numExpertPerRank, | ||
| numExpertPerToken, | ||
| warpNumPerBlock, | ||
| blockNum, | ||
| static_cast<int32_t>(useExternalInpBuffer), | ||
| static_cast<int32_t>(kernelType), | ||
| gpuPerNode, | ||
| rdmaBlockNum, | ||
| numQpPerPe, | ||
| static_cast<int32_t>(quantType), | ||
| static_cast<int32_t>(enableSdma), | ||
| }; | ||
| } | ||
|
|
||
| EpDispatchCombineConfig EpDispatchCombineConfig::FromPackedI32Array(const int32_t* packed, | ||
| size_t size) { | ||
| // Runtime check to ensure the size of the packed array is correct | ||
| if (size - 1 != kPackedI32Len) { | ||
| throw std::runtime_error("EpDispatchCombineConfig i32 decode failed: invalid size"); | ||
| } | ||
| if (packed == nullptr || packed[0] != EP_CONFIG_I32_VERSION) { | ||
| throw std::runtime_error("EpDispatchCombineConfig i32 decode failed: unsupported version"); | ||
| } | ||
|
|
||
| EpDispatchCombineConfig cfg; | ||
| cfg.rank = packed[1]; | ||
| cfg.worldSize = packed[2]; | ||
| cfg.hiddenDim = packed[3]; | ||
| cfg.scaleDim = packed[4]; | ||
| cfg.scaleTypeSize = packed[5]; | ||
| cfg.maxTokenTypeSize = packed[6]; | ||
| cfg.maxNumInpTokenPerRank = packed[7]; | ||
| cfg.numExpertPerRank = packed[8]; | ||
| cfg.numExpertPerToken = packed[9]; | ||
| cfg.warpNumPerBlock = packed[10]; | ||
| cfg.blockNum = packed[11]; | ||
| cfg.useExternalInpBuffer = (packed[12] != 0); | ||
| cfg.kernelType = static_cast<KernelType>(packed[13]); | ||
| cfg.gpuPerNode = packed[14]; | ||
| cfg.rdmaBlockNum = packed[15]; | ||
| cfg.numQpPerPe = packed[16]; | ||
| cfg.quantType = static_cast<QuantType>(packed[17]); | ||
| cfg.enableSdma = (packed[18] != 0); | ||
| return cfg; | ||
| } |
There was a problem hiding this comment.
ToPackedI32Array/FromPackedI32Array introduce a new serialized config format that is now relied on by the XLA FFI path (ep_config attribute). There are existing pytest-based dispatch/combine tests, but none appear to cover pack/unpack round-tripping, version mismatch, or size mismatch. Adding a small unit test for this serialization would help prevent silent ABI drift between Python and C++ (and make future config changes safer).
|
/fix-precommit |
✅ Auto-fixed and committedThe following files were modified: 🎉 All issues resolved automatically — no manual action needed. |
Motivation
Add JAX support to MORI via XLA FFI custom calls, enabling JAX users to use mori kernels and shmem without any torch dependency.
Technical Details
XLA FFI Framework
dispatch/combineBUILD_XLA_FFI_OPScmake option (default OFF), buildslibmori_xla_ffi_ops.soTorch-Free JAX Path
mori/__init__.pyare now lazy-imported via__getattr__, soimport mori.jax.*does not triggerimport torch