Skip to content

Commit

Permalink
Do not merge multiple operators in a single node (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Aug 2, 2024
1 parent 7a676fe commit 877b0e6
Show file tree
Hide file tree
Showing 12 changed files with 409 additions and 786 deletions.
418 changes: 147 additions & 271 deletions ark/api/model_test.cpp

Large diffs are not rendered by default.

97 changes: 48 additions & 49 deletions ark/api/planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,58 +62,57 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const {
size_t max_num_processors = 1;
size_t next_node_id = 0;
for (const auto &node : model_.nodes()) {
for (const auto &op : node->ops) {
if (op->is_virtual()) continue;

Json task_info;
task_info["Id"] = next_node_id++;

Json config;
if (!config_rules_.empty()) {
const std::string op_str = op->serialize().dump();
for (auto &rule : config_rules_) {
auto config_str = rule(op_str, gpu_info.arch->name());
if (!config_str.empty()) {
config = Json::parse(config_str);
break;
}
const auto &op = node->op;
if (op->is_virtual()) continue;

Json task_info;
task_info["Id"] = next_node_id++;

Json config;
if (!config_rules_.empty()) {
const std::string op_str = op->serialize().dump();
for (auto &rule : config_rules_) {
auto config_str = rule(op_str, gpu_info.arch->name());
if (!config_str.empty()) {
config = Json::parse(config_str);
break;
}
}
if (config.empty()) {
config = op->default_config(gpu_info.arch);
}
check_config_field(op, config, "NumWarps");
check_config_field(op, config, "NumTasks");
check_config_field(op, config, "SramBytes");
size_t num_warps = config["NumWarps"];
size_t num_tasks = config["NumTasks"];
size_t sram_bytes = config["SramBytes"];
task_info["NumWarps"] = num_warps;
task_info["SramBytes"] = sram_bytes;

max_num_warps = std::max(max_num_warps, num_warps);

task_info["Ops"] = Json::array();
task_info["Ops"].push_back(op->serialize());
task_info["Ops"][0]["Config"] = config;
task_infos.push_back(task_info);

Json resource_group;
size_t num_processors = std::min(num_sm, num_tasks);
max_num_processors = std::max(max_num_processors, num_processors);
resource_group["ProcessorRange"] = {0, num_processors};
resource_group["WarpRange"] = {0, num_warps};
resource_group["SramRange"] = {0, sram_bytes};
resource_group["TaskGroups"] = {{{"TaskId", task_info["Id"]},
{"TaskRange", {0, num_tasks}},
{"Granularity", 1}}};

Json processor_group;
processor_group["ProcessorRange"] = {0, num_processors};
processor_group["ResourceGroups"] = Json::array();
processor_group["ResourceGroups"].push_back(resource_group);
processor_groups.push_back(processor_group);
}
if (config.empty()) {
config = op->default_config(gpu_info.arch);
}
check_config_field(op, config, "NumWarps");
check_config_field(op, config, "NumTasks");
check_config_field(op, config, "SramBytes");
size_t num_warps = config["NumWarps"];
size_t num_tasks = config["NumTasks"];
size_t sram_bytes = config["SramBytes"];
task_info["NumWarps"] = num_warps;
task_info["SramBytes"] = sram_bytes;

max_num_warps = std::max(max_num_warps, num_warps);

task_info["Ops"] = Json::array();
task_info["Ops"].push_back(op->serialize());
task_info["Ops"][0]["Config"] = config;
task_infos.push_back(task_info);

Json resource_group;
size_t num_processors = std::min(num_sm, num_tasks);
max_num_processors = std::max(max_num_processors, num_processors);
resource_group["ProcessorRange"] = {0, num_processors};
resource_group["WarpRange"] = {0, num_warps};
resource_group["SramRange"] = {0, sram_bytes};
resource_group["TaskGroups"] = {{{"TaskId", task_info["Id"]},
{"TaskRange", {0, num_tasks}},
{"Granularity", 1}}};

Json processor_group;
processor_group["ProcessorRange"] = {0, num_processors};
processor_group["ResourceGroups"] = Json::array();
processor_group["ResourceGroups"].push_back(resource_group);
processor_groups.push_back(processor_group);
}

Json plan;
Expand Down
Loading

0 comments on commit 877b0e6

Please sign in to comment.