Skip to content

Feat: JAX integration via XLA FFI custom calls #226

Merged
jhchouuu merged 12 commits intomainfrom
pemeliya/jax_integration_upstream
Mar 31, 2026
Merged

Feat: JAX integration via XLA FFI custom calls #226
jhchouuu merged 12 commits intomainfrom
pemeliya/jax_integration_upstream

Conversation

@pemeliya
Copy link
Copy Markdown
Contributor

Motivation

Add JAX support to MORI via XLA FFI custom calls, enabling JAX users to use mori kernels and shmem without any torch dependency.

Technical Details

XLA FFI Framework

  • FFI handlers for dispatch / combine
  • New BUILD_XLA_FFI_OPS cmake option (default OFF), builds libmori_xla_ffi_ops.so

Torch-Free JAX Path

  • All submodules in mori/__init__.py are now lazy-imported via __getattr__, so import mori.jax.* does not trigger import torch

@jhchouuu
Copy link
Copy Markdown
Collaborator

If it's convenient for you, could you please share your testing steps, starting from the Docker build? Maybe we can incorporate it into the CI testing process.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a Torch-free JAX integration path for MORI by exposing EP dispatch/combine via XLA FFI custom calls, and updates build/Python packaging to optionally include these bindings.

Changes:

  • Introduce XLA FFI pybind bindings (mori_ep instantiate/execute handlers + type info) gated behind BUILD_XLA_FFI_OPS.
  • Add JAX Python module (mori.jax) that registers the FFI target and provides a JAX-facing EpDispatchCombineOp.
  • Add config packing/unpacking helpers on EpDispatchCombineConfig for passing config through XLA attributes.

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
CMakeLists.txt Adds BUILD_XLA_FFI_OPS option and build flags for the XLA FFI path.
src/pybind/CMakeLists.txt Conditionally builds pybind_xla_ffi_ops.cpp and includes vendored XLA FFI headers.
src/pybind/mori.cpp Registers XLA FFI ops into the main pybind module when enabled.
src/pybind/mori.hpp Declares RegisterXLAFFIOps behind BUILD_XLA_FFI_OPS.
src/pybind/pybind_ops.cpp Exposes config packing + max token helpers to Python.
src/pybind/pybind_xla_ffi_ops.cpp Implements XLA FFI state/handlers and a global handle cache.
include/mori/ops/dispatch_combine/dispatch_combine.hpp Declares config pack/unpack APIs.
src/ops/dispatch_combine/dispatch_combine.cpp Implements config pack/unpack.
python/mori/init.py Adds jax as a lazy-imported submodule.
python/mori/jax/init.py Registers the mori_ep FFI type/target with JAX on import.
python/mori/jax/ops.py Adds JAX wrappers for shmem init/finalize and EP dispatch/combine calls.
python/mori/ops/init.py Gates torch-based ops imports behind a torch import check.
python/mori/io/init.py Gates torch-dependent IO engine import behind a torch import check.
tests/cpp/CMakeLists.txt Restricts an MPI-dependent test target to WITH_MPI.
examples/ops/dispatch_combine/test_dispatch_combine_jax.py Adds a JAX example script for dispatch/combine.
3rdparty/xla_ffi/** Vendors XLA FFI headers and a version note for building the custom call interface.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 19 out of 19 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Collaborator

@jhchouuu jhchouuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
And we would like to add a CI test to ensure that the future work will not broke the integrated use of JAX. So, could you please provide the test procedures?
Finally, Thanks again to your excellent work and contribution to mori! @pemeliya @i-chaochen

@pemeliya
Copy link
Copy Markdown
Contributor Author

LGTM And we would like to add a CI test to ensure that the future work will not broke the integrated use of JAX. So, could you please provide the test procedures? Finally, Thanks again to your excellent work and contribution to mori! @pemeliya @i-chaochen

yes, I think I can move my dispatch_combine_test_jax.py into the test folder and adapt it for pytest execution. But this also needs a Dockerfile to build jax 0.9.1. Need some time to get it working

@jhchouuu
Copy link
Copy Markdown
Collaborator

LGTM And we would like to add a CI test to ensure that the future work will not broke the integrated use of JAX. So, could you please provide the test procedures? Finally, Thanks again to your excellent work and contribution to mori! @pemeliya @i-chaochen

yes, I think I can move my dispatch_combine_test_jax.py into the test folder and adapt it for pytest execution. But this also needs a Dockerfile to build jax 0.9.1. Need some time to get it working

Thanks, It would be best if there is a Dockerfile, or a Docker image that can be pulled.

…mbineHandles

moved to using int32 array attr to encode epdispatchcombineconfig

using one FFI call to cover all MORI EP functions

getting rid of pjrt registration

major update, first structure for jax interface

almost getting dispatch/combine ops working

getting the jax dispatch-combine test almost working

update

almost working dispatch/combine test

fixing combine op

update

added FFI headers from 0.9.0 branch

build script update

update

test fixes

build update

0.9.1 xla update

fixes after 0.9.1 xla update

fixing cmake

added runtime check

improved run script

removed test files

updated formatting

some fixes

fixes in xla ffi ops
@pemeliya pemeliya force-pushed the pemeliya/jax_integration_upstream branch from d676c22 to a620001 Compare March 27, 2026 13:29
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 19 out of 19 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 33 to 35
option(BUILD_PYBINDS "Whether to build mori python bindings" ON)
option(BUILD_XLA_FFI_OPS "Whether to build mori xla ffi python bindings" OFF)
option(BUILD_UMBP "Whether to build UMBP (Unified Memory/Bandwidth Pool)" OFF)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description says enabling BUILD_XLA_FFI_OPS builds a separate libmori_xla_ffi_ops.so, but the current CMake changes only add pybind_xla_ffi_ops.cpp into the existing mori_pybinds target (producing libmori_pybinds.so). Either update the PR description to match the actual build artifact, or adjust the build to produce the documented standalone shared library.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 19 out of 19 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +168 to +174
if(BUILD_XLA_FFI_OPS)
if (NOT BUILD_OPS_DEVICE)
message(FATAL_ERROR "BUILD_XLA_FFI_OPS=ON requires BUILD_OPS_DEVICE=ON")
endif()
add_compile_definitions(BUILD_XLA_FFI_OPS)
add_compile_options(-Wno-deprecated-literal-operator)
endif()
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description mentions building a separate libmori_xla_ffi_ops.so, but with the current CMake changes BUILD_XLA_FFI_OPS just adds pybind_xla_ffi_ops.cpp into the existing mori_pybinds shared library (producing libmori_pybinds.so). Either update the PR description/docs to match, or adjust the build to actually emit a dedicated libmori_xla_ffi_ops.so target if that separation is required.

Copilot uses AI. Check for mistakes.
Comment on lines 25 to 29
_LAZY_SUBMODULES = {
"cpp",
"ops",
"jax",
"shmem",
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding jax to the lazy submodule list does not currently make the JAX path torch-free: mori.jax imports mori.cpp, and python/mori/cpp/__init__.py still attempts import torch on module import (and will import it if installed). If the goal is that import mori.jax... never triggers a torch import, consider removing/moving the torch import from mori.cpp’s package init (or splitting out a minimal torch-free binding loader used by JAX).

Copilot uses AI. Check for mistakes.
Comment on lines +23 to +25
import jax # noqa: F401
from mori.cpp import mori_ep_handler, mori_ep_type_info, preload_kernels
from .ops import *
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mori.jax imports symbols from mori.cpp at import time, which (via python/mori/cpp/__init__.py) attempts to import torch. This contradicts the stated goal of a torch-free JAX path when torch is installed. To keep JAX usage independent of torch, load the pybind module without importing torch (e.g., refactor mori.cpp’s init to only import torch in torch-specific code paths, or provide a separate lightweight loader module for JAX).

Copilot uses AI. Check for mistakes.
Comment on lines +42 to +98
static constexpr int32_t EP_CONFIG_I32_VERSION = 1;

std::vector<int32_t> EpDispatchCombineConfig::ToPackedI32Array() const {
return {
EP_CONFIG_I32_VERSION,
rank,
worldSize,
hiddenDim,
scaleDim,
scaleTypeSize,
maxTokenTypeSize,
maxNumInpTokenPerRank,
numExpertPerRank,
numExpertPerToken,
warpNumPerBlock,
blockNum,
static_cast<int32_t>(useExternalInpBuffer),
static_cast<int32_t>(kernelType),
gpuPerNode,
rdmaBlockNum,
numQpPerPe,
static_cast<int32_t>(quantType),
static_cast<int32_t>(enableSdma),
};
}

EpDispatchCombineConfig EpDispatchCombineConfig::FromPackedI32Array(const int32_t* packed,
size_t size) {
// Runtime check to ensure the size of the packed array is correct
if (size - 1 != kPackedI32Len) {
throw std::runtime_error("EpDispatchCombineConfig i32 decode failed: invalid size");
}
if (packed == nullptr || packed[0] != EP_CONFIG_I32_VERSION) {
throw std::runtime_error("EpDispatchCombineConfig i32 decode failed: unsupported version");
}

EpDispatchCombineConfig cfg;
cfg.rank = packed[1];
cfg.worldSize = packed[2];
cfg.hiddenDim = packed[3];
cfg.scaleDim = packed[4];
cfg.scaleTypeSize = packed[5];
cfg.maxTokenTypeSize = packed[6];
cfg.maxNumInpTokenPerRank = packed[7];
cfg.numExpertPerRank = packed[8];
cfg.numExpertPerToken = packed[9];
cfg.warpNumPerBlock = packed[10];
cfg.blockNum = packed[11];
cfg.useExternalInpBuffer = (packed[12] != 0);
cfg.kernelType = static_cast<KernelType>(packed[13]);
cfg.gpuPerNode = packed[14];
cfg.rdmaBlockNum = packed[15];
cfg.numQpPerPe = packed[16];
cfg.quantType = static_cast<QuantType>(packed[17]);
cfg.enableSdma = (packed[18] != 0);
return cfg;
}
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ToPackedI32Array/FromPackedI32Array introduce a new serialized config format that is now relied on by the XLA FFI path (ep_config attribute). There are existing pytest-based dispatch/combine tests, but none appear to cover pack/unpack round-tripping, version mismatch, or size mismatch. Adding a small unit test for this serialization would help prevent silent ABI drift between Python and C++ (and make future config changes safer).

Copilot uses AI. Check for mistakes.
@jhchouuu
Copy link
Copy Markdown
Collaborator

/fix-precommit

@github-actions
Copy link
Copy Markdown
Contributor

✅ Auto-fixed and committed

The following files were modified:

tests/python/ops/test_dispatch_combine_jax.py

🎉 All issues resolved automatically — no manual action needed.

@jhchouuu jhchouuu merged commit 5ccc86d into main Mar 31, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants