[v3-backport] Add Mooncake extension for ArrayPartition cotangents#577
Merged
ChrisRackauckas merged 2 commits intoSciML:v3-backportfrom Apr 12, 2026
Conversation
When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's
`_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`,
such as the one produced by `SecondOrderODEProblem`) returns a
parameter / state cotangent as an `ArrayPartition`, Mooncake's
`@from_chainrules` / `@from_rrule` accumulator looks for an
`increment_and_get_rdata!` method matching
(FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)
There isn't a default method registered for this combination, so the
call falls through to the generic error path:
ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
rdata type Mooncake.NoRData, and tangent type
RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
combination is not supported with @from_chainrules or @from_rrule.
Add the missing dispatch via a new `RecursiveArrayToolsMooncakeExt`
weak-dep extension. An `ArrayPartition`'s only field is `x::Tuple` of
inner arrays, so the FData layout is `FData{@NamedTuple{x::Tuple{...}}}`
and the inner tuple positions line up with `t.x`. Walk the tuple
element-by-element and forward each leaf to the existing
`increment_and_get_rdata!` for the leaf's array type, which does the
actual in-place accumulation. Returns `Mooncake.NoRData()` to match the
no-rdata convention used by the equivalent ComponentArrays dispatch
(SciML/ComponentArrays.jl#350 / SciML#351).
Tested end-to-end against the SciMLSensitivity neural-ODE
`SecondOrderODEProblem` tutorial (via SciML/SciMLSensitivity.jl#1422,
which adds the matching `df_iip`/`df_oop` cotangent unwrap on the
SciMLSensitivity side): with both PRs applied, the Lux + `ArrayPartition`
training loop now runs under
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Add Mooncake to [extras] and [targets.test] so the new
`RecursiveArrayToolsMooncakeExt` is actually loaded and exercised in
the test suite, and add test/mooncake.jl as a direct unit test for the
new `Mooncake.increment_and_get_rdata!(::FData{@NamedTuple{x::T}},
::NoRData, ::ArrayPartition{P, T})` dispatch: constructs a matching
FData and ArrayPartition, calls `increment_and_get_rdata!`, and checks
that (a) the in-place accumulation on each inner-array leaf is correct
and (b) the method returns `NoRData()`. Also exercises a three-way
Float32 ArrayPartition to cover a different eltype and arity. Register
the testset in runtests.jl under the Core group.
Backport of SciML#575 to v3-backport.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Backport of #575 to the v3 maintenance line.
Adds a new
RecursiveArrayToolsMooncakeExtweak-dep extension that registers aMooncake.increment_and_get_rdata!(::FData{@NamedTuple{x::T}}, ::NoRData, ::ArrayPartition{P, T})method so Mooncake's@from_chainrules/@from_rruleaccumulator can handle anArrayPartitioncotangent returned by an upstream ChainRule (e.g. SciMLSensitivity's_concrete_solve_adjointfor aSecondOrderODEProblem). Without this, the call falls through to:Changes
ext/RecursiveArrayToolsMooncakeExt.jl— new extension. Walksf.data.xandt.xin lockstep and forwards each leaf to the existing per-arrayincrement_and_get_rdata!, mirroring the ComponentArrays dispatch (Add friendly_tangent_cache function to Mooncake ComponentArrays.jl#350 / ci: explicitly specify token for codecov #351). ReturnsMooncake.NoRData().Project.toml— Mooncake added to[weakdeps],[extensions],[compat](0.5), and to[extras]/[targets.test]so the extension is exercised in CI.test/mooncake.jl— direct unit test for the new dispatch: Float64 two-partition and Float32 three-partition cases, checking in-place accumulation per leaf and theNoRData()return. Wired into the Core testset inruntests.jl.Test plan
Local
Pkg.teston Julia 1.10.11 (v3-backportProject.toml, clean checkout of this branch):Aqua (including
test_stale_deps) passes; full Core testset green.Related
df_iip/df_oopincrement_and_get_rdata!plumbing forComponentArray🤖 Generated with Claude Code