Skip to content

Implement communication-avoiding work assignment for generator kernels in Command Graph #330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions include/command_graph_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class command_graph_generator {
std::optional<reduction_info> pending_reduction;

std::string debug_name;

const task* generator_task = nullptr;
};

struct host_object_state {
Expand All @@ -111,6 +113,7 @@ class command_graph_generator {
};

public:
bool is_generator_kernel(const task& tsk) const;
struct policy_set {
error_policy uninitialized_read_error = error_policy::panic;
error_policy overlapping_write_error = error_policy::panic;
Expand Down
15 changes: 15 additions & 0 deletions include/range_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@


namespace celerity {

// Forward-declaration so we can detect whether the functor is one_to_one
namespace access {
struct one_to_one;
}

namespace detail {

template <typename Functor, int BufferDims, int KernelDims>
Expand Down Expand Up @@ -85,6 +91,8 @@ namespace detail {
virtual region<3> map_3(const chunk<3>& chnk) const = 0;

virtual ~range_mapper_base() = default;

virtual bool is_one_to_one() const { return false; }
};

template <int BufferDims, typename Functor>
Expand All @@ -107,6 +115,13 @@ namespace detail {
region<3> map_3(const chunk<2>& chnk) const override { return map<3>(chnk); }
region<3> map_3(const chunk<3>& chnk) const override { return map<3>(chnk); }

// Override the s_one_to_one() to detect if the functor is specifically celerity::access::one_to_one:
bool is_one_to_one() const override {
// If the Functor is celerity::access::one_to_one, return true
if constexpr(std::is_same_v<Functor, celerity::access::one_to_one>) { return true; }
return false;
}

private:
Functor m_rmfn;
range<BufferDims> m_buffer_size;
Expand Down
9 changes: 9 additions & 0 deletions include/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ namespace detail {
/// Returns a set of bounding boxes, one for each accessed region, that must be allocated contiguously.
box_vector<3> compute_required_contiguous_boxes(const buffer_id bid, const box<3>& execution_range) const;

// Retrieves the range mapper associated with a specific buffer ID, or nullptr if not found.
const range_mapper_base* get_range_mapper(buffer_id search_bid) const {
for(const auto& ba : m_accesses) {
if(ba.bid == search_bid) { return ba.range_mapper.get(); }
}
return nullptr; // Not found
}


private:
std::vector<buffer_access> m_accesses;
std::unordered_set<buffer_id> m_accessed_buffers; ///< Cached set of buffer ids found in m_accesses
Expand Down
50 changes: 48 additions & 2 deletions src/command_graph_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,31 @@ bool is_topologically_sorted(Iterator begin, Iterator end) {
return true;
}

bool command_graph_generator::is_generator_kernel(const task& tsk) const {
if(tsk.get_type() != task_type::device_compute) return false;

// Must not have a hint that modifies splitting:
if(tsk.get_hint<experimental::hints::split_1d>() != nullptr || tsk.get_hint<experimental::hints::split_2d>() != nullptr) { return false; }

// Must have exactly one buffer access
const auto& bam = tsk.get_buffer_access_map();
if(bam.get_num_accesses() != 1) return false;

// That single access must be discard_write
const auto [bid, mode] = bam.get_nth_access(0);
if(mode != access_mode::discard_write) return false;

// Must produce exactly the entire buffer
const auto full_box = box(subrange({}, tsk.get_global_size()));
if(bam.get_task_produced_region(bid) != full_box) return false;

// Confirm the *range mapper* is truly one_to_one:
const auto rm = bam.get_range_mapper(bid);
if(rm == nullptr || !rm->is_one_to_one()) { return false; }

return true;
}

std::vector<const command*> command_graph_generator::build_task(const task& tsk) {
const auto epoch_to_prune_before = m_epoch_for_new_commands;
batch current_batch;
Expand Down Expand Up @@ -455,6 +480,29 @@ void command_graph_generator::update_local_buffer_fresh_regions(const task& tsk,
}

void command_graph_generator::generate_distributed_commands(batch& current_batch, const task& tsk) {
// If it's a generator kernel, we generate commands immediately and skip partial generation altogether.
if(is_generator_kernel(tsk)) {
// Identify which buffer is discard-written
const auto [gen_bid, _] = tsk.get_buffer_access_map().get_nth_access(0);
auto& bstate = m_buffers.at(gen_bid);
const auto chunks = split_task_and_assign_chunks(tsk);

// Create a command for each chunk that belongs to our local node.
for(const auto& a_chunk : chunks) {
if(a_chunk.executed_on != m_local_nid) continue;
auto* cmd = create_command<execution_command>(current_batch, &tsk, subrange<3>{a_chunk.chnk}, false,
[&](const auto& record_debug_info) { record_debug_info(tsk, [this](const buffer_id bid) { return m_buffers.at(bid).debug_name; }); });

// Mark that subrange as freshly written, so there’s no “uninitialized read” later.
box<3> write_box(a_chunk.chnk.offset, a_chunk.chnk.offset + a_chunk.chnk.range);
region<3> written_region{write_box};
bstate.local_last_writer.update_region(written_region, cmd);
bstate.initialized_region = region_union(bstate.initialized_region, written_region);
}
// Return here so we skip the normal device-logic below.
return;
}

const auto chunks = split_task_and_assign_chunks(tsk);
const auto chunks_with_requirements = compute_per_chunk_requirements(tsk, chunks);

Expand Down Expand Up @@ -521,8 +569,6 @@ void command_graph_generator::generate_distributed_commands(batch& current_batch

if(!produced.empty()) {
generate_anti_dependencies(tsk, bid, buffer.local_last_writer, produced, cmd);

// Update last writer
buffer.local_last_writer.update_region(produced, cmd);
buffer.replicated_regions.update_region(produced, node_bitset{});

Expand Down
59 changes: 59 additions & 0 deletions test/command_graph_general_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,62 @@ TEST_CASE("command_graph_generator throws in tests if it detects overlapping wri
"range mapper for this write access or constrain the split via experimental::constrain_split to make the access non-overlapping.");
}
}

TEST_CASE("results form generator kernels are never communicated between nodes", "[command_graph_generator][owner-computes]") {
const bool split_2d = GENERATE(values({0, 1}));
CAPTURE(split_2d);

const size_t num_nodes = 4;
cdag_test_context cctx(num_nodes); // 4 nodes, so we can get a true 2D work assignment for the timestep kernel
auto buf = cctx.create_buffer<2>({256, 256}); // a 256x256 buffer

const auto tid_init = cctx.device_compute(buf.get_range()) //
.discard_write(buf, celerity::access::one_to_one())
.name("init")
.submit();
const auto tid_ts0 = cctx.device_compute(buf.get_range()) //
.hint_if(split_2d, experimental::hints::split_2d())
.read_write(buf, celerity::access::one_to_one())
.name("timestep 0")
.submit();
const auto tid_ts1 = cctx.device_compute(buf.get_range()) //
.hint_if(split_2d, experimental::hints::split_2d())
.read_write(buf, celerity::access::one_to_one())
.name("timestep 1")
.submit();

CHECK(cctx.query<execution_command_record>().count_per_node() == 3); // one for each task above
CHECK(cctx.query<push_command_record>().total_count() == 0);
CHECK(cctx.query<await_push_command_record>().total_count() == 0);

const auto inits = cctx.query<execution_command_record>(tid_init);
const auto ts0s = cctx.query<execution_command_record>(tid_ts0);
const auto ts1s = cctx.query<execution_command_record>(tid_ts1);
CHECK(inits.count_per_node() == 1);
CHECK(ts0s.count_per_node() == 1);
CHECK(ts1s.count_per_node() == 1);

for(node_id nid = 0; nid < num_nodes; ++nid) {
const auto n_init = inits.on(nid);
REQUIRE(n_init->accesses.size() == 1);

const auto generate = n_init->accesses.front();
CHECK(generate.bid == buf.get_id());
CHECK(generate.mode == access_mode::discard_write);

const auto n_ts0 = ts0s.on(nid);
CHECK(n_ts0.predecessors().contains(n_init));
REQUIRE(n_ts0->accesses.size() == 1);

const auto consume = n_ts0->accesses.front();
CHECK(consume.bid == buf.get_id());
CHECK(consume.mode == access_mode::read_write);

// generator kernel "init" has generated exactly the buffer subrange that is consumed by "timestep 0"
CHECK(consume.req == generate.req);

const auto n_ts1 = ts1s.on(nid);
CHECK(n_ts1.predecessors().contains(n_ts0));
CHECK_FALSE(n_ts1.predecessors().contains(n_init));
}
}