Skip to content

Commit

Permalink
Kernel with atomics for col reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Aug 23, 2024
1 parent 684e11c commit 6002d77
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 0 deletions.
23 changes: 23 additions & 0 deletions mlx/backend/metal/kernels/reduce.metal
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,25 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);

#define instantiate_col_reduce_atomic_tile(name, itype, otype, op, dim, bm, bn) \
template [[host_name("colAtomic" #dim "_" #bm "_" #bn "_reduce_" #name)]] \
[[kernel]] void col_reduce_atomics<itype, otype, op, dim, bm, bn>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
const constant int* reduce_shape [[buffer(7)]], \
const constant size_t* reduce_strides [[buffer(8)]], \
const constant int& reduce_ndim [[buffer(9)]], \
const constant size_t& non_col_reductions [[buffer(10)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);

#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
Expand All @@ -182,7 +201,11 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_general(name##tname, type, type, op<type>)

#define instantiate_same_col_reduce_atomics_helper(name, tname, type, op) \
instantiate_col_reduce_atomic_tile(name##tname, type, type, op<type>, 1, 8, 128)

instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_atomics_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)

instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
Expand Down
105 changes: 105 additions & 0 deletions mlx/backend/metal/kernels/reduction/reduce_col.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,108 @@ template <
}
}
}

template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int BM = 8,
int BN = 128>
[[kernel]] void col_reduce_atomics(
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
constexpr int n_loops = 4;
constexpr int blocksize = n_loops * BM;
constexpr int n_simdgroups = 4;
constexpr short tgp_size = n_simdgroups * simd_size;
constexpr short n_reads = (BM * BN) / tgp_size;
constexpr short n_read_blocks = BN / n_reads;

threadgroup U shared_vals[BN * BM];
U totals[n_reads];
looped_elem_to_loc<NDIMS> loop;
const device T* row;

for (int i = 0; i < n_reads; i++) {
totals[i] = Op::init;
}

short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
size_t column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride;

size_t total = non_col_reductions * reduction_size;
size_t blocks = (total + blocksize - 1) / blocksize;
size_t full_idx = gid.y + gsize.y * size_t(gid.z);
size_t out_idx = full_idx / blocks;
size_t row_idx = (full_idx % total) * blocksize;
total = min(total, row_idx + blocksize);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;

loop.next(row_idx + offset.y, reduce_shape, reduce_strides);
for (size_t r = row_idx + offset.y; r < total; r += BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);

if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[n_reads];
for (int i = 0; i < n_reads; i++) {
vals[i] =
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < n_reads; i++) {
totals[i] = op(vals[i], totals[i]);
}
}

loop.next(BM, reduce_shape, reduce_strides);
}

short x_block = offset.x / n_reads;
for (int i = 0; i < n_reads; i++) {
shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (offset.y == 0) {
for (int i = 0; i < n_reads; i++) {
for (int j = 1; j < BM; j++) {
totals[i] =
op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
}
}
}

// Write the output.
if (offset.y == 0) {
out_idx = out_idx * reduction_stride + column;
if (safe) {
for (int i = 0; i < n_reads; i++) {
op.atomic_update(out, totals[i], out_idx + i);
}
} else {
for (int i = 0; column + i < reduction_stride; i++) {
op.atomic_update(out, totals[i], out_idx + i);
}
}
}
}
54 changes: 54 additions & 0 deletions mlx/backend/metal/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,46 @@ void strided_reduce_looped(
compute_encoder.dispatchThreads(grid_dims, group_dims);
}

void strided_reduce_atomics(
const array& in,
array& out,
const std::string& op_name,
ColReduceArgs& args,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;

// Figure out the grid dims
int BN = 128;
int BM = 8;
int blocksize = BM * 4;
int threadgroup_size = 4 * 32;
int blocks = (args.non_col_reductions * args.reduction_size + blocksize - 1) /
blocksize;
auto out_grid_size = output_grid_for_col_reduce(out, args);
MTL::Size grid_dims(
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
out_grid_size.width * blocks,
out_grid_size.height);
MTL::Size group_dims(threadgroup_size, 1, 1);

// Set the kernel
std::ostringstream kname;
kname << "colAtomic1_8_128_reduce_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
compute_encoder->setComputePipelineState(kernel);

// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}

void strided_reduce_general_dispatch(
const array& in,
array& out,
Expand All @@ -560,6 +600,20 @@ void strided_reduce_general_dispatch(
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
}

int col_reduce_parallelization = 1;
for (int i = 0; i < out.ndim(); i++) {
if (out.strides()[i] > args.reduction_stride) {
col_reduce_parallelization *= out.shape(i);
}
}
if (in.itemsize() == 4 && col_reduce_parallelization < 8 &&
args.reduce_ndim == 0 &&
args.reduction_size / args.reduction_stride > 1) {
init_reduce(out, op_name, compute_encoder, d, s);
return strided_reduce_atomics(
in, out, op_name, args, compute_encoder, d, s);
}

return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s);
}

Expand Down

0 comments on commit 6002d77

Please sign in to comment.