Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion c_src/emlx_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1491,6 +1491,67 @@ NIF(as_strided) {
TENSOR(mlx::core::as_strided(*t, to_shape(shape), to_strides(strides), offset, device));
}

// ============================================================================
// Quantization Operations (for 4-bit model support)
// ============================================================================

// quantized_matmul - Multiplies x with a quantized weight matrix w
// This is the key operation for efficient 4-bit inference
// MLX API: quantized_matmul(x, w, scales, biases, transpose, group_size, bits, stream)
NIF(quantized_matmul) {
TENSOR_PARAM(0, x); // Input tensor [batch, seq, hidden]
TENSOR_PARAM(1, w); // Quantized weights [out/8, in] (uint32 packed)
TENSOR_PARAM(2, scales); // Scales [out/group_size, in] (bfloat16)
TENSOR_PARAM(3, biases); // Biases [out/group_size, in] (bfloat16)
PARAM(4, bool, transpose);
PARAM(5, int, group_size);
PARAM(6, int, bits);
DEVICE_PARAM(7, device);

TENSOR(mlx::core::quantized_matmul(
*x, *w, *scales, *biases, transpose, group_size, bits, "affine", device));
}

// dequantize - Converts quantized weights back to float
// Useful for debugging and verification
// MLX API: dequantize(w, scales, biases, group_size, bits, stream)
NIF(dequantize) {
TENSOR_PARAM(0, w); // Quantized weights (uint32 packed)
TENSOR_PARAM(1, scales); // Scales (bfloat16)
TENSOR_PARAM(2, biases); // Biases (bfloat16)
PARAM(3, int, group_size);
PARAM(4, int, bits);
DEVICE_PARAM(5, device);

TENSOR(mlx::core::dequantize(*w, *scales, *biases, group_size, bits, "affine", std::nullopt, std::nullopt, device));
}

// quantize - Quantizes a float tensor to packed format
// Returns tuple of {weights, scales, biases}
// MLX API: quantize(w, group_size, bits, stream) -> tuple<array, array, array>
NIF(quantize) {
TENSOR_PARAM(0, w); // Float weights to quantize
PARAM(1, int, group_size);
PARAM(2, int, bits);
DEVICE_PARAM(3, device);

try {
auto result = mlx::core::quantize(*w, group_size, bits, "affine", std::nullopt, device);

ERL_NIF_TERM result_tuple[3];
result_tuple[0] = create_tensor_resource(env, result[0]);
result_tuple[1] = create_tensor_resource(env, result[1]);
result_tuple[2] = create_tensor_resource(env, result[2]);

return nx::nif::ok(env, enif_make_tuple3(env, result_tuple[0], result_tuple[1], result_tuple[2]));
}
CATCH()
}

ASYNC_NIF(quantized_matmul)
ASYNC_NIF(dequantize)
ASYNC_NIF(quantize)

// Build a sliding window view of a padded tensor.
// padded: [...] of ndim n; window/strides: per-axis lists of length n.
// Returns a view of shape [o0,...,on-1, w0,...,wn-1] where
Expand Down Expand Up @@ -1997,7 +2058,11 @@ static ErlNifFunc nif_funcs[] = {

// ── Worker control NIFs.
{"command_queue_new", 1, command_queue_new},
{"command_queue_synchronize", 1, command_queue_synchronize}};
{"command_queue_synchronize", 1, command_queue_synchronize},
// Quantization operations (async — must run on a worker thread)
{"quantized_matmul", 9, quantized_matmul_async},
{"dequantize", 7, dequantize_async},
{"quantize", 5, quantize_async}};

ERL_NIF_INIT(Elixir.EMLX.NIF, nif_funcs, load, NULL, upgrade, NULL)

253 changes: 251 additions & 2 deletions lib/emlx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,231 @@ defmodule EMLX do
defvalue scalar_type(tensor)
defvalue shape(tensor)

## Quantization operations (for 4-bit model support)

