Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[submodule "deps/mlx"]
path = deps/mlx
url = https://github.com/ml-explore/mlx
url = https://github.com/robert-johansson/mlx
[submodule "deps/kizunapi"]
path = deps/kizunapi
url = https://github.com/photoionization/kizunapi
url = https://github.com/robert-johansson/kizunapi
2 changes: 1 addition & 1 deletion deps/kizunapi
2 changes: 1 addition & 1 deletion deps/mlx
Submodule mlx updated 689 files
5 changes: 5 additions & 0 deletions node_mlx.node.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ declare module '*node_mlx.node' {
function notEqual(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array;
function erf(array: ScalarOrArray, s?: StreamOrDevice): array;
function erfinv(array: ScalarOrArray, s?: StreamOrDevice): array;
function lgamma(array: ScalarOrArray, s?: StreamOrDevice): array;
function digamma(array: ScalarOrArray, s?: StreamOrDevice): array;
function besselI0e(array: ScalarOrArray, s?: StreamOrDevice): array;
function besselI1e(array: ScalarOrArray, s?: StreamOrDevice): array;
function exp(array: ScalarOrArray, s?: StreamOrDevice): array;
function expm1(array: ScalarOrArray, s?: StreamOrDevice): array;
function expandDims(array: ScalarOrArray, dims: number | number[], s?: StreamOrDevice): array;
Expand Down Expand Up @@ -378,6 +382,7 @@ declare module '*node_mlx.node' {
function tidy<U>(func: () => U): U;
function dispose(...args: unknown[]): void;
function getWrappersCount(): number;
function sweepDeadArrays(): number;

// Metal.
namespace metal {
Expand Down
35 changes: 33 additions & 2 deletions src/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,17 @@ napi_value Tidy(napi_env env, std::function<napi_value()> func) {
// 3. The JS object is marked as dead, but the finalizer has not run.
// We have to unbind the JS object in 1, and only delete array in 1
// and 3.
int64_t ext = ki::internal::ExternalMemorySize<mx::array>::Get(a);
napi_value value;
if (instance_data->GetWrapper<mx::array>(a, &value))
napi_remove_wrap(env, value, nullptr);
if (instance_data->DeleteWrapper<mx::array>(a))
if (instance_data->DeleteWrapper<mx::array>(a)) {
if (ext > 0) {
int64_t adjusted;
napi_adjust_external_memory(env, -ext, &adjusted);
}
delete a;
}
}
return result;
},
Expand All @@ -504,8 +510,13 @@ void Dispose(const ki::Arguments& args) {
TreeVisit(args.Env(), args[i],
[instance_data](napi_env env, napi_value value) {
if (auto a = ki::FromNodeTo<mx::array*>(env, value); a) {
int64_t ext = ki::internal::ExternalMemorySize<mx::array>::Get(a.value());
napi_remove_wrap(env, value, nullptr);
instance_data->DeleteWrapper<mx::array>(a.value());
if (ext > 0) {
int64_t adjusted;
napi_adjust_external_memory(env, -ext, &adjusted);
}
delete a.value();
}
return napi_value();
Expand All @@ -518,6 +529,25 @@ size_t GetWrappersCount(napi_env env) {
return ki::InstanceData::Get(env)->GetWrappersCount();
}

// Synchronously sweep dead array wrappers.
// Finds arrays whose JS wrappers have been GC'd but whose deferred finalizers
// haven't run yet, and immediately frees the native Metal buffers.
// Returns the number of arrays swept.
size_t SweepDeadArrays(napi_env env) {
ki::InstanceData* instance_data = ki::InstanceData::Get(env);
auto dead_ptrs = instance_data->CollectDeadWrappers<mx::array>();
for (void* ptr : dead_ptrs) {
mx::array* a = static_cast<mx::array*>(ptr);
int64_t ext = ki::internal::ExternalMemorySize<mx::array>::Get(a);
if (ext > 0) {
int64_t adjusted;
napi_adjust_external_memory(env, -ext, &adjusted);
}
delete a;
}
return dead_ptrs.size();
}

} // namespace

namespace ki {
Expand Down Expand Up @@ -808,5 +838,6 @@ void InitArray(napi_env env, napi_value exports) {
ki::Set(env, exports,
"tidy", &Tidy,
"dispose", &Dispose,
"getWrappersCount", &GetWrappersCount);
"getWrappersCount", &GetWrappersCount,
"sweepDeadArrays", &SweepDeadArrays);
}
22 changes: 22 additions & 0 deletions src/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,28 @@ struct Type<mx::array> : public AllowPassByValue<mx::array> {
napi_value value);
};

namespace internal {

// Report external memory for mx::array to enable GC pressure signaling.
// MLX arrays hold Metal GPU buffers that are invisible to the JS GC.
// Without this, the GC doesn't know about GPU memory pressure and doesn't
// collect array wrappers fast enough, causing Metal resource exhaustion.
template<>
struct ExternalMemorySize<mx::array> {
static int64_t Get(mx::array* a) {
// Metal has a hard limit of 499K buffer allocations. We must create
// enough external memory pressure to force the GC to collect array
// wrappers before hitting it. Report 1MB per array as the minimum
// external cost — this is much larger than the actual data size but
// necessary to trigger sufficiently aggressive GC for GPU resources.
size_t n = a->nbytes();
constexpr int64_t min_cost = 1024 * 1024; // 1MB
return static_cast<int64_t>(n) > min_cost ? static_cast<int64_t>(n) : min_cost;
}
};

} // namespace internal

} // namespace ki

#endif // SRC_ARRAY_H_
6 changes: 3 additions & 3 deletions src/fast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ mx::array ScaledDotProductAttention(
throw std::invalid_argument(msg.str());
}
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, mask_str, {}, s);
queries, keys, values, scale, mask_str, {}, {}, s);
} else {
auto mask_arr = std::get<mx::array>(mask);
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {mask_arr}, s);
queries, keys, values, scale, "", {mask_arr}, {}, s);
}

} else {
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {}, s);
queries, keys, values, scale, "", {}, {}, s);
}
}

