Add iree_metal to platforms with buffer donation support#2
Open
Add iree_metal to platforms with buffer donation support#2
Conversation
PiperOrigin-RevId: 841972235
…` instead PiperOrigin-RevId: 841977281
PiperOrigin-RevId: 841983057
PiperOrigin-RevId: 842010889
PiperOrigin-RevId: 842041847
…ommit/91b3f740b75d1d932a12fb0886338f84f856a453 PiperOrigin-RevId: 842096427
After getting rid of hints, these constructors are no longer necessary. PiperOrigin-RevId: 842112040
…other layouts. PiperOrigin-RevId: 842145211
PiperOrigin-RevId: 842264374
PiperOrigin-RevId: 842272730
PiperOrigin-RevId: 842272894
Recently we have found the need to evolve the serialization of `jax.export.Exported`. E.g., in jax-ml#33942 we have added a 32-bit representation for `nr_devices`. This introduced a compatiblity bug that was found by usersm and fixed in in jax-ml#33685. Here we add backwards compatibility tests. See the description in the `export_serialization_back_compat_test.py` module docstring. Note that this is separate from our previous set of backwards compatibility tests for the lowering of custom calls (in `export_back_compat_test.py`). However, we reuse some of the same ideas, and we use the same directory for storing saved old serializations.
Introduces a variant of Get in AttributeMap that returns the value variant as is. PiperOrigin-RevId: 842283537
PiperOrigin-RevId: 842284947
…gration guide Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 842298387
…tribute PiperOrigin-RevId: 842307512
They still run in CI. PiperOrigin-RevId: 842313199
PiperOrigin-RevId: 842326801
PiperOrigin-RevId: 842328162
PiperOrigin-RevId: 842328715
PiperOrigin-RevId: 842344094
PiperOrigin-RevId: 842347956
PiperOrigin-RevId: 842354532
This allows doing things like dynamic indexing of Refs using just regular scalars from outside the kernel. PiperOrigin-RevId: 842429684
PiperOrigin-RevId: 846781195
PiperOrigin-RevId: 846787187
…sfer_batching PiperOrigin-RevId: 846791377
PiperOrigin-RevId: 846791567
The NCCL version can be chosen via `HERMETIC_NCCL_VERSION` env var. See docs [here](https://github.com/google-ml-infra/rules_ml_toolchain/blob/main/gpu/README.md#environment-variables-controlling-the-hermetic-cudacudnnnvshmem-versions). PiperOrigin-RevId: 846797606
PiperOrigin-RevId: 846799030
PiperOrigin-RevId: 846803402
….9.0. PiperOrigin-RevId: 846811504
PiperOrigin-RevId: 846830895
The following modules are removed: - `jax.lib.xla_bridge` - `jax.lib.xla_client` - `jax.lib.xla_extension` All contents of these submodules were deprecated and removed as of JAX v0.8.0; the modules themselves have been raising warnings on import since this release. PiperOrigin-RevId: 846831442
PiperOrigin-RevId: 846841818
PiperOrigin-RevId: 846849490
Updates LLVM usage to match [7d381f2a5634](llvm/llvm-project@7d381f2a5634) PiperOrigin-RevId: 846858892
This has been deprecated since JAX v0.8.0; after this change the flag still exists, but setting it raises a warning and otherwise has no effect. It will be removed in JAX v0.10.0. PiperOrigin-RevId: 846877149
PiperOrigin-RevId: 846909748
PiperOrigin-RevId: 846930629
PiperOrigin-RevId: 846986071
…ommit/66dbbf501ffd74f83c6a5d8fc201c756b1198d64 PiperOrigin-RevId: 847021722
…ommit/08872f587d442a05802cbdb052e8c9e6e87423f4 PiperOrigin-RevId: 847325574
…ommit/7521349eccb22a780f50fa5f6f09dbaa1d09f470 PiperOrigin-RevId: 847636908
PiperOrigin-RevId: 847665237
PiperOrigin-RevId: 847713430
PiperOrigin-RevId: 847801848
PiperOrigin-RevId: 847871002
PiperOrigin-RevId: 847908521
…ommit/d1635d1c99de225c8029d82e56c4dd03f90b013f PiperOrigin-RevId: 848049355
The IREE Metal PJRT plugin now supports buffer donation, which allows JAX to reuse input buffers for outputs when the input is no longer needed. This optimization reduces memory allocation overhead. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…tion - Register iree_metal as an experimental plugin in xla_bridge.py - Add IREE Metal lowering for spsolve that emits custom_call to iree_spsolve, routed to BaSpaCho sparse solver on Metal GPUs Co-developed-by: Claude Code v2.1.39 (claude-opus-4-6)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
iree_metalto_platforms_with_donationlist injax/_src/interpreters/mlir.pyThe IREE Metal PJRT plugin now supports buffer donation, which allows JAX to reuse input buffers for outputs when the input is no longer needed. This optimization reduces memory allocation overhead and improves performance for workloads that can donate buffers.
Background
Buffer donation is an optimization where JAX can inform the runtime that an input buffer can be reused for an output, avoiding extra memory allocations. The IREE Metal PJRT plugin implements:
Test plan
🤖 Generated with Claude Code