Skip to content

Commit

Permalink
adjusting executor test for reduce scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
caiomcbr committed Feb 4, 2025
1 parent dd80593 commit e9e504e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
13 changes: 11 additions & 2 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,26 @@ def build_bufs(
nelems_input = nelems if in_place else nelems // num_ranks
else:
nelems_input = nelems
nelems_output = nelems

if "reducescatter" in collective:
assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks)
nelems_output = nelems // num_ranks
else:
nelems_output = nelems

result_buf = GpuBuffer(nelems_output, dtype=dtype)
if in_place:
if "allgather" in collective:
input_buf = cp.split(result_buf, num_ranks)[rank]
elif "reducescatter" in collective:
input_buf = GpuBuffer(nelems_input, dtype=dtype)
result_buf = cp.split(input_buf, num_ranks)[rank]
else:
input_buf = result_buf
else:
input_buf = GpuBuffer(nelems_input, dtype=dtype)
test_buf = cp.zeros(nelems_output, dtype=dtype)

test_buf = cp.zeros(nelems, dtype=dtype)

return input_buf, result_buf, test_buf

Expand Down
2 changes: 1 addition & 1 deletion python/test/executor_test_verifier.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ TEST_DATA_ALL_REDUCE(int32, int)
} \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
if (i >= offset && i < offset + nem_elems_per_rank) { \
assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \
assert(abs(float(result_buf[i - offset]) - float(test_buf[i])) < 1e-3 * num_ranks); \
} \
} \
}
Expand Down

0 comments on commit e9e504e

Please sign in to comment.