Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions mlx/backend/metal/kernels/logsumexp.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,33 @@ template <typename T, typename AccT = float, int N_READS = 4>
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 tid [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]],
uint2 _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
int lid = _lid.x;

constexpr int SIMD_SIZE = 32;
constexpr int elem_per_group = SIMD_SIZE * 32 * N_READS;

threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];

AccT ld[N_READS];

in += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
const int axis_offset = tid.y * elem_per_group;
in += gid.x * size_t(axis_size) + lid * N_READS + axis_offset;
if (axis_offset + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
ld[i] = ((axis_offset + lid * N_READS + i) < axis_size)
? AccT(in[i])
: Limits<AccT>::min;
}
}
if (simd_group_id == 0) {
Expand Down Expand Up @@ -55,6 +60,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
maxval = local_max[0];

// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
out += gid.x * grid_dim.y + tid.y;
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(ld[i] - maxval);
Expand All @@ -67,7 +73,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
out[0] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}
Expand Down
26 changes: 24 additions & 2 deletions mlx/backend/metal/logsumexp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,37 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
const int n_reads = 4;
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;

std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
bool split = n_rows < 4 && axis_size > 4 * looped_limit;
bool looped = !split && axis_size > looped_limit;
std::string kernel_name = looped ? "looped_" : "block_";
kernel_name += "logsumexp_";
kernel_name += type_to_name(out);

auto kernel = get_logsumexp_kernel(d, kernel_name, out);
auto& compute_encoder = d.get_command_encoder(s.index);
if (split) {
auto tmp_size = ceildiv(axis_size, looped_limit);
auto tmp_shape = Shape{n_rows, static_cast<int>(tmp_size)};
array tmp(tmp_shape, in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
size_t threadgroup_size = 1024;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
auto grid_dims = MTL::Size(n_threads, tmp_size, 1);
auto group_dims = MTL::Size(threadgroup_size, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(tmp, 1);
compute_encoder.set_bytes(axis_size, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims);
d.add_temporary(tmp, s.index);
in = tmp;
axis_size = tmp_size;
}

{
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
if (!looped) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
Expand Down
4 changes: 4 additions & 0 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,10 @@ def logsumexp(x, axes=None):
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))

# Even larger
x = mx.random.uniform(shape=(4 * 4096 + 3,))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))

def test_mean(self):
x = mx.array(
[
Expand Down