@doc """
Performs quantized matrix multiplication.

This is the key operation for efficient 4-bit inference. It multiplies `x` with
quantized weights `w` (packed as uint32), using scales and biases for
dequantization during the computation.

## Parameters
- `x` - Input tensor (e.g., {batch, seq, hidden})
- `w` - Quantized weights as uint32 (8 int4 values packed per uint32)
- `scales` - Per-group scale factors (bfloat16)
- `biases` - Per-group zero points (bfloat16)
- `transpose` - Whether to transpose weights (default: true)
- `group_size` - Number of weights per scale/bias group (default: 64)
- `bits` - Quantization bits (default: 4)
"""
@mlx_function {:quantized_matmul, 9}
def quantized_matmul(
{dev_x, ref_x} = _tensor_x,
{dev_w, ref_w} = _tensor_w,
{dev_s, ref_s} = _tensor_scales,
{dev_b, ref_b} = _tensor_biases,
transpose \\ true,
group_size \\ 64,
bits \\ 4
)
when is_tensor(dev_x, ref_x) and is_tensor(dev_w, ref_w) and
is_tensor(dev_s, ref_s) and is_tensor(dev_b, ref_b) do
device = merge_device(merge_device(dev_x, dev_w), merge_device(dev_s, dev_b))
{worker, effective_device} = resolve_worker(device)

job_ref =
EMLX.NIF.quantized_matmul(
worker,
ref_x,
ref_w,
ref_s,
ref_b,
transpose,
group_size,
bits,
effective_device
)
|> unwrap!()

await_worker(job_ref) |> wrap_tensor(effective_device)
end

@doc """
Dequantizes packed weights to floating point.

Converts quantized weights back to their original floating point representation.
Useful for debugging and verification.

## Parameters
- `w` - Quantized weights as uint32 (packed int4 values)
- `scales` - Per-group scale factors
- `biases` - Per-group zero points
- `group_size` - Number of weights per group (default: 64)
- `bits` - Quantization bits (default: 4)
"""
@mlx_function {:dequantize, 7}
def dequantize(
{dev_w, ref_w} = _tensor_w,
{dev_s, ref_s} = _tensor_scales,
{dev_b, ref_b} = _tensor_biases,
group_size,
bits
)
when is_tensor(dev_w, ref_w) and is_tensor(dev_s, ref_s) and is_tensor(dev_b, ref_b) do
device = merge_device(dev_w, merge_device(dev_s, dev_b))
{worker, effective_device} = resolve_worker(device)

job_ref =
EMLX.NIF.dequantize(worker, ref_w, ref_s, ref_b, group_size, bits, effective_device)
|> unwrap!()

await_worker(job_ref) |> wrap_tensor(effective_device)
end

@doc """
Quantizes a floating point tensor to packed format.

Returns a tuple of `{quantized_weights, scales, biases}` where:
- `quantized_weights` - Packed uint32 tensor (8 int4 values per uint32)
- `scales` - Per-group scale factors
- `biases` - Per-group zero points

## Parameters
- `w` - Float tensor to quantize
- `group_size` - Number of weights per group (default: 64)
- `bits` - Quantization bits (default: 4)
"""
@mlx_function {:quantize, 5}
def quantize({dev_w, ref_w}, group_size, bits)
when is_tensor(dev_w, ref_w) do
device = dev_w
{worker, effective_device} = resolve_worker(device)

{weights_ref, scales_ref, biases_ref} =
EMLX.NIF.quantize(worker, ref_w, group_size, bits, effective_device)
|> unwrap!()
|> await_worker()

{{effective_device, weights_ref}, {effective_device, scales_ref},
{effective_device, biases_ref}}
end

