diff --git a/.gitmodules b/.gitmodules index 40a4a0f..d3fbe2b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "deps/mlx"] path = deps/mlx - url = https://github.com/ml-explore/mlx + url = https://github.com/robert-johansson/mlx [submodule "deps/kizunapi"] path = deps/kizunapi - url = https://github.com/photoionization/kizunapi + url = https://github.com/robert-johansson/kizunapi diff --git a/deps/kizunapi b/deps/kizunapi index b8d0622..0071f6c 160000 --- a/deps/kizunapi +++ b/deps/kizunapi @@ -1 +1 @@ -Subproject commit b8d06226897a0cfe42a6efab39c413efd35b2276 +Subproject commit 0071f6c30b2f7ce9b3a3ac13f0d17ab464f21073 diff --git a/deps/mlx b/deps/mlx index b529515..3b30ffc 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit b529515eb158edd0919746ce4e545fe0879d6437 +Subproject commit 3b30ffc5e8c16d116b57458b0d91e75a29f98bf5 diff --git a/node_mlx.node.d.ts b/node_mlx.node.d.ts index 2000c8a..1e7efa5 100644 --- a/node_mlx.node.d.ts +++ b/node_mlx.node.d.ts @@ -240,6 +240,10 @@ declare module '*node_mlx.node' { function notEqual(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array; function erf(array: ScalarOrArray, s?: StreamOrDevice): array; function erfinv(array: ScalarOrArray, s?: StreamOrDevice): array; + function lgamma(array: ScalarOrArray, s?: StreamOrDevice): array; + function digamma(array: ScalarOrArray, s?: StreamOrDevice): array; + function besselI0e(array: ScalarOrArray, s?: StreamOrDevice): array; + function besselI1e(array: ScalarOrArray, s?: StreamOrDevice): array; function exp(array: ScalarOrArray, s?: StreamOrDevice): array; function expm1(array: ScalarOrArray, s?: StreamOrDevice): array; function expandDims(array: ScalarOrArray, dims: number | number[], s?: StreamOrDevice): array; @@ -378,6 +382,7 @@ declare module '*node_mlx.node' { function tidy(func: () => U): U; function dispose(...args: unknown[]): void; function getWrappersCount(): number; + function sweepDeadArrays(): number; // Metal. namespace metal { diff --git a/src/array.cc b/src/array.cc index 8edbb55..213c928 100644 --- a/src/array.cc +++ b/src/array.cc @@ -483,11 +483,17 @@ napi_value Tidy(napi_env env, std::function func) { // 3. The JS object is marked as dead, but the finalizer has not run. // We have to unbind the JS object in 1, and only delete array in 1 // and 3. + int64_t ext = ki::internal::ExternalMemorySize::Get(a); napi_value value; if (instance_data->GetWrapper(a, &value)) napi_remove_wrap(env, value, nullptr); - if (instance_data->DeleteWrapper(a)) + if (instance_data->DeleteWrapper(a)) { + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } delete a; + } } return result; }, @@ -504,8 +510,13 @@ void Dispose(const ki::Arguments& args) { TreeVisit(args.Env(), args[i], [instance_data](napi_env env, napi_value value) { if (auto a = ki::FromNodeTo(env, value); a) { + int64_t ext = ki::internal::ExternalMemorySize::Get(a.value()); napi_remove_wrap(env, value, nullptr); instance_data->DeleteWrapper(a.value()); + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } delete a.value(); } return napi_value(); @@ -518,6 +529,25 @@ size_t GetWrappersCount(napi_env env) { return ki::InstanceData::Get(env)->GetWrappersCount(); } +// Synchronously sweep dead array wrappers. +// Finds arrays whose JS wrappers have been GC'd but whose deferred finalizers +// haven't run yet, and immediately frees the native Metal buffers. +// Returns the number of arrays swept. +size_t SweepDeadArrays(napi_env env) { + ki::InstanceData* instance_data = ki::InstanceData::Get(env); + auto dead_ptrs = instance_data->CollectDeadWrappers(); + for (void* ptr : dead_ptrs) { + mx::array* a = static_cast(ptr); + int64_t ext = ki::internal::ExternalMemorySize::Get(a); + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } + delete a; + } + return dead_ptrs.size(); +} + } // namespace namespace ki { @@ -808,5 +838,6 @@ void InitArray(napi_env env, napi_value exports) { ki::Set(env, exports, "tidy", &Tidy, "dispose", &Dispose, - "getWrappersCount", &GetWrappersCount); + "getWrappersCount", &GetWrappersCount, + "sweepDeadArrays", &SweepDeadArrays); } diff --git a/src/array.h b/src/array.h index 901e2a5..f465422 100644 --- a/src/array.h +++ b/src/array.h @@ -65,6 +65,28 @@ struct Type : public AllowPassByValue { napi_value value); }; +namespace internal { + +// Report external memory for mx::array to enable GC pressure signaling. +// MLX arrays hold Metal GPU buffers that are invisible to the JS GC. +// Without this, the GC doesn't know about GPU memory pressure and doesn't +// collect array wrappers fast enough, causing Metal resource exhaustion. +template<> +struct ExternalMemorySize { + static int64_t Get(mx::array* a) { + // Metal has a hard limit of 499K buffer allocations. We must create + // enough external memory pressure to force the GC to collect array + // wrappers before hitting it. Report 1MB per array as the minimum + // external cost — this is much larger than the actual data size but + // necessary to trigger sufficiently aggressive GC for GPU resources. + size_t n = a->nbytes(); + constexpr int64_t min_cost = 1024 * 1024; // 1MB + return static_cast(n) > min_cost ? static_cast(n) : min_cost; + } +}; + +} // namespace internal + } // namespace ki #endif // SRC_ARRAY_H_ diff --git a/src/fast.cc b/src/fast.cc index 6cfe1e0..24aa539 100644 --- a/src/fast.cc +++ b/src/fast.cc @@ -42,16 +42,16 @@ mx::array ScaledDotProductAttention( throw std::invalid_argument(msg.str()); } return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, mask_str, {}, s); + queries, keys, values, scale, mask_str, {}, {}, s); } else { auto mask_arr = std::get(mask); return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {mask_arr}, s); + queries, keys, values, scale, "", {mask_arr}, {}, s); } } else { return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {}, s); + queries, keys, values, scale, "", {}, {}, s); } } diff --git a/src/fft.cc b/src/fft.cc index 808889a..51cccca 100644 --- a/src/fft.cc +++ b/src/fft.cc @@ -32,7 +32,7 @@ std::function FFTNOpWrapper(const char* name, mx::array(*func1)(const mx::array&, - const std::vector&, + const mx::Shape&, const std::vector&, mx::StreamOrDevice), mx::array(*func2)(const mx::array&, @@ -45,16 +45,17 @@ FFTNOpWrapper(const char* name, std::optional> axes, mx::StreamOrDevice s) { if (n && axes) { - return mx::fft::fftn(a, std::move(*n), std::move(*axes), s); + mx::Shape shape_n(n->begin(), n->end()); + return func1(a, shape_n, std::move(*axes), s); } else if (axes) { - return mx::fft::fftn(a, std::move(*axes), s); + return func2(a, std::move(*axes), s); } else if (n) { std::ostringstream msg; msg << "[" << name << "] " << "`axes` should not be `None` if `s` is not `None`."; throw std::invalid_argument(msg.str()); } else { - return mx::fft::fftn(a, s); + return func3(a, s); } }; } @@ -66,7 +67,7 @@ std::function FFT2OpWrapper(const char* name, mx::array(*func1)(const mx::array&, - const std::vector&, + const mx::Shape&, const std::vector&, mx::StreamOrDevice), mx::array(*func2)(const mx::array&, diff --git a/src/indexing.cc b/src/indexing.cc index e5d7286..b61dc12 100644 --- a/src/indexing.cc +++ b/src/indexing.cc @@ -563,7 +563,7 @@ ScatterResult ScatterArgsNDimentional(const mx::array* a, a->shape().begin() + non_none_indices, a->shape().end()); up = mx::reshape(std::move(up), std::move(up_reshape)); - mx::Shape axes(arr_indices.size(), 0); + std::vector axes(arr_indices.size(), 0); std::iota(axes.begin(), axes.end(), 0); return {std::move(arr_indices), std::move(up), std::move(axes)}; } diff --git a/src/metal.cc b/src/metal.cc index 91f79ae..777f1b4 100644 --- a/src/metal.cc +++ b/src/metal.cc @@ -1,4 +1,14 @@ #include "src/bindings.h" +#include "mlx/backend/gpu/device_info.h" + +namespace metal_ops { + +const std::unordered_map>& +DeviceInfo() { + return mx::gpu::device_info(0); +} + +} // namespace metal_ops void InitMetal(napi_env env, napi_value exports) { napi_value metal = ki::CreateObject(env); @@ -8,5 +18,5 @@ void InitMetal(napi_env env, napi_value exports) { "isAvailable", &mx::metal::is_available, "startCapture", &mx::metal::start_capture, "stopCapture", &mx::metal::stop_capture, - "deviceInfo", &mx::metal::device_info); + "deviceInfo", &metal_ops::DeviceInfo); } diff --git a/src/ops.cc b/src/ops.cc index aad9d05..7707dd4 100644 --- a/src/ops.cc +++ b/src/ops.cc @@ -191,7 +191,7 @@ mx::array Full(std::variant shape, ScalarOrArray vals, std::optional dtype, mx::StreamOrDevice s) { - return mx::full(PutIntoVector(std::move(shape)), + return mx::full(PutIntoShape(std::move(shape)), ToArray(std::move(vals), std::move(dtype)), s); } @@ -199,13 +199,13 @@ mx::array Full(std::variant shape, mx::array Zeros(std::variant shape, std::optional dtype, mx::StreamOrDevice s) { - return mx::zeros(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s); + return mx::zeros(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s); } mx::array Ones(std::variant shape, std::optional dtype, mx::StreamOrDevice s) { - return mx::ones(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s); + return mx::ones(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s); } mx::array Eye(int n, @@ -303,8 +303,9 @@ std::vector Split(const mx::array& a, if (auto i = std::get_if(&indices); i) { return mx::split(a, *i, axis.value_or(0), s); } else { - return mx::split(a, std::move(std::get>(indices)), - axis.value_or(0), s); + auto& v = std::get>(indices); + mx::Shape shape_indices(v.begin(), v.end()); + return mx::split(a, std::move(shape_indices), axis.value_or(0), s); } } @@ -544,7 +545,7 @@ mx::array ConvTranspose1d( mx::StreamOrDevice s) { return mx::conv_transpose1d(input, weight, stride.value_or(1), padding.value_or(0), dilation.value_or(1), - groups.value_or(1), s); + /*output_padding=*/0, groups.value_or(1), s); } mx::array ConvTranspose2d( @@ -574,7 +575,7 @@ mx::array ConvTranspose2d( dilation_pair = std::move(*p); return mx::conv_transpose2d(input, weight, stride_pair, padding_pair, - dilation_pair, groups.value_or(1), s); + dilation_pair, {0, 0}, groups.value_or(1), s); } mx::array ConvTranspose3d( @@ -604,7 +605,7 @@ mx::array ConvTranspose3d( dilation_tuple = std::move(*p); return mx::conv_transpose3d(input, weight, stride_tuple, padding_tuple, - dilation_tuple, groups.value_or(1), s); + dilation_tuple, {0, 0, 0}, groups.value_or(1), s); } mx::array ConvGeneral( @@ -789,6 +790,10 @@ void InitOps(napi_env env, napi_value exports) { "expm1", &mx::expm1, "erf", &mx::erf, "erfinv", &mx::erfinv, + "lgamma", &mx::lgamma, + "digamma", &mx::digamma, + "besselI0e", &mx::bessel_i0e, + "besselI1e", &mx::bessel_i1e, "sin", &mx::sin, "cos", &mx::cos, "tan", &mx::tan, @@ -811,7 +816,9 @@ void InitOps(napi_env env, napi_value exports) { "stopGradient", &mx::stop_gradient, "sigmoid", &mx::sigmoid, "power", BinOpWrapper(&mx::power), - "arange", &ops::ARange, + "arange", &ops::ARange); + + ki::Set(env, exports, "linspace", &ops::Linspace, "kron", &mx::kron, "take", &ops::Take, @@ -848,12 +855,16 @@ void InitOps(napi_env env, napi_value exports) { "min", DimOpWrapper(&mx::min), "max", DimOpWrapper(&mx::max), "logcumsumexp", CumOpWrapper(&mx::logcumsumexp), - "logsumexp", DimOpWrapper(&mx::logsumexp), + "logsumexp", DimOpWrapper(&mx::logsumexp)); + + ki::Set(env, exports, "mean", DimOpWrapper(&mx::mean), "variance", &ops::Var, "std", &ops::Std, "split", &ops::Split, - "argmin", &ops::ArgMin, + "argmin", &ops::ArgMin); + + ki::Set(env, exports, "argmax", &ops::ArgMax, "sort", &ops::Sort, "argsort", &ops::ArgSort, @@ -864,7 +875,9 @@ void InitOps(napi_env env, napi_value exports) { "blockMaskedMM", &mx::block_masked_mm, "gatherMM", &mx::gather_mm, "gatherQMM", &mx::gather_qmm, - "softmax", &ops::Softmax, + "softmax", &ops::Softmax); + + ki::Set(env, exports, "concatenate", &ops::Concatenate, "concat", &ops::Concatenate, "stack", &ops::Stack, @@ -876,7 +889,9 @@ void InitOps(napi_env env, napi_value exports) { "cumsum", CumOpWrapper(&mx::cumsum), "cumprod", CumOpWrapper(&mx::cumprod), "cummax", CumOpWrapper(&mx::cummax), - "cummin", CumOpWrapper(&mx::cummin), + "cummin", CumOpWrapper(&mx::cummin)); + + ki::Set(env, exports, "conj", &mx::conjugate, "conjugate", &mx::conjugate, "convolve", &ops::Convolve, @@ -912,7 +927,9 @@ void InitOps(napi_env env, napi_value exports) { "bitwiseXor", BinOpWrapper(&mx::bitwise_xor), "leftShift", BinOpWrapper(&mx::left_shift), "rightShift", BinOpWrapper(&mx::right_shift), - "view", &mx::view, + "view", &mx::view); + + ki::Set(env, exports, "hadamardTransform", &mx::hadamard_transform, "einsumPath", &mx::einsum_path, "einsum", &mx::einsum, diff --git a/src/utils.cc b/src/utils.cc index b827c53..f309c48 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -1,7 +1,7 @@ #include "src/array.h" #include "src/utils.h" -mx::Shape PutIntoVector(std::variant shape) { +mx::Shape PutIntoShape(std::variant shape) { if (auto i = std::get_if(&shape); i) return {*i}; return std::move(std::get(shape)); diff --git a/src/utils.h b/src/utils.h index 8e9fd41..3cf2fb2 100644 --- a/src/utils.h +++ b/src/utils.h @@ -8,12 +8,45 @@ namespace mx = mlx::core; +// Teach kizunapi how to serialize/deserialize SmallVector (used for Shape +// and other types in MLX >= 0.26). Mirrors the std::vector specialization. +namespace ki { + +template +struct Type> { + static constexpr const char* name = "Array"; + static napi_status ToNode(napi_env env, + const mlx::core::SmallVector& vec, + napi_value* result) { + napi_status s = napi_create_array_with_length(env, vec.size(), result); + if (s != napi_ok) return s; + for (size_t i = 0; i < vec.size(); ++i) { + napi_value el; + s = ConvertToNode(env, vec[i], &el); + if (s != napi_ok) return s; + s = napi_set_element(env, *result, i, el); + if (s != napi_ok) return s; + } + return napi_ok; + } + static std::optional> FromNode( + napi_env env, napi_value value) { + // Read as std::vector then convert to SmallVector. + auto vec = Type>::FromNode(env, value); + if (!vec) return std::nullopt; + return mlx::core::SmallVector(vec->begin(), vec->end()); + } +}; + +} // namespace ki + using OptionalAxes = std::variant>; using ScalarOrArray = std::variant; -// Read args into a vector of types. -template -bool ReadArgs(ki::Arguments* args, std::vector* results) { +// Read args into a container of types (vector or SmallVector). +template +bool ReadArgs(ki::Arguments* args, Container* results) { + using T = typename Container::value_type; while (args->RemainingsLength() > 0) { std::optional a = args->GetNext(); if (!a) { @@ -45,8 +78,15 @@ void DefineToString(napi_env env, napi_value prototype) { symbol, ki::MemberFunction(&ToString)); } +// If input is one int, put it into a Shape, otherwise just return the Shape. +mx::Shape PutIntoShape(std::variant shape); + // If input is one int, put it into a vector, otherwise just return the vector. -std::vector PutIntoVector(std::variant> shape); +inline std::vector PutIntoVector(std::variant> v) { + if (auto i = std::get_if(&v); i) + return {*i}; + return std::move(std::get>(v)); +} // Get axis arg from js value. std::vector GetReduceAxes(OptionalAxes value, int dims);