Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions c_src/emlx_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,14 +911,53 @@ NIF(isclose) {
NIF(item) {
TENSOR_PARAM(0, t);
mlx::core::eval(*t);
auto dtype_kind = mlx::core::kindof(t->dtype());

if (dtype_kind == mlx::core::Dtype::Kind::u ||
dtype_kind == mlx::core::Dtype::Kind::i ||
dtype_kind == mlx::core::Dtype::Kind::b) {
// Fix for MLX scalar layout bug: Use the correct type when calling item<T>()
// to avoid reading wrong number of bytes from potentially invalid memory
// layouts.
auto dtype = t->dtype();

// Handle integer and boolean types with proper dtype matching
if (dtype == mlx::core::bool_) {
bool value = t->item<bool>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::uint8) {
uint8_t value = t->item<uint8_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::uint16) {
uint16_t value = t->item<uint16_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::uint32) {
uint32_t value = t->item<uint32_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::uint64) {
uint64_t value = t->item<uint64_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::int8) {
int8_t value = t->item<int8_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::int16) {
int16_t value = t->item<int16_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::int32) {
int32_t value = t->item<int32_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::int64) {
int64_t value = t->item<int64_t>();
return nx::nif::ok(env, nx::nif::make(env, value));
} else if (dtype == mlx::core::float16 || dtype == mlx::core::bfloat16) {
// MLX handles float16/bfloat16 conversion internally
float value = t->item<float>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<double>(value)));
} else if (dtype == mlx::core::float32) {
float value = t->item<float>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<double>(value)));
} else if (dtype == mlx::core::complex64) {
// Complex types need special handling - not supported via item()
return nx::nif::error(env,
"Complex scalar extraction not supported via item()");
} else {
// Fallback for any other types
double value = t->item<double>();
return nx::nif::ok(env, nx::nif::make(env, value));
}
Expand Down
135 changes: 135 additions & 0 deletions test/emlx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,139 @@ defmodule EMLXTest do
assert_equal(right, Nx.tensor(-1))
end)
end

describe "scalar item extraction (MLX layout bug fix)" do
# Tests for the fix in emlx_nif.cpp:item()
# The bug: MLX creates scalars with invalid memory layout after slice→squeeze
# The fix: Call item<T>() with the correct dtype instead of always using int64/double

# Helper to call EMLX.item() directly (bypasses any Elixir workarounds)
defp item_direct(tensor) do
{_device, ref} = EMLX.Backend.from_nx(tensor)
EMLX.item({:cpu, ref})
end

test "extracts int32 scalar from slice→squeeze" do
array = Nx.iota({1000}, type: :s32)

# Test various indices that previously failed
for idx <- [0, 1, 100, 500, 900, 951, 998, 999] do
sliced = Nx.slice_along_axis(array, idx, 1, axis: 0)
scalar = Nx.squeeze(sliced, axes: [0])
value = item_direct(scalar)

assert value == idx, "Expected #{idx}, got #{value} for int32 scalar"
end
end

test "extracts int8 scalar correctly" do
array = Nx.iota({128}, type: :s8)

for idx <- [0, 1, 50, 100, 127] do
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
assert item_direct(scalar) == idx
end
end

test "extracts int16 scalar correctly" do
array = Nx.iota({1000}, type: :s16)

for idx <- [0, 1, 500, 999] do
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
assert item_direct(scalar) == idx
end
end

test "extracts int64 scalar correctly" do
array = Nx.iota({100}, type: :s64)

for idx <- [0, 1, 50, 99] do
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
assert item_direct(scalar) == idx
end
end

test "extracts uint8 scalar correctly" do
array = Nx.iota({200}, type: :u8)

for idx <- [0, 1, 100, 199] do
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
assert item_direct(scalar) == idx
end
end

test "extracts uint16 scalar correctly" do
array = Nx.iota({1000}, type: :u16)

for idx <- [0, 1, 500, 999] do
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
assert item_direct(scalar) == idx
end
end

test "extracts uint32 scalar correctly" do
array = Nx.iota({1000}, type: :u32)

for idx <- [0, 1, 500, 951, 999] do
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
assert item_direct(scalar) == idx
end
end

test "extracts uint64 scalar correctly" do
array = Nx.iota({100}, type: :u64)

for idx <- [0, 1, 50, 99] do
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
assert item_direct(scalar) == idx
end
end

test "extracts float32 scalar correctly" do
array = Nx.iota({100}, type: :f32)

for idx <- [0, 1, 50, 99] do
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
assert_in_delta item_direct(scalar), idx * 1.0, 1.0e-6
end
end

test "extracts boolean scalar correctly" do
# Create array [0, 1, 0, 1, ...] as uint8
array = Nx.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], type: :u8)

for idx <- [0, 1, 2, 3] do
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
expected = rem(idx, 2)
assert item_direct(scalar) == expected
end
end

test "direct scalar creation works (baseline)" do
# Ensure direct scalar creation still works
scalar = Nx.tensor(951, type: :s32)
assert item_direct(scalar) == 951
end

test "negative values work correctly" do
array = Nx.tensor([-100, -50, 0, 50, 100], type: :s32)

for {expected, idx} <- Enum.with_index([-100, -50, 0, 50, 100]) do
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
assert item_direct(scalar) == expected
end
end

test "edge values for int32" do
# Test boundary values
max_val = 2_147_483_647
min_val = -2_147_483_648
array = Nx.tensor([min_val, -1, 0, 1, max_val], type: :s32)

for {expected, idx} <- Enum.with_index([min_val, -1, 0, 1, max_val]) do
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
assert item_direct(scalar) == expected
end
end
end
end