Expand Down
11 changes: 6 additions & 5 deletions src/fft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::function<mx::array(const mx::array& a,
mx::StreamOrDevice s)>
FFTNOpWrapper(const char* name,
mx::array(*func1)(const mx::array&,
const std::vector<int>&,
const mx::Shape&,
const std::vector<int>&,
mx::StreamOrDevice),
mx::array(*func2)(const mx::array&,
Expand All @@ -45,16 +45,17 @@ FFTNOpWrapper(const char* name,
std::optional<std::vector<int>> axes,
mx::StreamOrDevice s) {
if (n && axes) {
return mx::fft::fftn(a, std::move(*n), std::move(*axes), s);
mx::Shape shape_n(n->begin(), n->end());
return func1(a, shape_n, std::move(*axes), s);
} else if (axes) {
return mx::fft::fftn(a, std::move(*axes), s);
return func2(a, std::move(*axes), s);
} else if (n) {
std::ostringstream msg;
msg << "[" << name << "] "
<< "`axes` should not be `None` if `s` is not `None`.";
throw std::invalid_argument(msg.str());
} else {
return mx::fft::fftn(a, s);
return func3(a, s);
}
};
}
Expand All @@ -66,7 +67,7 @@ std::function<mx::array(const mx::array& a,
mx::StreamOrDevice s)>
FFT2OpWrapper(const char* name,
mx::array(*func1)(const mx::array&,
const std::vector<int>&,
const mx::Shape&,
const std::vector<int>&,
mx::StreamOrDevice),
mx::array(*func2)(const mx::array&,
Expand Down
2 changes: 1 addition & 1 deletion src/indexing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ ScatterResult ScatterArgsNDimentional(const mx::array* a,
a->shape().begin() + non_none_indices, a->shape().end());
up = mx::reshape(std::move(up), std::move(up_reshape));

mx::Shape axes(arr_indices.size(), 0);
std::vector<int> axes(arr_indices.size(), 0);
std::iota(axes.begin(), axes.end(), 0);
return {std::move(arr_indices), std::move(up), std::move(axes)};
}
Expand Down
12 changes: 11 additions & 1 deletion src/metal.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
#include "src/bindings.h"
#include "mlx/backend/gpu/device_info.h"

namespace metal_ops {

const std::unordered_map<std::string, std::variant<std::string, size_t>>&
DeviceInfo() {
return mx::gpu::device_info(0);
}

} // namespace metal_ops

void InitMetal(napi_env env, napi_value exports) {
napi_value metal = ki::CreateObject(env);
Expand All @@ -8,5 +18,5 @@ void InitMetal(napi_env env, napi_value exports) {
"isAvailable", &mx::metal::is_available,
"startCapture", &mx::metal::start_capture,
"stopCapture", &mx::metal::stop_capture,
"deviceInfo", &mx::metal::device_info);
"deviceInfo", &metal_ops::DeviceInfo);
}
45 changes: 31 additions & 14 deletions src/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,21 @@ mx::array Full(std::variant<int, mx::Shape> shape,
ScalarOrArray vals,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
return mx::full(PutIntoVector(std::move(shape)),
return mx::full(PutIntoShape(std::move(shape)),
ToArray(std::move(vals), std::move(dtype)),
s);
}

mx::array Zeros(std::variant<int, mx::Shape> shape,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
return mx::zeros(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s);
return mx::zeros(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s);
}

mx::array Ones(std::variant<int, mx::Shape> shape,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
return mx::ones(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s);
return mx::ones(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s);
}

mx::array Eye(int n,
Expand Down Expand Up @@ -303,8 +303,9 @@ std::vector<mx::array> Split(const mx::array& a,
if (auto i = std::get_if<int>(&indices); i) {
return mx::split(a, *i, axis.value_or(0), s);
} else {
return mx::split(a, std::move(std::get<std::vector<int>>(indices)),
axis.value_or(0), s);
auto& v = std::get<std::vector<int>>(indices);
mx::Shape shape_indices(v.begin(), v.end());
return mx::split(a, std::move(shape_indices), axis.value_or(0), s);
}
}

Expand Down Expand Up @@ -544,7 +545,7 @@ mx::array ConvTranspose1d(
mx::StreamOrDevice s) {
return mx::conv_transpose1d(input, weight, stride.value_or(1),
padding.value_or(0), dilation.value_or(1),
groups.value_or(1), s);
/*output_padding=*/0, groups.value_or(1), s);
}

mx::array ConvTranspose2d(
Expand Down Expand Up @@ -574,7 +575,7 @@ mx::array ConvTranspose2d(
dilation_pair = std::move(*p);

return mx::conv_transpose2d(input, weight, stride_pair, padding_pair,
dilation_pair, groups.value_or(1), s);
dilation_pair, {0, 0}, groups.value_or(1), s);
}

mx::array ConvTranspose3d(
Expand Down Expand Up @@ -604,7 +605,7 @@ mx::array ConvTranspose3d(
dilation_tuple = std::move(*p);

return mx::conv_transpose3d(input, weight, stride_tuple, padding_tuple,
dilation_tuple, groups.value_or(1), s);
dilation_tuple, {0, 0, 0}, groups.value_or(1), s);
}

mx::array ConvGeneral(
Expand Down Expand Up @@ -789,6 +790,10 @@ void InitOps(napi_env env, napi_value exports) {
"expm1", &mx::expm1,
"erf", &mx::erf,
"erfinv", &mx::erfinv,
"lgamma", &mx::lgamma,
"digamma", &mx::digamma,
"besselI0e", &mx::bessel_i0e,
"besselI1e", &mx::bessel_i1e,
"sin", &mx::sin,
"cos", &mx::cos,
"tan", &mx::tan,
Expand All @@ -811,7 +816,9 @@ void InitOps(napi_env env, napi_value exports) {
"stopGradient", &mx::stop_gradient,
"sigmoid", &mx::sigmoid,
"power", BinOpWrapper(&mx::power),
"arange", &ops::ARange,
"arange", &ops::ARange);

ki::Set(env, exports,
"linspace", &ops::Linspace,
"kron", &mx::kron,
"take", &ops::Take,
Expand Down Expand Up @@ -848,12 +855,16 @@ void InitOps(napi_env env, napi_value exports) {
"min", DimOpWrapper(&mx::min),
"max", DimOpWrapper(&mx::max),
"logcumsumexp", CumOpWrapper(&mx::logcumsumexp),
"logsumexp", DimOpWrapper(&mx::logsumexp),
"logsumexp", DimOpWrapper(&mx::logsumexp));

ki::Set(env, exports,
"mean", DimOpWrapper(&mx::mean),
"variance", &ops::Var,
"std", &ops::Std,
"split", &ops::Split,
"argmin", &ops::ArgMin,
"argmin", &ops::ArgMin);

ki::Set(env, exports,
"argmax", &ops::ArgMax,
"sort", &ops::Sort,
"argsort", &ops::ArgSort,
Expand All @@ -864,7 +875,9 @@ void InitOps(napi_env env, napi_value exports) {
"blockMaskedMM", &mx::block_masked_mm,
"gatherMM", &mx::gather_mm,
"gatherQMM", &mx::gather_qmm,
"softmax", &ops::Softmax,
"softmax", &ops::Softmax);

ki::Set(env, exports,
"concatenate", &ops::Concatenate,
"concat", &ops::Concatenate,
"stack", &ops::Stack,
Expand All @@ -876,7 +889,9 @@ void InitOps(napi_env env, napi_value exports) {
"cumsum", CumOpWrapper(&mx::cumsum),
"cumprod", CumOpWrapper(&mx::cumprod),
"cummax", CumOpWrapper(&mx::cummax),
"cummin", CumOpWrapper(&mx::cummin),
"cummin", CumOpWrapper(&mx::cummin));

ki::Set(env, exports,
"conj", &mx::conjugate,
"conjugate", &mx::conjugate,
"convolve", &ops::Convolve,
Expand Down Expand Up @@ -912,7 +927,9 @@ void InitOps(napi_env env, napi_value exports) {
"bitwiseXor", BinOpWrapper(&mx::bitwise_xor),
"leftShift", BinOpWrapper(&mx::left_shift),
"rightShift", BinOpWrapper(&mx::right_shift),
"view", &mx::view,
"view", &mx::view);

ki::Set(env, exports,
"hadamardTransform", &mx::hadamard_transform,
"einsumPath", &mx::einsum_path,
"einsum", &mx::einsum,
Expand Down
2 changes: 1 addition & 1 deletion src/utils.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "src/array.h"
#include "src/utils.h"

mx::Shape PutIntoVector(std::variant<int, mx::Shape> shape) {
mx::Shape PutIntoShape(std::variant<int, mx::Shape> shape) {
if (auto i = std::get_if<int>(&shape); i)
return {*i};
return std::move(std::get<mx::Shape>(shape));
Expand Down
Loading