From 65e6acdf45772b9771b156ee215d617484eb65aa Mon Sep 17 00:00:00 2001 From: Robert Johansson Date: Sun, 22 Feb 2026 21:08:49 +0100 Subject: [PATCH 1/8] Update to MLX 0.30.6 Bump MLX submodule from v0.25.0 to v0.30.6 and fix all API changes: - Add SmallVector kizunapi type specialization (Shape changed from std::vector to SmallVector in MLX >= 0.26) - Add PutIntoShape helper, keep PutIntoVector for std::vector uses - Update FFT wrapper function pointer types for Shape parameter - Add output_padding parameter to conv_transpose1d/2d/3d - Add sinks parameter to scaled_dot_product_attention calls - Move device_info from metal:: to gpu:: namespace - Split large ki::Set calls to stay within template argument limits Co-Authored-By: Claude Opus 4.6 --- deps/mlx | 2 +- src/fast.cc | 6 +++--- src/fft.cc | 11 ++++++----- src/indexing.cc | 2 +- src/metal.cc | 12 +++++++++++- src/ops.cc | 41 +++++++++++++++++++++++++++-------------- src/utils.cc | 2 +- src/utils.h | 48 ++++++++++++++++++++++++++++++++++++++++++++---- 8 files changed, 94 insertions(+), 30 deletions(-) diff --git a/deps/mlx b/deps/mlx index b529515e..185b06d9 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit b529515eb158edd0919746ce4e545fe0879d6437 +Subproject commit 185b06d9efc1c869540eccfb5baff853fff3659d diff --git a/src/fast.cc b/src/fast.cc index 6cfe1e07..24aa539a 100644 --- a/src/fast.cc +++ b/src/fast.cc @@ -42,16 +42,16 @@ mx::array ScaledDotProductAttention( throw std::invalid_argument(msg.str()); } return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, mask_str, {}, s); + queries, keys, values, scale, mask_str, {}, {}, s); } else { auto mask_arr = std::get(mask); return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {mask_arr}, s); + queries, keys, values, scale, "", {mask_arr}, {}, s); } } else { return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {}, s); + queries, keys, values, scale, "", {}, {}, s); } } diff --git a/src/fft.cc b/src/fft.cc index 808889ab..51cccca2 100644 --- a/src/fft.cc +++ b/src/fft.cc @@ -32,7 +32,7 @@ std::function FFTNOpWrapper(const char* name, mx::array(*func1)(const mx::array&, - const std::vector&, + const mx::Shape&, const std::vector&, mx::StreamOrDevice), mx::array(*func2)(const mx::array&, @@ -45,16 +45,17 @@ FFTNOpWrapper(const char* name, std::optional> axes, mx::StreamOrDevice s) { if (n && axes) { - return mx::fft::fftn(a, std::move(*n), std::move(*axes), s); + mx::Shape shape_n(n->begin(), n->end()); + return func1(a, shape_n, std::move(*axes), s); } else if (axes) { - return mx::fft::fftn(a, std::move(*axes), s); + return func2(a, std::move(*axes), s); } else if (n) { std::ostringstream msg; msg << "[" << name << "] " << "`axes` should not be `None` if `s` is not `None`."; throw std::invalid_argument(msg.str()); } else { - return mx::fft::fftn(a, s); + return func3(a, s); } }; } @@ -66,7 +67,7 @@ std::function FFT2OpWrapper(const char* name, mx::array(*func1)(const mx::array&, - const std::vector&, + const mx::Shape&, const std::vector&, mx::StreamOrDevice), mx::array(*func2)(const mx::array&, diff --git a/src/indexing.cc b/src/indexing.cc index e5d7286d..b61dc129 100644 --- a/src/indexing.cc +++ b/src/indexing.cc @@ -563,7 +563,7 @@ ScatterResult ScatterArgsNDimentional(const mx::array* a, a->shape().begin() + non_none_indices, a->shape().end()); up = mx::reshape(std::move(up), std::move(up_reshape)); - mx::Shape axes(arr_indices.size(), 0); + std::vector axes(arr_indices.size(), 0); std::iota(axes.begin(), axes.end(), 0); return {std::move(arr_indices), std::move(up), std::move(axes)}; } diff --git a/src/metal.cc b/src/metal.cc index 91f79aef..777f1b48 100644 --- a/src/metal.cc +++ b/src/metal.cc @@ -1,4 +1,14 @@ #include "src/bindings.h" +#include "mlx/backend/gpu/device_info.h" + +namespace metal_ops { + +const std::unordered_map>& +DeviceInfo() { + return mx::gpu::device_info(0); +} + +} // namespace metal_ops void InitMetal(napi_env env, napi_value exports) { napi_value metal = ki::CreateObject(env); @@ -8,5 +18,5 @@ void InitMetal(napi_env env, napi_value exports) { "isAvailable", &mx::metal::is_available, "startCapture", &mx::metal::start_capture, "stopCapture", &mx::metal::stop_capture, - "deviceInfo", &mx::metal::device_info); + "deviceInfo", &metal_ops::DeviceInfo); } diff --git a/src/ops.cc b/src/ops.cc index aad9d052..769f7020 100644 --- a/src/ops.cc +++ b/src/ops.cc @@ -191,7 +191,7 @@ mx::array Full(std::variant shape, ScalarOrArray vals, std::optional dtype, mx::StreamOrDevice s) { - return mx::full(PutIntoVector(std::move(shape)), + return mx::full(PutIntoShape(std::move(shape)), ToArray(std::move(vals), std::move(dtype)), s); } @@ -199,13 +199,13 @@ mx::array Full(std::variant shape, mx::array Zeros(std::variant shape, std::optional dtype, mx::StreamOrDevice s) { - return mx::zeros(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s); + return mx::zeros(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s); } mx::array Ones(std::variant shape, std::optional dtype, mx::StreamOrDevice s) { - return mx::ones(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s); + return mx::ones(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s); } mx::array Eye(int n, @@ -303,8 +303,9 @@ std::vector Split(const mx::array& a, if (auto i = std::get_if(&indices); i) { return mx::split(a, *i, axis.value_or(0), s); } else { - return mx::split(a, std::move(std::get>(indices)), - axis.value_or(0), s); + auto& v = std::get>(indices); + mx::Shape shape_indices(v.begin(), v.end()); + return mx::split(a, std::move(shape_indices), axis.value_or(0), s); } } @@ -544,7 +545,7 @@ mx::array ConvTranspose1d( mx::StreamOrDevice s) { return mx::conv_transpose1d(input, weight, stride.value_or(1), padding.value_or(0), dilation.value_or(1), - groups.value_or(1), s); + /*output_padding=*/0, groups.value_or(1), s); } mx::array ConvTranspose2d( @@ -574,7 +575,7 @@ mx::array ConvTranspose2d( dilation_pair = std::move(*p); return mx::conv_transpose2d(input, weight, stride_pair, padding_pair, - dilation_pair, groups.value_or(1), s); + dilation_pair, {0, 0}, groups.value_or(1), s); } mx::array ConvTranspose3d( @@ -604,7 +605,7 @@ mx::array ConvTranspose3d( dilation_tuple = std::move(*p); return mx::conv_transpose3d(input, weight, stride_tuple, padding_tuple, - dilation_tuple, groups.value_or(1), s); + dilation_tuple, {0, 0, 0}, groups.value_or(1), s); } mx::array ConvGeneral( @@ -811,7 +812,9 @@ void InitOps(napi_env env, napi_value exports) { "stopGradient", &mx::stop_gradient, "sigmoid", &mx::sigmoid, "power", BinOpWrapper(&mx::power), - "arange", &ops::ARange, + "arange", &ops::ARange); + + ki::Set(env, exports, "linspace", &ops::Linspace, "kron", &mx::kron, "take", &ops::Take, @@ -848,12 +851,16 @@ void InitOps(napi_env env, napi_value exports) { "min", DimOpWrapper(&mx::min), "max", DimOpWrapper(&mx::max), "logcumsumexp", CumOpWrapper(&mx::logcumsumexp), - "logsumexp", DimOpWrapper(&mx::logsumexp), + "logsumexp", DimOpWrapper(&mx::logsumexp)); + + ki::Set(env, exports, "mean", DimOpWrapper(&mx::mean), "variance", &ops::Var, "std", &ops::Std, "split", &ops::Split, - "argmin", &ops::ArgMin, + "argmin", &ops::ArgMin); + + ki::Set(env, exports, "argmax", &ops::ArgMax, "sort", &ops::Sort, "argsort", &ops::ArgSort, @@ -864,7 +871,9 @@ void InitOps(napi_env env, napi_value exports) { "blockMaskedMM", &mx::block_masked_mm, "gatherMM", &mx::gather_mm, "gatherQMM", &mx::gather_qmm, - "softmax", &ops::Softmax, + "softmax", &ops::Softmax); + + ki::Set(env, exports, "concatenate", &ops::Concatenate, "concat", &ops::Concatenate, "stack", &ops::Stack, @@ -876,7 +885,9 @@ void InitOps(napi_env env, napi_value exports) { "cumsum", CumOpWrapper(&mx::cumsum), "cumprod", CumOpWrapper(&mx::cumprod), "cummax", CumOpWrapper(&mx::cummax), - "cummin", CumOpWrapper(&mx::cummin), + "cummin", CumOpWrapper(&mx::cummin)); + + ki::Set(env, exports, "conj", &mx::conjugate, "conjugate", &mx::conjugate, "convolve", &ops::Convolve, @@ -912,7 +923,9 @@ void InitOps(napi_env env, napi_value exports) { "bitwiseXor", BinOpWrapper(&mx::bitwise_xor), "leftShift", BinOpWrapper(&mx::left_shift), "rightShift", BinOpWrapper(&mx::right_shift), - "view", &mx::view, + "view", &mx::view); + + ki::Set(env, exports, "hadamardTransform", &mx::hadamard_transform, "einsumPath", &mx::einsum_path, "einsum", &mx::einsum, diff --git a/src/utils.cc b/src/utils.cc index b827c534..f309c48d 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -1,7 +1,7 @@ #include "src/array.h" #include "src/utils.h" -mx::Shape PutIntoVector(std::variant shape) { +mx::Shape PutIntoShape(std::variant shape) { if (auto i = std::get_if(&shape); i) return {*i}; return std::move(std::get(shape)); diff --git a/src/utils.h b/src/utils.h index 8e9fd41d..3cf2fb28 100644 --- a/src/utils.h +++ b/src/utils.h @@ -8,12 +8,45 @@ namespace mx = mlx::core; +// Teach kizunapi how to serialize/deserialize SmallVector (used for Shape +// and other types in MLX >= 0.26). Mirrors the std::vector specialization. +namespace ki { + +template +struct Type> { + static constexpr const char* name = "Array"; + static napi_status ToNode(napi_env env, + const mlx::core::SmallVector& vec, + napi_value* result) { + napi_status s = napi_create_array_with_length(env, vec.size(), result); + if (s != napi_ok) return s; + for (size_t i = 0; i < vec.size(); ++i) { + napi_value el; + s = ConvertToNode(env, vec[i], &el); + if (s != napi_ok) return s; + s = napi_set_element(env, *result, i, el); + if (s != napi_ok) return s; + } + return napi_ok; + } + static std::optional> FromNode( + napi_env env, napi_value value) { + // Read as std::vector then convert to SmallVector. + auto vec = Type>::FromNode(env, value); + if (!vec) return std::nullopt; + return mlx::core::SmallVector(vec->begin(), vec->end()); + } +}; + +} // namespace ki + using OptionalAxes = std::variant>; using ScalarOrArray = std::variant; -// Read args into a vector of types. -template -bool ReadArgs(ki::Arguments* args, std::vector* results) { +// Read args into a container of types (vector or SmallVector). +template +bool ReadArgs(ki::Arguments* args, Container* results) { + using T = typename Container::value_type; while (args->RemainingsLength() > 0) { std::optional a = args->GetNext(); if (!a) { @@ -45,8 +78,15 @@ void DefineToString(napi_env env, napi_value prototype) { symbol, ki::MemberFunction(&ToString)); } +// If input is one int, put it into a Shape, otherwise just return the Shape. +mx::Shape PutIntoShape(std::variant shape); + // If input is one int, put it into a vector, otherwise just return the vector. -std::vector PutIntoVector(std::variant> shape); +inline std::vector PutIntoVector(std::variant> v) { + if (auto i = std::get_if(&v); i) + return {*i}; + return std::move(std::get>(v)); +} // Get axis arg from js value. std::vector GetReduceAxes(OptionalAxes value, int dims); From 18192133b4e1c94b4601c2ca89f738bfb2b27eb9 Mon Sep 17 00:00:00 2001 From: Robert Johansson Date: Tue, 24 Feb 2026 22:37:20 +0100 Subject: [PATCH 2/8] Fix MLX compile crash on large computation graphs Update deps/mlx with fix for compile_fuse broadcast split_one bug that caused "unordered_map::at: key not found" on compiled functions with ~100+ operations. This is an upstream MLX bug (v0.29.4+). Co-Authored-By: Claude Opus 4.6 --- deps/mlx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/mlx b/deps/mlx index 185b06d9..65cefdef 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit 185b06d9efc1c869540eccfb5baff853fff3659d +Subproject commit 65cefdef40bce82a9db3279cedfa04ea93e33ed7 From 43976b62ee18ae6acfa6940b5fbdbb8e6d6f41d4 Mon Sep 17 00:00:00 2001 From: Robert Johansson Date: Tue, 24 Feb 2026 23:19:13 +0100 Subject: [PATCH 3/8] Fix MLX compile crash with proper broadcast split handling Update MLX submodule with improved compile_fuse fix that preserves the broadcast fusion optimization while fixing the aliasing bug that caused unordered_map::at crashes on large computation graphs. Co-Authored-By: Claude Opus 4.6 --- deps/mlx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/mlx b/deps/mlx index 65cefdef..a6d40e4a 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit 65cefdef40bce82a9db3279cedfa04ea93e33ed7 +Subproject commit a6d40e4aa54b71b73e266dc9ec53b87db4b5dd3a From 44bb8cb49c20d56f9d696a19def96c71d3d5da49 Mon Sep 17 00:00:00 2001 From: Robert Johansson Date: Fri, 27 Feb 2026 08:50:07 +0100 Subject: [PATCH 4/8] Update MLX submodule to official main (includes merged fix #3166) Points deps/mlx to ml-explore/mlx main (c8536f52) which includes the merged compile_fuse broadcast split fix from PR #3166, plus newer upstream fixes (Metal event leak, conv3d overflow, fence sync). Replaces the local branch commits (65cefdef, a6d40e4a) which are now superseded by the upstream merge. Co-Authored-By: Claude Opus 4.6 --- deps/mlx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/mlx b/deps/mlx index a6d40e4a..c8536f52 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit a6d40e4aa54b71b73e266dc9ec53b87db4b5dd3a +Subproject commit c8536f5248bf06e20a8ed1706c57a6687d2e7c64 From e4aeb03c4f2d287147d3191cf5c821498200a6ff Mon Sep 17 00:00:00 2001 From: Robert Johansson Date: Fri, 27 Feb 2026 16:12:37 +0100 Subject: [PATCH 5/8] Expose lgamma and digamma ops from MLX Update MLX submodule to include native lgamma/digamma kernels and add Node.js bindings for both operations. Co-Authored-By: Claude Opus 4.6 --- deps/mlx | 2 +- node_mlx.node.d.ts | 2 ++ src/ops.cc | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/deps/mlx b/deps/mlx index c8536f52..a0a30ba8 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit c8536f5248bf06e20a8ed1706c57a6687d2e7c64 +Subproject commit a0a30ba8e70ee08db3b53abbea2664cfc39d4170 diff --git a/node_mlx.node.d.ts b/node_mlx.node.d.ts index 2000c8ae..650b0fa3 100644 --- a/node_mlx.node.d.ts +++ b/node_mlx.node.d.ts @@ -240,6 +240,8 @@ declare module '*node_mlx.node' { function notEqual(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array; function erf(array: ScalarOrArray, s?: StreamOrDevice): array; function erfinv(array: ScalarOrArray, s?: StreamOrDevice): array; + function lgamma(array: ScalarOrArray, s?: StreamOrDevice): array; + function digamma(array: ScalarOrArray, s?: StreamOrDevice): array; function exp(array: ScalarOrArray, s?: StreamOrDevice): array; function expm1(array: ScalarOrArray, s?: StreamOrDevice): array; function expandDims(array: ScalarOrArray, dims: number | number[], s?: StreamOrDevice): array; diff --git a/src/ops.cc b/src/ops.cc index 769f7020..5c1923ed 100644 --- a/src/ops.cc +++ b/src/ops.cc @@ -790,6 +790,8 @@ void InitOps(napi_env env, napi_value exports) { "expm1", &mx::expm1, "erf", &mx::erf, "erfinv", &mx::erfinv, + "lgamma", &mx::lgamma, + "digamma", &mx::digamma, "sin", &mx::sin, "cos", &mx::cos, "tan", &mx::tan, From b4ce814c05c61126f10f7a1a1d9082b9d7c7bfb0 Mon Sep 17 00:00:00 2001 From: Robert Johansson Date: Wed, 4 Mar 2026 21:23:53 +0100 Subject: [PATCH 6/8] Point deps/mlx to robert-johansson/mlx fork, add bessel bindings - Update deps/mlx submodule URL to robert-johansson/mlx (genmlx branch) with lgamma, digamma, bessel_i0e, bessel_i1e ops - Add besselI0e/besselI1e bindings in ops.cc and type declarations Co-Authored-By: Claude Opus 4.6 --- .gitmodules | 2 +- deps/mlx | 2 +- node_mlx.node.d.ts | 2 ++ src/ops.cc | 2 ++ 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 40a4a0fa..2e0fdd9e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "deps/mlx"] path = deps/mlx - url = https://github.com/ml-explore/mlx + url = https://github.com/robert-johansson/mlx [submodule "deps/kizunapi"] path = deps/kizunapi url = https://github.com/photoionization/kizunapi diff --git a/deps/mlx b/deps/mlx index a0a30ba8..3b30ffc5 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit a0a30ba8e70ee08db3b53abbea2664cfc39d4170 +Subproject commit 3b30ffc5e8c16d116b57458b0d91e75a29f98bf5 diff --git a/node_mlx.node.d.ts b/node_mlx.node.d.ts index 650b0fa3..fcaf69a7 100644 --- a/node_mlx.node.d.ts +++ b/node_mlx.node.d.ts @@ -242,6 +242,8 @@ declare module '*node_mlx.node' { function erfinv(array: ScalarOrArray, s?: StreamOrDevice): array; function lgamma(array: ScalarOrArray, s?: StreamOrDevice): array; function digamma(array: ScalarOrArray, s?: StreamOrDevice): array; + function besselI0e(array: ScalarOrArray, s?: StreamOrDevice): array; + function besselI1e(array: ScalarOrArray, s?: StreamOrDevice): array; function exp(array: ScalarOrArray, s?: StreamOrDevice): array; function expm1(array: ScalarOrArray, s?: StreamOrDevice): array; function expandDims(array: ScalarOrArray, dims: number | number[], s?: StreamOrDevice): array; diff --git a/src/ops.cc b/src/ops.cc index 5c1923ed..7707dd46 100644 --- a/src/ops.cc +++ b/src/ops.cc @@ -792,6 +792,8 @@ void InitOps(napi_env env, napi_value exports) { "erfinv", &mx::erfinv, "lgamma", &mx::lgamma, "digamma", &mx::digamma, + "besselI0e", &mx::bessel_i0e, + "besselI1e", &mx::bessel_i1e, "sin", &mx::sin, "cos", &mx::cos, "tan", &mx::tan, From 7293697cd4c7d2cbc7674e1391395d3c03a9d345 Mon Sep 17 00:00:00 2001 From: Robert Johansson Date: Sat, 7 Mar 2026 23:54:50 +0100 Subject: [PATCH 7/8] Add GC pressure signaling for Metal GPU buffers Report external memory (min 1MB per array) via napi_adjust_external_memory so the JS GC knows about Metal GPU buffer pressure. This makes GC run earlier, reducing the chance of hitting Metal's 499K allocation limit. - Point kizunapi submodule to robert-johansson fork with ExternalMemorySize trait - Specialize ExternalMemorySize for mx::array (1MB minimum cost) - Add napi_adjust_external_memory calls in Tidy and Dispose paths Co-Authored-By: Claude Opus 4.6 --- .gitmodules | 2 +- deps/kizunapi | 2 +- src/array.cc | 13 ++++++++++++- src/array.h | 22 ++++++++++++++++++++++ 4 files changed, 36 insertions(+), 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index 2e0fdd9e..d3fbe2b7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,4 +3,4 @@ url = https://github.com/robert-johansson/mlx [submodule "deps/kizunapi"] path = deps/kizunapi - url = https://github.com/photoionization/kizunapi + url = https://github.com/robert-johansson/kizunapi diff --git a/deps/kizunapi b/deps/kizunapi index b8d06226..dd06d3eb 160000 --- a/deps/kizunapi +++ b/deps/kizunapi @@ -1 +1 @@ -Subproject commit b8d06226897a0cfe42a6efab39c413efd35b2276 +Subproject commit dd06d3eb6f3ea44291fb781b54f6b89240176d85 diff --git a/src/array.cc b/src/array.cc index 8edbb558..12c05be5 100644 --- a/src/array.cc +++ b/src/array.cc @@ -483,11 +483,17 @@ napi_value Tidy(napi_env env, std::function func) { // 3. The JS object is marked as dead, but the finalizer has not run. // We have to unbind the JS object in 1, and only delete array in 1 // and 3. + int64_t ext = ki::internal::ExternalMemorySize::Get(a); napi_value value; if (instance_data->GetWrapper(a, &value)) napi_remove_wrap(env, value, nullptr); - if (instance_data->DeleteWrapper(a)) + if (instance_data->DeleteWrapper(a)) { + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } delete a; + } } return result; }, @@ -504,8 +510,13 @@ void Dispose(const ki::Arguments& args) { TreeVisit(args.Env(), args[i], [instance_data](napi_env env, napi_value value) { if (auto a = ki::FromNodeTo(env, value); a) { + int64_t ext = ki::internal::ExternalMemorySize::Get(a.value()); napi_remove_wrap(env, value, nullptr); instance_data->DeleteWrapper(a.value()); + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } delete a.value(); } return napi_value(); diff --git a/src/array.h b/src/array.h index 901e2a59..f4654224 100644 --- a/src/array.h +++ b/src/array.h @@ -65,6 +65,28 @@ struct Type : public AllowPassByValue { napi_value value); }; +namespace internal { + +// Report external memory for mx::array to enable GC pressure signaling. +// MLX arrays hold Metal GPU buffers that are invisible to the JS GC. +// Without this, the GC doesn't know about GPU memory pressure and doesn't +// collect array wrappers fast enough, causing Metal resource exhaustion. +template<> +struct ExternalMemorySize { + static int64_t Get(mx::array* a) { + // Metal has a hard limit of 499K buffer allocations. We must create + // enough external memory pressure to force the GC to collect array + // wrappers before hitting it. Report 1MB per array as the minimum + // external cost — this is much larger than the actual data size but + // necessary to trigger sufficiently aggressive GC for GPU resources. + size_t n = a->nbytes(); + constexpr int64_t min_cost = 1024 * 1024; // 1MB + return static_cast(n) > min_cost ? static_cast(n) : min_cost; + } +}; + +} // namespace internal + } // namespace ki #endif // SRC_ARRAY_H_ From 5491416cb43746abd8d7b972a21c84f7f57b889d Mon Sep 17 00:00:00 2001 From: Robert Johansson Date: Sun, 8 Mar 2026 02:27:16 +0100 Subject: [PATCH 8/8] Add sweepDeadArrays() for synchronous Metal buffer cleanup Adds a native function that bypasses the deferred N-API finalizer queue by synchronously walking the wrapper registry and freeing arrays whose JS wrappers have been GC'd. This is critical for synchronous inference loops where the event loop never yields and deferred finalizers never run. Includes kizunapi changes: - CollectDeadWrappers() in InstanceData - ExternalMemorySize reporting on AllowPassByValue path - Double-free guard in finalizer callbacks Co-Authored-By: Claude Opus 4.6 --- deps/kizunapi | 2 +- node_mlx.node.d.ts | 1 + src/array.cc | 22 +++++++++++++++++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/deps/kizunapi b/deps/kizunapi index dd06d3eb..0071f6c3 160000 --- a/deps/kizunapi +++ b/deps/kizunapi @@ -1 +1 @@ -Subproject commit dd06d3eb6f3ea44291fb781b54f6b89240176d85 +Subproject commit 0071f6c30b2f7ce9b3a3ac13f0d17ab464f21073 diff --git a/node_mlx.node.d.ts b/node_mlx.node.d.ts index fcaf69a7..1e7efa59 100644 --- a/node_mlx.node.d.ts +++ b/node_mlx.node.d.ts @@ -382,6 +382,7 @@ declare module '*node_mlx.node' { function tidy(func: () => U): U; function dispose(...args: unknown[]): void; function getWrappersCount(): number; + function sweepDeadArrays(): number; // Metal. namespace metal { diff --git a/src/array.cc b/src/array.cc index 12c05be5..213c9281 100644 --- a/src/array.cc +++ b/src/array.cc @@ -529,6 +529,25 @@ size_t GetWrappersCount(napi_env env) { return ki::InstanceData::Get(env)->GetWrappersCount(); } +// Synchronously sweep dead array wrappers. +// Finds arrays whose JS wrappers have been GC'd but whose deferred finalizers +// haven't run yet, and immediately frees the native Metal buffers. +// Returns the number of arrays swept. +size_t SweepDeadArrays(napi_env env) { + ki::InstanceData* instance_data = ki::InstanceData::Get(env); + auto dead_ptrs = instance_data->CollectDeadWrappers(); + for (void* ptr : dead_ptrs) { + mx::array* a = static_cast(ptr); + int64_t ext = ki::internal::ExternalMemorySize::Get(a); + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } + delete a; + } + return dead_ptrs.size(); +} + } // namespace namespace ki { @@ -819,5 +838,6 @@ void InitArray(napi_env env, napi_value exports) { ki::Set(env, exports, "tidy", &Tidy, "dispose", &Dispose, - "getWrappersCount", &GetWrappersCount); + "getWrappersCount", &GetWrappersCount, + "sweepDeadArrays", &SweepDeadArrays); }