Skip to content

Commit

Permalink
Small column long row specialization
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Aug 23, 2024
1 parent f06d8c2 commit 2a32d76
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 23 deletions.
75 changes: 57 additions & 18 deletions mlx/backend/metal/kernels/reduction/reduce_col.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ template <
looped_elem_to_loc<NDIMS> loop;
const device T* row;

// Case 1:
// reduction_stride is small, reduction_size is small and non_col_reductions
// is small. Each thread computes reduction_stride outputs.
if (reduction_size * non_col_reductions < 64) {
// Case 1: Small row small column
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
U totals[31];
for (int i = 0; i < 31; i++) {
totals[i] = Op::init;
Expand Down Expand Up @@ -71,10 +69,55 @@ template <
}
}

// Case 2:
// Reduction stride is small but everything else can be big. We loop both
// across reduction size and non_col_reductions. Each simdgroup produces
// N_READS outputs.
// Case 2: Long row small column
else if (reduction_size * non_col_reductions < 32) {
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = Op::init;
}

short size = reduction_size;
size_t offset = size_t(tid.x) * N_READS;
bool safe = offset + N_READS <= reduction_stride;
short extra = reduction_stride - offset;

size_t out_idx = tid.y + tsize.z * size_t(tid.z);
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;

for (uint r = 0; r < non_col_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);

if (safe) {
for (short i = 0; i < size; i++) {
for (short j = 0; j < N_READS; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
} else {
for (short i = 0; i < size; i++) {
for (short j = 0; j < extra; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
}

loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride + offset;
if (safe) {
for (short i = 0; i < N_READS; i++) {
out[i] = totals[i];
}
} else {
for (short i = 0; i < extra; i++) {
out[i] = totals[i];
}
}
}

// Case 3: Long row medium column
else {
threadgroup U shared_vals[1024];
U totals[N_READS];
Expand Down Expand Up @@ -147,17 +190,13 @@ template <
/**
* Our approach is the following simple looped approach:
* 1. Each thread keeps running totals for BN / n_simdgroups outputs.
* 2. Load a tile BM, BN in shared memory.
* 3. Add the values from shared memory to the current running totals.
* Neighboring threads access different rows (transposed acces).
* 4. Move ahead to the next tile until the M axis is exhausted.
* 5. Move ahead to the next non column reduction
* 6. Simd reduce the running totals
* 2. Load a tile BM, BN in registers and accumulate in the running totals
* 3. Move ahead by BM steps until the column axis and the non column
* reductions are exhausted.
* 6. If BM == 32 then transpose in SM and simd reduce the running totals.
* Otherwise write in shared memory and BN threads accumulate the running
* totals with a loop.
* 7. Write them to the output
*
* The kernel becomes verbose because we support all kinds of OOB checks. For
* instance if we choose that reduction_stride must be larger than BN then we
* can get rid of half the kernel.
*/
template <
typename T,
Expand Down
20 changes: 16 additions & 4 deletions mlx/backend/metal/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,25 @@ void strided_reduce_small(
// Figure out the grid dims
MTL::Size grid_dims, group_dims;

// Case 1: everything is small so launch one thread per col reduce
if (args.reduction_size * args.non_col_reductions < 64) {
// Case 1: Small row small column
if (args.reduction_size * args.non_col_reductions < 64 &&
args.reduction_stride < 32) {
grid_dims = output_grid_for_col_reduce(out, args);
int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width;
group_dims = MTL::Size(threadgroup_size, 1, 1);
}

// Case 2: Reduction in the simdgroup
// Case 2: Long row small column
else if (args.reduction_size * args.non_col_reductions < 32) {
auto out_grid_dims = output_grid_for_col_reduce(out, args);
int threads_x =
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
int threadgroup_x = std::min(threads_x, 128);
grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height);
group_dims = MTL::Size(threadgroup_x, 1, 1);
}

// Case 3: Long row medium column
else {
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
Expand Down Expand Up @@ -544,7 +555,8 @@ void strided_reduce_general_dispatch(
// Prepare the arguments for the kernel
ColReduceArgs args(in, plan, axes);

if (args.reduction_stride < 32) {
if (args.reduction_stride < 32 ||
args.reduction_size * args.non_col_reductions < 32) {
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
}

Expand Down
3 changes: 2 additions & 1 deletion mlx/backend/metal/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ void all_reduce_dispatch(
const std::string& op_name,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s);
const Stream& s,
std::vector<array>& copies);

void row_reduce_general_dispatch(
const array& in,
Expand Down

0 comments on commit 2a32d76

Please sign in to comment.