@doc """
Quantize a dense 2-D `Nx.Tensor` and return an annotated quantized tensor.

The returned tensor carries the original logical shape and type (e.g.
`{:s, 4}`). Its backend stores the packed uint32 data and a
`EMLX.Quantization.Config` with scales, biases, `group_size`, and `bits`.

## Options

* `:type` — storage type: `{:s, 2}`, `{:s, 4}` (default), or `{:s, 8}`.
* `:group_size` — 32, 64, or 128 (default 64). Must evenly divide the last
dimension of `tensor`.
"""
@spec quantize(Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
def quantize(%Nx.Tensor{} = tensor, opts) when is_list(opts) do
type = Keyword.get(opts, :type, {:s, 4})
{_, bits} = type
group_size = Keyword.get(opts, :group_size, 64)

unless Nx.rank(tensor) == 2 do
raise ArgumentError,
"EMLX.quantize/2 requires a rank-2 tensor, got rank #{Nx.rank(tensor)}"
end

{_out_features, in_features} = Nx.shape(tensor)

unless rem(in_features, group_size) == 0 do
raise ArgumentError,
"EMLX.quantize/2 requires the last dimension (#{in_features}) " <>
"to be divisible by group_size (#{group_size})"
end

device_ref = EMLX.Backend.from_nx(tensor)
{weight_ref, scales_ref, biases_ref} = EMLX.quantize(device_ref, group_size, bits)

scales = EMLX.Backend.to_nx(scales_ref)
biases = EMLX.Backend.to_nx(biases_ref)

config = %EMLX.Quantization.Config{
scales: scales,
biases: biases,
group_size: group_size,
bits: bits
}

weight_shape = EMLX.shape(weight_ref)
template = Nx.template(Nx.shape(tensor), type)

%Nx.Tensor{
template
| data: %EMLX.Backend{
ref: weight_ref,
shape: weight_shape,
type: {:u, 32},
quantization_config: config
}
}
end

@doc """
Dequantize a quantized `Nx.Tensor` (created by `EMLX.quantize/2`) to a
dense float tensor by calling `mx::dequantize`.
"""
@spec dequantize(Nx.Tensor.t()) :: Nx.Tensor.t()
def dequantize(
%Nx.Tensor{
data: %EMLX.Backend{ref: weight_ref, quantization_config: cfg}
} = _qw
)
when not is_nil(cfg) do
EMLX.dequantize(
weight_ref,
EMLX.Backend.from_nx(cfg.scales),
EMLX.Backend.from_nx(cfg.biases),
cfg.group_size,
cfg.bits
)
|> EMLX.Backend.to_nx()
end

@doc """
Run `activation @ dequantize(qw)` using `mx::quantized_matmul`.

`qw` must be a quantized tensor produced by `EMLX.quantize/2`. Raises
`ArgumentError` if both arguments are quantized.
"""
@spec quantized_matmul(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
def quantized_matmul(%Nx.Tensor{} = activation, %Nx.Tensor{} = qw) do
cfg = qw.data.quantization_config

if is_nil(cfg) do
raise ArgumentError,
"EMLX.quantized_matmul/2: second argument must be a quantized tensor"
end

if not is_nil(activation.data.quantization_config) do
raise ArgumentError,
"EMLX.quantized_matmul/2 requires a dense activation as the first " <>
"argument; got two quantized tensors. Dequantize one of them first."
end

result =
EMLX.quantized_matmul(
EMLX.Backend.from_nx(activation),
qw.data.ref,
EMLX.Backend.from_nx(cfg.scales),
EMLX.Backend.from_nx(cfg.biases),
true,
cfg.group_size,
cfg.bits
)

EMLX.Backend.to_nx(result)
end

def to_blob({device, ref} = tensor) when is_tensor(device, ref) do
# Eval first so the underlying MLX array is materialised; then ask the
# worker for the contiguous-copy + zero-copy resource binary. Both
Expand Down Expand Up @@ -506,6 +731,18 @@ defmodule EMLX do
deftensor squeeze(tensor, axes)
defvalue strides(tensor)

@doc """
Converts an EMLX device ref back to an Nx.Tensor.

## Example

result_ref = EMLX.some_operation(input)
result_tensor = EMLX.to_nx(result_ref)
"""
def to_nx({device, ref} = device_ref) when is_atom(device) and is_reference(ref) do
EMLX.Backend.to_nx(device_ref)
end

@doc """
Returns the scalar value of a 0-d tensor as a number.

Expand Down Expand Up @@ -558,7 +795,7 @@ defmodule EMLX do
@impl Nx.Defn.Compiler
def __partitions_options__(opts) do
n = Keyword.get(opts, :max_concurrency, 1)
device = Keyword.get(opts, :device, :gpu)
device = Keyword.get(opts, :device, default_device())

# Allocate one CommandQueue (and its OS thread) per partition. This runs
# inside Nx.Serving's GenServer init/1 — queues are owned by module_state.
Expand All @@ -570,10 +807,22 @@ defmodule EMLX do

@impl Nx.Defn.Compiler
def __to_backend__(opts) do
device = Keyword.get(opts, :device, :gpu)
device = Keyword.get(opts, :device, default_device())
{EMLX.Backend, device: device}
end

@doc """
Returns the default MLX device for this process.

Reads `:default_device` from the `:emlx` application environment, falling
back to `:gpu`. Override in tests or config via:

Application.put_env(:emlx, :default_device, :cpu)
"""
def default_device do
Application.get_env(:emlx, :default_device, :gpu)
end

# Splits opts into {emlx_compiler_opts, rest_opts}. The rest_opts are
# forwarded to Nx.Defn.Evaluator; EMLX-specific keys are consumed here.
defp split_compiler_opts(opts) do
Expand Down
Loading
Loading