Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions c_src/emlx_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,6 @@ NIF(to_blob) {
int limit = 0;
bool has_received_limit = (argc == 2);

// Evaluate to ensure data is available
t->eval();

if (has_received_limit) {
PARAM(1, int, param_limit);
limit = param_limit;
Expand Down Expand Up @@ -1051,7 +1048,7 @@ static ErlNifFunc nif_funcs[] = {
{"slice", 5, slice},
{"slice_update", 5, slice_update},
{"squeeze", 3, squeeze},
{"item", 1, item},
{"item", 1, item, ERL_NIF_DIRTY_JOB_CPU_BOUND},
{"all", 4, all},
{"any", 4, any},
{"sum", 4, sum},
Expand All @@ -1067,8 +1064,8 @@ static ErlNifFunc nif_funcs[] = {
{"shape", 1, shape},
{"reshape", 3, reshape},
{"astype", 3, astype},
{"to_blob", 1, to_blob},
{"to_blob", 2, to_blob},
{"to_blob", 1, to_blob, ERL_NIF_DIRTY_JOB_CPU_BOUND},
{"to_blob", 2, to_blob, ERL_NIF_DIRTY_JOB_CPU_BOUND},
{"from_blob", 4, from_blob},
{"scalar_tensor", 3, scalar_tensor},
{"ones", 3, ones},
Expand Down Expand Up @@ -1154,7 +1151,8 @@ static ErlNifFunc nif_funcs[] = {
{"tri_inv", 3, tri_inv},
{"set_compile", 1, set_compile},
{"compile", 2, compile, ERL_NIF_DIRTY_JOB_CPU_BOUND},
{"call_compiled", 2, call_compiled, ERL_NIF_DIRTY_JOB_CPU_BOUND}};
{"call_compiled_cpu", 2, call_compiled, ERL_NIF_DIRTY_JOB_CPU_BOUND},
{"call_compiled_gpu", 2, call_compiled, ERL_NIF_DIRTY_JOB_IO_BOUND}};

// Update the NIF initialization
ERL_NIF_INIT(Elixir.EMLX.NIF, nif_funcs, load, NULL, upgrade, NULL)
24 changes: 19 additions & 5 deletions lib/emlx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,21 @@ defmodule EMLX do
deftensor clip(tensor, tensor_min, tensor_max)

## Dirty non-tensor return values
defvalue to_blob(tensor)
defvalue to_blob(tensor, limit)
defvalue scalar_type(tensor)
defvalue shape(tensor)

def to_blob({device, ref} = tensor) when is_tensor(device, ref) do
# Two-step to_blob: eval on main scheduler, then copy on dirty scheduler
eval(tensor)
EMLX.NIF.to_blob(ref) |> unwrap!()
end

def to_blob({device, ref} = tensor, limit) when is_tensor(device, ref) do
# Two-step to_blob: eval on main scheduler, then copy on dirty scheduler
eval(tensor)
EMLX.NIF.to_blob(ref, limit) |> unwrap!()
end

defp unwrap!(:ok), do: :ok
defp unwrap!({:ok, result}), do: result
defp unwrap!({:error, error}), do: raise(EMLX.NIFError, List.to_string(error))
Expand Down Expand Up @@ -305,7 +315,6 @@ defmodule EMLX do
defp merge_device(_, _), do: :cpu

defvalue deallocate(tensor_ref)

defvalue eval(tensor)

deftensor slice(tensor, starts, stops, strides)
Expand Down Expand Up @@ -409,9 +418,14 @@ defmodule EMLX do
cached_fun
end

nif_result =
case device do
:cpu -> EMLX.NIF.call_compiled_cpu(compiled_fun, nif_args)
:gpu -> EMLX.NIF.call_compiled_gpu(compiled_fun, nif_args)
end

results =
compiled_fun
|> EMLX.NIF.call_compiled(nif_args)
nif_result
|> unwrap!()
|> Enum.map(fn ref ->
EMLX.Backend.to_nx({device, ref})
Expand Down
15 changes: 14 additions & 1 deletion lib/emlx/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,20 @@ defmodule EMLX.NIF do
:erlang.nif_error(:nif_not_loaded)
end

def call_compiled(_compiled_fun, _args) do
# Device-specific NIFs for dirty scheduler optimization
def call_compiled_cpu(_compiled_fun, _args) do
:erlang.nif_error(:nif_not_loaded)
end

def call_compiled_gpu(_compiled_fun, _args) do
:erlang.nif_error(:nif_not_loaded)
end

def to_blob(_tensor) do
:erlang.nif_error(:nif_not_loaded)
end

def to_blob(_tensor, _limit) do
:erlang.nif_error(:nif_not_loaded)
end
end