From 877b0e638a922b824365645f289b2721e5b69534 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Fri, 2 Aug 2024 14:23:28 -0700 Subject: [PATCH] Do not merge multiple operators in a single node (#228) --- ark/api/model_test.cpp | 418 +++++++++++--------------------- ark/api/planner.cpp | 97 ++++---- ark/model/model_graph_impl.cpp | 250 +++---------------- ark/model/model_graph_impl.hpp | 7 - ark/model/model_json.cpp | 10 +- ark/model/model_node.hpp | 5 +- ark/ops/ops_all_reduce_test.cpp | 58 ----- ark/ops/ops_identity_test.cpp | 39 ++- ark/ops/ops_sharding_test.cpp | 46 ++-- docs/model_file.md | 29 +-- examples/tutorial/model.json | 231 +++++++++--------- python/unittest/test_model.py | 5 +- 12 files changed, 409 insertions(+), 786 deletions(-) diff --git a/ark/api/model_test.cpp b/ark/api/model_test.cpp index a9d332a97..1845774e3 100644 --- a/ark/api/model_test.cpp +++ b/ark/api/model_test.cpp @@ -21,7 +21,7 @@ ark::unittest::State test_model_basics() { // | // TensorOp --> t1 --+ // | - // TensorOp --> tx --+ (tx is the output reference, hidden from the code) + // TensorOp --> tx --+ (tx is a write_tensor, hidden from the code) // ark::Tensor t0 = model.tensor({1}, ark::FP32); @@ -31,25 +31,25 @@ ark::unittest::State test_model_basics() { UNITTEST_TRUE(model.verify()); UNITTEST_FALSE(model.compressed()); - // OpNode graph (parentheses indicate a OpNode): + // OpNode graph: // - // (AddOp,) + // AddOp // compressed = model.compress(); UNITTEST_TRUE(compressed.verify()); UNITTEST_TRUE(compressed.compressed()); - UNITTEST_EQ(compressed.nodes().size(), 1); - auto node = compressed.nodes().front(); - UNITTEST_EQ(node->ops.size(), 1); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t2.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[0], t0.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[1], t1.ref()); - UNITTEST_EQ(node->consumers.size(), 0); - UNITTEST_EQ(node->producers.size(), 0); + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 1); + + UNITTEST_EQ(nodes[0]->op->result_tensors()[0], t2.ref()); + UNITTEST_EQ(nodes[0]->op->read_tensors()[0], t0.ref()); + UNITTEST_EQ(nodes[0]->op->read_tensors()[1], t1.ref()); + UNITTEST_EQ(nodes[0]->consumers.size(), 0); + UNITTEST_EQ(nodes[0]->producers.size(), 0); - // Test a chain of Ops that share an input tensor. + // Test a chain of Ops that share a read_tensor. // Model graph: // // TensorOp --> t0 --+--> AddOp --> t2 ------+--> AddOp --> t3 @@ -58,70 +58,78 @@ ark::unittest::State test_model_basics() { // | | // TensorOp --> tx --+ TensorOp --> ty --+ // - // (tx and ty are output references, hidden from the code) + // (tx and ty are write_tensors, hidden from the code) // ark::Tensor t3 = model.add(t2, t1); UNITTEST_TRUE(model.verify()); - // OpNode graph (parentheses indicate a OpNode): + // OpNode graph: // - // (AddOp,AddOp,) + // AddOp --> AddOp // compressed = model.compress(); UNITTEST_TRUE(compressed.verify()); - UNITTEST_EQ(compressed.nodes().size(), 1); - node = compressed.nodes().front(); + nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 2); + + UNITTEST_EQ(nodes[0]->op->result_tensors()[0], t2.ref()); + UNITTEST_EQ(nodes[0]->op->read_tensors()[0], t0.ref()); + UNITTEST_EQ(nodes[0]->op->read_tensors()[1], t1.ref()); + UNITTEST_EQ(nodes[1]->op->result_tensors()[0], t3.ref()); + UNITTEST_EQ(nodes[1]->op->read_tensors()[0], t2.ref()); + UNITTEST_EQ(nodes[1]->op->read_tensors()[1], t1.ref()); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t2.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[0], t0.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[1], t1.ref()); - UNITTEST_EQ(node->ops[1]->result_tensors()[0], t3.ref()); - UNITTEST_EQ(node->ops[1]->read_tensors()[0], t2.ref()); - UNITTEST_EQ(node->ops[1]->read_tensors()[1], t1.ref()); - UNITTEST_EQ(node->consumers.size(), 0); - UNITTEST_EQ(node->producers.size(), 0); + UNITTEST_EQ(nodes[0]->consumers.size(), 1); + UNITTEST_EQ(nodes[0]->producers.size(), 0); + UNITTEST_EQ(nodes[1]->consumers.size(), 0); + UNITTEST_EQ(nodes[1]->producers.size(), 1); - // Test a chain of Ops without shared input tensors. + // Test a chain of Ops without shared read_tensors. // Model graph (omit leftmost part): // // ... ----+--> AddOp --> t3 ----+-> ReluOp --> t4 // ... | | // ... ----+ TensorOp --> tz --+ // ... | - // ... --+ (tz is the output reference, hidden from the code) + // ... --+ (tz is a write_tensor, hidden from the code) // ark::Tensor t4 = model.relu(t3); UNITTEST_TRUE(model.verify()); - // OpNode graph (parentheses indicate a OpNode): + // OpNode graph: // - // (AddOp,AddOp,ReluOp,) + // AddOp --> AddOp --> ReluOp // compressed = model.compress(); UNITTEST_TRUE(compressed.verify()); - UNITTEST_EQ(compressed.nodes().size(), 1); - - node = compressed.nodes().front(); - - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t2.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[0], t0.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[1], t1.ref()); - UNITTEST_EQ(node->ops[1]->result_tensors()[0], t3.ref()); - UNITTEST_EQ(node->ops[1]->read_tensors()[0], t2.ref()); - UNITTEST_EQ(node->ops[1]->read_tensors()[1], t1.ref()); - UNITTEST_EQ(node->ops[2]->result_tensors()[0], t4.ref()); - UNITTEST_EQ(node->ops[2]->read_tensors()[0], t3.ref()); - UNITTEST_EQ(node->consumers.size(), 0); - UNITTEST_EQ(node->producers.size(), 0); - - // Test a chain of Ops that use the output from the same previous Op. + + nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 3); + + UNITTEST_EQ(nodes[0]->op->result_tensors()[0], t2.ref()); + UNITTEST_EQ(nodes[0]->op->read_tensors()[0], t0.ref()); + UNITTEST_EQ(nodes[0]->op->read_tensors()[1], t1.ref()); + UNITTEST_EQ(nodes[1]->op->result_tensors()[0], t3.ref()); + UNITTEST_EQ(nodes[1]->op->read_tensors()[0], t2.ref()); + UNITTEST_EQ(nodes[1]->op->read_tensors()[1], t1.ref()); + UNITTEST_EQ(nodes[2]->op->result_tensors()[0], t4.ref()); + UNITTEST_EQ(nodes[2]->op->read_tensors()[0], t3.ref()); + + UNITTEST_EQ(nodes[0]->consumers.size(), 1); + UNITTEST_EQ(nodes[0]->producers.size(), 0); + UNITTEST_EQ(nodes[1]->consumers.size(), 1); + UNITTEST_EQ(nodes[1]->producers.size(), 1); + UNITTEST_EQ(nodes[2]->consumers.size(), 0); + UNITTEST_EQ(nodes[2]->producers.size(), 1); + + // Test a chain of Ops that use the result_tensor from the same previous Op. // Model graph (omit leftmost part): // // ... +---- (this is t2) -------------------------+--> AddOp --> t5 @@ -132,32 +140,29 @@ ark::unittest::State test_model_basics() { // ... | TensorOp --> tw --+ // ... --+ // - // (tz and tw are output references, hidden from the code) + // (tz and tw are write_tensors, hidden from the code) // ark::Tensor t5 = model.add(t2, t4); UNITTEST_TRUE(model.verify()); - // OpNode graph (parentheses indicate a OpNode): + // OpNode graph: // - // (AddOp,AddOp,ReluOp,AddOp,) + // AddOp --> AddOp --> ReluOp --> AddOp // compressed = model.compress(); UNITTEST_TRUE(compressed.verify()); - auto nodes = compressed.nodes(); - UNITTEST_EQ(nodes.size(), 1); + nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 4); - auto nodes_iter = nodes.begin(); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t2.ref()); - UNITTEST_EQ(node->ops[1]->result_tensors()[0], t3.ref()); - UNITTEST_EQ(node->ops[2]->result_tensors()[0], t4.ref()); - UNITTEST_EQ(node->ops[3]->result_tensors()[0], t5.ref()); + UNITTEST_EQ(nodes[0]->op->result_tensors()[0], t2.ref()); + UNITTEST_EQ(nodes[1]->op->result_tensors()[0], t3.ref()); + UNITTEST_EQ(nodes[2]->op->result_tensors()[0], t4.ref()); + UNITTEST_EQ(nodes[3]->op->result_tensors()[0], t5.ref()); - // Test an Op that uses outputs from multiple previous Ops. + // Test an Op that uses result_tensors from multiple previous Ops. // Model graph (omit leftmost part): // // ... ----- (this is t2) --+--> AddOp --> t5 @@ -174,7 +179,7 @@ ark::unittest::State test_model_basics() { // | // TensorOp --> tu --+ // - // (tw and tu are output references, hidden from the code) + // (tw and tu are write_tensors, hidden from the code) // ark::Tensor t6 = model.tensor({1}, ark::FP32); @@ -183,34 +188,27 @@ ark::unittest::State test_model_basics() { ark::Tensor t9 = model.add(t5, t8); UNITTEST_TRUE(model.verify()); - // OpNode graph (parentheses indicate a OpNode): + // OpNode graph: // - // (AddOp,AddOp,ReluOp,AddOp,) --+ - // | - // (AddOp,) --+--> (AddOp,) + // AddOp --> AddOp --> ReluOp --> AddOp --+ + // | + // AddOp --+--> AddOp // compressed = model.compress(); UNITTEST_TRUE(compressed.verify()); nodes = compressed.nodes(); - UNITTEST_EQ(nodes.size(), 3); + UNITTEST_EQ(nodes.size(), 6); - nodes_iter = nodes.begin(); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t2.ref()); - UNITTEST_EQ(node->ops[1]->result_tensors()[0], t3.ref()); - UNITTEST_EQ(node->ops[2]->result_tensors()[0], t4.ref()); - UNITTEST_EQ(node->ops[3]->result_tensors()[0], t5.ref()); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add_3;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t8.ref()); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add_4;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t9.ref()); - - // Test an Op that uses a single input tensor for multiple inputs. + UNITTEST_EQ(nodes[0]->op->result_tensors()[0], t2.ref()); + UNITTEST_EQ(nodes[1]->op->result_tensors()[0], t3.ref()); + UNITTEST_EQ(nodes[2]->op->result_tensors()[0], t4.ref()); + UNITTEST_EQ(nodes[3]->op->result_tensors()[0], t5.ref()); + UNITTEST_EQ(nodes[4]->op->result_tensors()[0], t8.ref()); + UNITTEST_EQ(nodes[5]->op->result_tensors()[0], t9.ref()); + + // Test an Op that uses a single tensor for multiple inputs. // Model graph (omit leftmost part): // // ... ----- (this is t2) --+--> AddOp --> t5 @@ -234,46 +232,37 @@ ark::unittest::State test_model_basics() { // | // TensorOp --> tv -----------+ // - // (tw, tu, and tv are output references, hidden from the code) + // (tw, tu, and tv are write_tensors, hidden from the code) // ark::Tensor t10 = model.tensor({1}, ark::FP32); ark::Tensor t11 = model.add(t10, t10); UNITTEST_TRUE(model.verify()); - // OpNode graph (parentheses indicate a OpNode): + // OpNode graph: // - // (AddOp,AddOp,ReluOp,AddOp,) --+ - // | - // (AddOp,) --+--> (AddOp,) + // AddOp --> AddOp --> ReluOp --> AddOp --+ + // | + // AddOp --+--> AddOp // - // (AddOp,) + // AddOp // compressed = model.compress(); UNITTEST_TRUE(compressed.verify()); nodes = compressed.nodes(); - UNITTEST_EQ(nodes.size(), 4); + UNITTEST_EQ(nodes.size(), 7); - nodes_iter = nodes.begin(); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t2.ref()); - UNITTEST_EQ(node->ops[1]->result_tensors()[0], t3.ref()); - UNITTEST_EQ(node->ops[2]->result_tensors()[0], t4.ref()); - UNITTEST_EQ(node->ops[3]->result_tensors()[0], t5.ref()); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add_3;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t8.ref()); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add_4;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t9.ref()); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add_5;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t11.ref()); - - // Test using previous Ops' outputs from multiple different Ops. + UNITTEST_EQ(nodes[0]->op->result_tensors()[0], t2.ref()); + UNITTEST_EQ(nodes[1]->op->result_tensors()[0], t3.ref()); + UNITTEST_EQ(nodes[2]->op->result_tensors()[0], t4.ref()); + UNITTEST_EQ(nodes[3]->op->result_tensors()[0], t5.ref()); + UNITTEST_EQ(nodes[4]->op->result_tensors()[0], t8.ref()); + UNITTEST_EQ(nodes[5]->op->result_tensors()[0], t9.ref()); + UNITTEST_EQ(nodes[6]->op->result_tensors()[0], t11.ref()); + + // Test using previous Ops' result_tensors from multiple different Ops. // Model graph (omit leftmost part): // // ... ----- (this is t2) --+--> AddOp --> t5 @@ -297,46 +286,35 @@ ark::unittest::State test_model_basics() { // | // TensorOp --> tv -----------+ // - // (tw, tu, and tv are output references, hidden from the code) + // (tw, tu, and tv are write_tensors, hidden from the code) // ark::Tensor t12 = model.add(t5, t8); UNITTEST_TRUE(model.verify()); - // OpNode graph (parentheses indicate a OpNode): + // OpNode graph: // - // (AddOp,AddOp,ReluOp,AddOp,) --+--> (AddOp,) - // | - // (AddOp,) --+--> (AddOp,) + // AddOp --> AddOp --> ReluOp --> AddOp --+--> AddOp + // | + // AddOp --+--> AddOp // - // (AddOp,) + // AddOp // compressed = model.compress(); UNITTEST_TRUE(compressed.verify()); nodes = compressed.nodes(); - UNITTEST_EQ(nodes.size(), 5); - - nodes_iter = nodes.begin(); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t2.ref()); - UNITTEST_EQ(node->ops[1]->result_tensors()[0], t3.ref()); - UNITTEST_EQ(node->ops[2]->result_tensors()[0], t4.ref()); - UNITTEST_EQ(node->ops[3]->result_tensors()[0], t5.ref()); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add_3;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t8.ref()); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add_4;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t9.ref()); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add_5;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t11.ref()); - node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "add_6;"); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t12.ref()); + UNITTEST_EQ(nodes.size(), 8); + + UNITTEST_EQ(nodes[0]->op->result_tensors()[0], t2.ref()); + UNITTEST_EQ(nodes[1]->op->result_tensors()[0], t3.ref()); + UNITTEST_EQ(nodes[2]->op->result_tensors()[0], t4.ref()); + UNITTEST_EQ(nodes[3]->op->result_tensors()[0], t5.ref()); + UNITTEST_EQ(nodes[4]->op->result_tensors()[0], t8.ref()); + UNITTEST_EQ(nodes[5]->op->result_tensors()[0], t9.ref()); + UNITTEST_EQ(nodes[6]->op->result_tensors()[0], t11.ref()); + UNITTEST_EQ(nodes[7]->op->result_tensors()[0], t12.ref()); return ark::unittest::SUCCESS; } @@ -353,166 +331,66 @@ ark::unittest::State test_model_dependent_inputs() { ark::Tensor x4 = m.mul(x2, x3); ark::Tensor y = m.add(x0, x4); - auto compressed = m.compress(); - auto nodes = compressed.nodes(); - UNITTEST_EQ(nodes.size(), 4); - auto nodes_iter = nodes.begin(); - auto node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops.size(), 4); - UNITTEST_EQ(node->ops[1]->result_tensors()[0], x0.ref()); - UNITTEST_EQ(node->ops[3]->result_tensors()[0], x1.ref()); - UNITTEST_EQ(node->consumers.size(), 3); - UNITTEST_EQ(node->producers.size(), 0); - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops.size(), 1); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], x2.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[0], ones.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[1], x1.ref()); - UNITTEST_EQ(node->consumers.size(), 1); - UNITTEST_EQ(node->producers.size(), 1); - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops.size(), 1); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], x3.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[0], ones.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[1], x1.ref()); - UNITTEST_EQ(node->consumers.size(), 1); - UNITTEST_EQ(node->producers.size(), 1); - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops.size(), 2); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], x4.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[0], x2.ref()); - UNITTEST_EQ(node->ops[0]->read_tensors()[1], x3.ref()); - UNITTEST_EQ(node->ops[1]->result_tensors()[0], y.ref()); - UNITTEST_EQ(node->ops[1]->read_tensors()[0], x0.ref()); - UNITTEST_EQ(node->ops[1]->read_tensors()[1], x4.ref()); - UNITTEST_EQ(node->consumers.size(), 0); - UNITTEST_EQ(node->producers.size(), 3); - - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_model_noop() { - ark::Model model; - model.tensor({1}, ark::FP32); - model.tensor({1}, ark::FP32); - model.tensor({1}, ark::FP32); - - UNITTEST_TRUE(model.verify()); - - auto compressed = model.compress(); - UNITTEST_TRUE(compressed.verify()); - UNITTEST_EQ(compressed.nodes().size(), 0); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_model_identity() { - // OpNode graph (parentheses indicate a OpNode): + // OpNode graph: // - // (Relu,) --+ - // | - // (Relu,) --+--> (Relu,) + // x0 x1 x2 x4 + // MulOp -> MulOp -+-> MulOp -> MulOp -+-> MulOp -+-> MulOp -+-> AddOp + // | | | | + // | +-> MulOp -+ x3 | + // +-----------------------------------------+ // - ark::Model model; - ark::Tensor t0 = model.tensor({1}, ark::FP32); - ark::Tensor t1 = model.tensor({1}, ark::FP32); - ark::Tensor t2 = model.tensor({1}, ark::FP32); + auto compressed = m.compress(); + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 8); - ark::Tensor r0 = model.relu(t0); - ark::Tensor r1 = model.relu(t1); - ark::Tensor t3 = model.identity(t2, {r0, r1}); + UNITTEST_EQ(nodes[1]->op->result_tensors()[0], x0.ref()); + UNITTEST_EQ(nodes[1]->consumers.size(), 2); + UNITTEST_EQ(nodes[1]->producers.size(), 1); - ark::Tensor t4 = model.relu(t3); - UNITTEST_TRUE(model.verify()); + UNITTEST_EQ(nodes[3]->op->result_tensors()[0], x1.ref()); + UNITTEST_EQ(nodes[3]->consumers.size(), 2); + UNITTEST_EQ(nodes[3]->producers.size(), 1); - auto compressed = model.compress(); - UNITTEST_TRUE(compressed.verify()); - auto nodes = compressed.nodes(); - UNITTEST_EQ(nodes.size(), 3); - auto nodes_iter = nodes.begin(); + UNITTEST_EQ(nodes[4]->op->result_tensors()[0], x2.ref()); + UNITTEST_EQ(nodes[4]->consumers.size(), 1); + UNITTEST_EQ(nodes[4]->producers.size(), 1); - auto node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], r0.ref()); - UNITTEST_EQ(node->producers.size(), 0UL); - UNITTEST_EQ(node->consumers.size(), 1UL); + UNITTEST_EQ(nodes[5]->op->result_tensors()[0], x3.ref()); + UNITTEST_EQ(nodes[5]->consumers.size(), 1); + UNITTEST_EQ(nodes[5]->producers.size(), 1); - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], r1.ref()); - UNITTEST_EQ(node->producers.size(), 0UL); - UNITTEST_EQ(node->consumers.size(), 1UL); + UNITTEST_EQ(nodes[6]->op->result_tensors()[0], x4.ref()); + UNITTEST_EQ(nodes[6]->consumers.size(), 1); + UNITTEST_EQ(nodes[6]->producers.size(), 2); - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t4.ref()); - UNITTEST_EQ(node->producers.size(), 2UL); - UNITTEST_EQ(node->consumers.size(), 0UL); + UNITTEST_EQ(nodes[7]->op->result_tensors()[0], y.ref()); + UNITTEST_EQ(nodes[7]->consumers.size(), 0); + UNITTEST_EQ(nodes[7]->producers.size(), 2); return ark::unittest::SUCCESS; } -ark::unittest::State test_model_sharding() { - // OpNode graph (parentheses indicate a OpNode): - // - // (Relu,) --+ - // | - // (Relu,) --+ - // | - // (Relu,) --+--> (Relu,) - // - +ark::unittest::State test_model_noop() { ark::Model model; - ark::Tensor t0 = model.tensor({3}, ark::FP32); - - std::vector vec = model.sharding(t0, 0, 1); - UNITTEST_EQ(vec.size(), 3UL); - - ark::Tensor t1 = vec[0]; - ark::Tensor t2 = vec[1]; - ark::Tensor t3 = vec[2]; - - ark::Tensor r0 = model.relu(t1); - ark::Tensor r1 = model.relu(t2); - ark::Tensor r2 = model.relu(t3); - - ark::Tensor t4 = model.identity(t0, {r0, r1, r2}); + model.tensor({1}, ark::FP32); + model.tensor({1}, ark::FP32); + model.tensor({1}, ark::FP32); - ark::Tensor t5 = model.relu(t4); UNITTEST_TRUE(model.verify()); auto compressed = model.compress(); UNITTEST_TRUE(compressed.verify()); - auto nodes = compressed.nodes(); - UNITTEST_EQ(nodes.size(), 4); - auto nodes_iter = nodes.begin(); - - auto node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], r0.ref()); - UNITTEST_EQ(node->producers.size(), 0UL); - UNITTEST_EQ(node->consumers.size(), 1UL); - - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], r1.ref()); - UNITTEST_EQ(node->producers.size(), 0UL); - UNITTEST_EQ(node->consumers.size(), 1UL); - - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], r2.ref()); - UNITTEST_EQ(node->producers.size(), 0UL); - UNITTEST_EQ(node->consumers.size(), 1UL); - - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t5.ref()); - UNITTEST_EQ(node->producers.size(), 3UL); - UNITTEST_EQ(node->consumers.size(), 0UL); - + UNITTEST_EQ(compressed.nodes().size(), 0); return ark::unittest::SUCCESS; } ark::unittest::State test_model_cumulate() { - // OpNode graph (parentheses indicate a OpNode): + // OpNode graph: // - // (Relu,) --+ (Relu,) --+ - // | | - // (Relu,Add,) --+--> (Add,) --+--> (Add,) + // Relu --+ Relu --+ + // | | + // Relu --> Add --+--> Add --+--> Add // ark::Model model; @@ -528,10 +406,10 @@ ark::unittest::State test_model_cumulate() { auto compressed = model.compress(); auto nodes = compressed.nodes(); - UNITTEST_EQ(nodes.size(), 5); + UNITTEST_EQ(nodes.size(), 6); auto last_node = nodes.back().get(); - UNITTEST_EQ(last_node->ops[0]->result_tensors()[0], cumulate.ref()); + UNITTEST_EQ(last_node->op->result_tensors()[0], cumulate.ref()); UNITTEST_EQ(last_node->producers.size(), 2); UNITTEST_EQ(last_node->consumers.size(), 0); @@ -542,8 +420,6 @@ int main() { UNITTEST(test_model_basics); UNITTEST(test_model_dependent_inputs); UNITTEST(test_model_noop); - UNITTEST(test_model_identity); - UNITTEST(test_model_sharding); UNITTEST(test_model_cumulate); return 0; } diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index 5c9d09f2e..e7e9ba96b 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -62,58 +62,57 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { size_t max_num_processors = 1; size_t next_node_id = 0; for (const auto &node : model_.nodes()) { - for (const auto &op : node->ops) { - if (op->is_virtual()) continue; - - Json task_info; - task_info["Id"] = next_node_id++; - - Json config; - if (!config_rules_.empty()) { - const std::string op_str = op->serialize().dump(); - for (auto &rule : config_rules_) { - auto config_str = rule(op_str, gpu_info.arch->name()); - if (!config_str.empty()) { - config = Json::parse(config_str); - break; - } + const auto &op = node->op; + if (op->is_virtual()) continue; + + Json task_info; + task_info["Id"] = next_node_id++; + + Json config; + if (!config_rules_.empty()) { + const std::string op_str = op->serialize().dump(); + for (auto &rule : config_rules_) { + auto config_str = rule(op_str, gpu_info.arch->name()); + if (!config_str.empty()) { + config = Json::parse(config_str); + break; } } - if (config.empty()) { - config = op->default_config(gpu_info.arch); - } - check_config_field(op, config, "NumWarps"); - check_config_field(op, config, "NumTasks"); - check_config_field(op, config, "SramBytes"); - size_t num_warps = config["NumWarps"]; - size_t num_tasks = config["NumTasks"]; - size_t sram_bytes = config["SramBytes"]; - task_info["NumWarps"] = num_warps; - task_info["SramBytes"] = sram_bytes; - - max_num_warps = std::max(max_num_warps, num_warps); - - task_info["Ops"] = Json::array(); - task_info["Ops"].push_back(op->serialize()); - task_info["Ops"][0]["Config"] = config; - task_infos.push_back(task_info); - - Json resource_group; - size_t num_processors = std::min(num_sm, num_tasks); - max_num_processors = std::max(max_num_processors, num_processors); - resource_group["ProcessorRange"] = {0, num_processors}; - resource_group["WarpRange"] = {0, num_warps}; - resource_group["SramRange"] = {0, sram_bytes}; - resource_group["TaskGroups"] = {{{"TaskId", task_info["Id"]}, - {"TaskRange", {0, num_tasks}}, - {"Granularity", 1}}}; - - Json processor_group; - processor_group["ProcessorRange"] = {0, num_processors}; - processor_group["ResourceGroups"] = Json::array(); - processor_group["ResourceGroups"].push_back(resource_group); - processor_groups.push_back(processor_group); } + if (config.empty()) { + config = op->default_config(gpu_info.arch); + } + check_config_field(op, config, "NumWarps"); + check_config_field(op, config, "NumTasks"); + check_config_field(op, config, "SramBytes"); + size_t num_warps = config["NumWarps"]; + size_t num_tasks = config["NumTasks"]; + size_t sram_bytes = config["SramBytes"]; + task_info["NumWarps"] = num_warps; + task_info["SramBytes"] = sram_bytes; + + max_num_warps = std::max(max_num_warps, num_warps); + + task_info["Ops"] = Json::array(); + task_info["Ops"].push_back(op->serialize()); + task_info["Ops"][0]["Config"] = config; + task_infos.push_back(task_info); + + Json resource_group; + size_t num_processors = std::min(num_sm, num_tasks); + max_num_processors = std::max(max_num_processors, num_processors); + resource_group["ProcessorRange"] = {0, num_processors}; + resource_group["WarpRange"] = {0, num_warps}; + resource_group["SramRange"] = {0, sram_bytes}; + resource_group["TaskGroups"] = {{{"TaskId", task_info["Id"]}, + {"TaskRange", {0, num_tasks}}, + {"Granularity", 1}}}; + + Json processor_group; + processor_group["ProcessorRange"] = {0, num_processors}; + processor_group["ResourceGroups"] = Json::array(); + processor_group["ResourceGroups"].push_back(resource_group); + processor_groups.push_back(processor_group); } Json plan; diff --git a/ark/model/model_graph_impl.cpp b/ark/model/model_graph_impl.cpp index 17410d23f..356867cf2 100644 --- a/ark/model/model_graph_impl.cpp +++ b/ark/model/model_graph_impl.cpp @@ -24,7 +24,7 @@ ModelGraph::Impl &ModelGraph::Impl::operator=(const ModelGraph::Impl &other) { nodes_.clear(); for (const auto &node : other.nodes_) { ModelNodeRef new_node = std::make_shared(); - new_node->ops = node->ops; + new_node->op = node->op; node_map.emplace(node, new_node); nodes_.push_back(new_node); } @@ -67,44 +67,41 @@ ModelGraph::Impl &ModelGraph::Impl::operator=(const ModelGraph::Impl &other) { void ModelGraph::Impl::compress_nodes() { if (!compressed_) { this->recursive_remove_virtual_nodes(); - this->recursive_merge_nodes(); compressed_ = true; } } bool ModelGraph::Impl::verify() const { for (auto &node : nodes_) { - if (node->ops.size() == 0) { - LOG(DEBUG, "node has no ops"); + if (node->op == nullptr) { + LOG(DEBUG, "node has no op"); return false; } - for (auto &op : node->ops) { - if (op_to_node_.find(op) == op_to_node_.end()) { - LOG(DEBUG, "op has not been added to the graph"); + if (op_to_node_.find(node->op) == op_to_node_.end()) { + LOG(DEBUG, "op has not been added to the graph"); + return false; + } + if (op_to_node_.at(node->op) != node) { + LOG(DEBUG, "op is not in the correct node"); + return false; + } + node->op->verify(); + for (auto &tns : node->op->result_tensors()) { + if (tensor_to_producer_op_.find(tns) == + tensor_to_producer_op_.end()) { + LOG(DEBUG, "result tensor has not been produced by any op"); return false; } - if (op_to_node_.at(op) != node) { - LOG(DEBUG, "op is not in the correct node"); + if (tensor_to_producer_op_.at(tns) != node->op) { + LOG(DEBUG, "result tensor has been produced by another op"); return false; } - op->verify(); - for (auto &tns : op->result_tensors()) { - if (tensor_to_producer_op_.find(tns) == - tensor_to_producer_op_.end()) { - LOG(DEBUG, "result tensor has not been produced by any op"); - return false; - } - if (tensor_to_producer_op_.at(tns) != op) { - LOG(DEBUG, "result tensor has been produced by another op"); - return false; - } - } - for (auto &tns : op->input_tensors()) { - if (tensor_to_producer_op_.find(tns) == - tensor_to_producer_op_.end()) { - LOG(DEBUG, "input tensor has not been produced by any op"); - return false; - } + } + for (auto &tns : node->op->input_tensors()) { + if (tensor_to_producer_op_.find(tns) == + tensor_to_producer_op_.end()) { + LOG(DEBUG, "input tensor has not been produced by any op"); + return false; } } for (auto &producer : node->producers) { @@ -153,7 +150,7 @@ ModelNodeRef ModelGraph::Impl::add_op(ModelOpRef op) { } ModelNodeRef node = std::make_shared(); - node->ops.push_back(op); + node->op = op; op_to_node_[op] = node; for (auto &tns : op->input_tensors()) { @@ -194,42 +191,14 @@ void ModelGraph::Impl::remove_node(ModelNodeRef node) { producer->consumers.push_back(consumer); } } - for (auto &op : node->ops) { - auto it = op_to_node_.find(op); - if (it == op_to_node_.end()) { - ERR(ModelError, "unexpected error"); - } - if (it->second == node) { - op_to_node_.erase(it); - } + auto it2 = op_to_node_.find(node->op); + if (it2 == op_to_node_.end()) { + ERR(ModelError, "unexpected error"); } - nodes_.erase(it); -} - -bool ModelGraph::Impl::depends_on(ModelNodeRef node1, - ModelNodeRef node2) const { - if (node1 == node2) { - return false; - } - std::set seen_nodes; - std::vector boundary_nodes; - boundary_nodes.emplace_back(node1); - while (boundary_nodes.size() > 0) { - std::vector new_boundary_nodes; - for (auto &boundary_node : boundary_nodes) { - if (boundary_node == node2) { - return true; - } - for (auto &producer : boundary_node->producers) { - if (seen_nodes.find(producer) != seen_nodes.end()) { - continue; - } - new_boundary_nodes.emplace_back(producer); - } - } - boundary_nodes = new_boundary_nodes; + if (it2->second == node) { + op_to_node_.erase(it2); } - return false; + nodes_.erase(it); } void ModelGraph::Impl::recursive_remove_virtual_nodes() { @@ -252,10 +221,8 @@ void ModelGraph::Impl::recursive_remove_virtual_nodes( MODEL_GRAPH_DEBUG("remove virtual nodes"); std::vector new_boundary_nodes; for (auto &boundary_node : boundary_nodes) { - if (boundary_node->ops.size() == 0) { + if (boundary_node->op == nullptr) { ERR(ModelError, "unexpected error: empty node"); - } else if (boundary_node->ops.size() > 1) { - ERR(ModelError, "unexpected error: multiple ops in node"); } MODEL_GRAPH_DEBUG(" boundary node"); MODEL_GRAPH_DEBUG(" node: ", to_json(boundary_node).dump()); @@ -284,7 +251,7 @@ void ModelGraph::Impl::recursive_remove_virtual_nodes( to_json(producer).dump()); new_boundary_nodes.emplace_back(producer); } - if (boundary_node->ops[0]->is_virtual()) { + if (boundary_node->op->is_virtual()) { MODEL_GRAPH_DEBUG(" remove node: ", to_json(boundary_node).dump()); // Remove this node from the graph. @@ -297,152 +264,6 @@ void ModelGraph::Impl::recursive_remove_virtual_nodes( this->recursive_remove_virtual_nodes(seen_nodes, new_boundary_nodes); } -void ModelGraph::Impl::recursive_merge_nodes() { - std::vector leaf_nodes; - for (auto &node : nodes_) { - if (node->consumers.empty()) { - leaf_nodes.emplace_back(node); - } - } - UniqueList seen_nodes; - this->recursive_merge_nodes(seen_nodes, leaf_nodes); -} - -void ModelGraph::Impl::recursive_merge_nodes( - UniqueList &seen_nodes, - const std::vector &boundary_nodes) { - if (boundary_nodes.size() == 0) { - return; - } - MODEL_GRAPH_DEBUG("merge ops"); - std::vector new_boundary_nodes; - for (auto &boundary_node : boundary_nodes) { - MODEL_GRAPH_DEBUG(" boundary node"); - MODEL_GRAPH_DEBUG(" node: ", to_json(boundary_node).dump()); - if (boundary_node->producers.size() == 0) { - // This node is a root. - seen_nodes.push_back(boundary_node); - MODEL_GRAPH_DEBUG(" root"); - continue; - } - // Add all producers of this node to the next boundary. - for (auto &producer : boundary_node->producers) { - // Exception: if any consumer of the producer (rather than the - // current boundary_node) is unseen, we should not add the producer - // to the next boundary. - bool should_add = true; - for (auto &consumer : producer->consumers) { - if (consumer == boundary_node) { - continue; - } - if (!seen_nodes.contains(consumer)) { - should_add = false; - break; - } - } - if (!should_add) { - continue; - } - if (seen_nodes.contains(producer)) { - ERR(ModelError, - "unexpected error: circular dependency detected"); - } - new_boundary_nodes.emplace_back(producer); - } - ModelNodeRef merge_candidate; - if (boundary_node->producers.size() > 1) { - // This node has multiple producers. We can merge only if one - // producer depends on all other producers. - for (auto &producer : boundary_node->producers) { - bool depends_on_all = true; - for (auto &other_producer : boundary_node->producers) { - if (other_producer == producer) { - continue; - } - if (!this->depends_on(producer, other_producer)) { - depends_on_all = false; - break; - } - } - if (depends_on_all) { - merge_candidate = producer; - break; - } - } - if (!merge_candidate) { - // At least one producer does not depend on others. - // Cannot merge. - seen_nodes.push_back(boundary_node); - MODEL_GRAPH_DEBUG(" multiple producers"); - continue; - } - } else { - // This node has only one producer. - merge_candidate = *(boundary_node->producers.begin()); - } - if (merge_candidate->consumers.size() == 0) { - ERR(ModelError, "unexpected error: graph is incomplete"); - } - if (merge_candidate->consumers.size() > 1) { - // The candidate has multiple consumers. We can merge only if all - // other consumers depend on the current boundary_node. - bool depends_on_one = true; - for (auto &consumer : merge_candidate->consumers) { - if (consumer == boundary_node) { - continue; - } - if (!this->depends_on(consumer, boundary_node)) { - depends_on_one = false; - break; - } - } - if (!depends_on_one) { - // At least one consumer does not depend on the boundary_node. - // Cannot merge. - seen_nodes.push_back(boundary_node); - MODEL_GRAPH_DEBUG(" multiple consumers"); - continue; - } - } - // We can merge the two nodes. - // Merge `boundary_node` into `merge_candidate`. - MODEL_GRAPH_DEBUG(" merge: ", to_json(merge_candidate).dump(), " -> ", - to_json(boundary_node).dump()); - auto &ops = boundary_node->ops; - merge_candidate->ops.insert(merge_candidate->ops.end(), ops.begin(), - ops.end()); - for (auto &op : ops) { - op_to_node_[op] = merge_candidate; - } - for (auto &consumer : boundary_node->consumers) { - consumer->producers.erase(boundary_node); - consumer->producers.push_back(merge_candidate); - merge_candidate->consumers.push_back(consumer); - } - for (auto &producer : boundary_node->producers) { - if (producer == merge_candidate) { - continue; - } - producer->consumers.erase(boundary_node); - producer->consumers.push_back(merge_candidate); - merge_candidate->producers.push_back(producer); - } - merge_candidate->consumers.erase(boundary_node); - - // Remove `boundary_node` from `nodes_`. - auto it = nodes_.find(boundary_node); - if (it == nodes_.end()) { - ERR(ModelError, "unexpected error"); - } - nodes_.erase(it); - - // Since producer is already in the next boundary and boundary_node is - // merged into producer, we don't need to add anything to - // seen_nodes here. - } - this->recursive_merge_nodes(seen_nodes, new_boundary_nodes); -} - Json ModelGraph::Impl::to_json(const ModelNodeRef &node) const { Json j; j["Id"] = nodes_.index(node); @@ -454,10 +275,7 @@ Json ModelGraph::Impl::to_json(const ModelNodeRef &node) const { for (auto consumer : node->consumers) { j["ConsumerNodeIds"].emplace_back(nodes_.index(consumer)); } - j["Ops"] = Json::array(); - for (auto op : node->ops) { - j["Ops"].emplace_back(op->serialize()); - } + j["Op"] = node->op->serialize(); return j; } diff --git a/ark/model/model_graph_impl.hpp b/ark/model/model_graph_impl.hpp index 6c109b51e..ae4718310 100644 --- a/ark/model/model_graph_impl.hpp +++ b/ark/model/model_graph_impl.hpp @@ -65,19 +65,12 @@ class ModelGraph::Impl { void remove_node(ModelNodeRef node); - bool depends_on(ModelNodeRef node1, ModelNodeRef node2) const; - void recursive_remove_virtual_nodes(); void recursive_remove_virtual_nodes( UniqueList &seen_nodes, const std::vector &boundary_nodes); - void recursive_merge_nodes(); - - void recursive_merge_nodes(UniqueList &seen_nodes, - const std::vector &boundary_nodes); - Json to_json(const ModelNodeRef &node) const; /// The list of @ref ModelNode in the graph. diff --git a/ark/model/model_json.cpp b/ark/model/model_json.cpp index 0057ef0aa..e7bb8df18 100644 --- a/ark/model/model_json.cpp +++ b/ark/model/model_json.cpp @@ -78,13 +78,11 @@ static void verfiy_format_op(const Json &json, bool need_config) { static void verify_format_node(const Json &json) { const std::vector required_fields = {"Id", "ProducerNodeIds", - "ConsumerNodeIds", "Ops"}; + "ConsumerNodeIds", "Op"}; const std::vector array_fields = {"ProducerNodeIds", - "ConsumerNodeIds", "Ops"}; + "ConsumerNodeIds"}; verify_format_json("NodeJson", json, required_fields, array_fields); - for (const auto &op : json.at("Ops")) { - verfiy_format_op(op, false); - } + verfiy_format_op(json.at("Op"), false); } static void verify_format_model(const Json &json) { @@ -210,7 +208,7 @@ static std::stringstream &dump_pretty_object(const Json &json, std::string ModelJson::dump_pretty(int indent, int indent_step) const { std::stringstream ss; - dump_pretty_object(*this, "", 5, ss, indent, indent_step) << "\n"; + dump_pretty_object(*this, "", 4, ss, indent, indent_step) << "\n"; return ss.str(); } diff --git a/ark/model/model_node.hpp b/ark/model/model_node.hpp index 7838ca120..264e891e2 100644 --- a/ark/model/model_node.hpp +++ b/ark/model/model_node.hpp @@ -17,9 +17,8 @@ class ModelNode { public: ModelNode() = default; - /// The list of @ref Op that this @ref ModelNode contains. Sorted in the - /// execution order. - std::vector ops; + /// @ref Op that this @ref ModelNode represents. + ModelOpRef op; /// The list of @ref ModelNode that depends on this @ref ModelNode. UniqueList consumers; diff --git a/ark/ops/ops_all_reduce_test.cpp b/ark/ops/ops_all_reduce_test.cpp index 9e2c6f675..c73426f24 100644 --- a/ark/ops/ops_all_reduce_test.cpp +++ b/ark/ops/ops_all_reduce_test.cpp @@ -5,63 +5,6 @@ #include "model/model_op.hpp" #include "ops_test_common.hpp" -ark::unittest::State test_all_reduce_model() { - // OpNode graph (parentheses indicate a OpNode): - // - // +--> (S,SD,R,) --+--> (S,SD,R,) --+ - // | | | - // (S,SD,R,) --+--> (Add,) +--> (Add,) +--> (Add,) - // | ^ | ^ - // | | | | - // +---------------+ +--------------+ - - ark::Model model; - ark::Tensor input = model.tensor({1}, ark::FP32); - ark::Tensor output = model.all_reduce(input, 0, 4); - - UNITTEST_TRUE(model.verify()); - - auto compressed = model.compress(); - auto nodes = compressed.nodes(); - UNITTEST_EQ(nodes.size(), 6); - - auto nodes_iter = nodes.begin(); - auto node = *(nodes_iter++); - // UNITTEST_EQ(node->get_name(), "send;send_done;recv;"); - UNITTEST_EQ(node->producers.size(), 0); - UNITTEST_EQ(node->consumers.size(), 2); - - // UNITTEST_EQ(node->consumers[0]->get_name(), "add;"); - UNITTEST_EQ(node->consumers[0]->consumers.size(), 1); - // UNITTEST_EQ((*(node->consumers[0]->consumers.begin()))->get_name(), - // "add_1;"); - - // UNITTEST_EQ(node->consumers[1]->get_name(), - // "send_1;send_done_1;recv_1;"); - UNITTEST_EQ(node->consumers[1]->producers.size(), 1); - UNITTEST_EQ(node->consumers[1]->consumers.size(), 2); - - node = node->consumers[1]; - - // UNITTEST_EQ(node->consumers[0]->get_name(), "add_1;"); - UNITTEST_EQ(node->consumers[0]->producers.size(), 2); - UNITTEST_EQ(node->consumers[0]->consumers.size(), 1); - // UNITTEST_EQ((*(node->consumers[0]->consumers.begin()))->get_name(), - // "add_2;"); - - // UNITTEST_EQ(node->consumers[1]->get_name(), - // "send_2;send_done_2;recv_2;"); - UNITTEST_EQ(node->consumers[1]->producers.size(), 1); - UNITTEST_EQ(node->consumers[1]->consumers.size(), 1); - // UNITTEST_EQ((*(node->consumers[1]->consumers.begin()))->get_name(), - // "add_2;"); - UNITTEST_EQ( - (*(node->consumers[1]->consumers.begin()))->ops[0]->result_tensors()[0], - output.ref()); - - return ark::unittest::SUCCESS; -} - template void baseline_all_reduce(std::vector &outputs, const std::vector &output_shapes, @@ -165,7 +108,6 @@ ark::unittest::State test_all_reduce_8gpus() { } int main() { - // UNITTEST(test_all_reduce_model); UNITTEST(test_all_reduce_4gpus); UNITTEST(test_all_reduce_8gpus); return 0; diff --git a/ark/ops/ops_identity_test.cpp b/ark/ops/ops_identity_test.cpp index 6e395af62..a6e49c9c0 100644 --- a/ark/ops/ops_identity_test.cpp +++ b/ark/ops/ops_identity_test.cpp @@ -6,12 +6,12 @@ #include "model/model_op.hpp" #include "ops_test_common.hpp" -ark::unittest::State test_identity_model() { - // OpNode graph (parentheses indicate a OpNode): +ark::unittest::State test_ops_identity_model() { + // OpNode graph: // - // (Relu,) --+ - // | - // (Relu,) --+--> (Relu,) + // ReluOp --+ + // | + // ReluOp --+--> ReluOp // ark::Model model; @@ -27,29 +27,26 @@ ark::unittest::State test_identity_model() { UNITTEST_TRUE(model.verify()); auto compressed = model.compress(); + UNITTEST_TRUE(compressed.verify()); auto nodes = compressed.nodes(); UNITTEST_EQ(nodes.size(), 3); - auto nodes_iter = nodes.begin(); - auto node = *(nodes_iter++); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], r0.ref()); - UNITTEST_EQ(node->producers.size(), 0); - UNITTEST_EQ(node->consumers.size(), 1); + UNITTEST_EQ(nodes[0]->op->result_tensors()[0], r0.ref()); + UNITTEST_EQ(nodes[0]->producers.size(), 0); + UNITTEST_EQ(nodes[0]->consumers.size(), 1); - node = *(nodes_iter++); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], r1.ref()); - UNITTEST_EQ(node->producers.size(), 0); - UNITTEST_EQ(node->consumers.size(), 1); + UNITTEST_EQ(nodes[1]->op->result_tensors()[0], r1.ref()); + UNITTEST_EQ(nodes[1]->producers.size(), 0); + UNITTEST_EQ(nodes[1]->consumers.size(), 1); - node = *(nodes_iter++); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t4.ref()); - UNITTEST_EQ(node->producers.size(), 2); - UNITTEST_EQ(node->consumers.size(), 0); + UNITTEST_EQ(nodes[2]->op->result_tensors()[0], t4.ref()); + UNITTEST_EQ(nodes[2]->producers.size(), 2); + UNITTEST_EQ(nodes[2]->consumers.size(), 0); return ark::unittest::SUCCESS; } -ark::unittest::State test_identity() { +ark::unittest::State test_ops_identity() { ark::Model model; // float buf[2][3][4][5]; ark::Tensor tns0 = model.tensor({2, 3, 4, 5}, ark::FP32); @@ -81,7 +78,7 @@ ark::unittest::State test_identity() { } int main() { - UNITTEST(test_identity_model); - UNITTEST(test_identity); + UNITTEST(test_ops_identity_model); + UNITTEST(test_ops_identity); return 0; } diff --git a/ark/ops/ops_sharding_test.cpp b/ark/ops/ops_sharding_test.cpp index 5ba270baa..bd1c3a806 100644 --- a/ark/ops/ops_sharding_test.cpp +++ b/ark/ops/ops_sharding_test.cpp @@ -7,14 +7,14 @@ #include "model/model_op.hpp" #include "unittest/unittest_utils.h" -ark::unittest::State test_model_op_sharding() { - // OpNode graph (parentheses indicate a OpNode): +ark::unittest::State test_ops_sharding_model() { + // OpNode graph: // - // (Relu,) --+ - // | - // (Relu,) --+ - // | - // (Relu,) --+--> (Relu,) + // ReluOp --+ + // | + // ReluOp --+ + // | + // ReluOp --+--> ReluOp // ark::Model model; @@ -37,34 +37,30 @@ ark::unittest::State test_model_op_sharding() { UNITTEST_TRUE(model.verify()); auto compressed = model.compress(); + UNITTEST_TRUE(compressed.verify()); auto nodes = compressed.nodes(); UNITTEST_EQ(nodes.size(), 4); - auto nodes_iter = nodes.begin(); - auto node = *(nodes_iter++); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], r0.ref()); - UNITTEST_EQ(node->producers.size(), 0); - UNITTEST_EQ(node->consumers.size(), 1); + UNITTEST_EQ(nodes[0]->op->result_tensors()[0], r0.ref()); + UNITTEST_EQ(nodes[0]->producers.size(), 0); + UNITTEST_EQ(nodes[0]->consumers.size(), 1); - node = *(nodes_iter++); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], r1.ref()); - UNITTEST_EQ(node->producers.size(), 0); - UNITTEST_EQ(node->consumers.size(), 1); + UNITTEST_EQ(nodes[1]->op->result_tensors()[0], r1.ref()); + UNITTEST_EQ(nodes[1]->producers.size(), 0); + UNITTEST_EQ(nodes[1]->consumers.size(), 1); - node = *(nodes_iter++); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], r2.ref()); - UNITTEST_EQ(node->producers.size(), 0); - UNITTEST_EQ(node->consumers.size(), 1); + UNITTEST_EQ(nodes[2]->op->result_tensors()[0], r2.ref()); + UNITTEST_EQ(nodes[2]->producers.size(), 0); + UNITTEST_EQ(nodes[2]->consumers.size(), 1); - node = *(nodes_iter++); - UNITTEST_EQ(node->ops[0]->result_tensors()[0], t5.ref()); - UNITTEST_EQ(node->producers.size(), 3); - UNITTEST_EQ(node->consumers.size(), 0); + UNITTEST_EQ(nodes[3]->op->result_tensors()[0], t5.ref()); + UNITTEST_EQ(nodes[3]->producers.size(), 3); + UNITTEST_EQ(nodes[3]->consumers.size(), 0); return ark::unittest::SUCCESS; } int main() { - UNITTEST(test_model_op_sharding); + UNITTEST(test_ops_sharding_model); return 0; } diff --git a/docs/model_file.md b/docs/model_file.md index 078d627d7..614c83bae 100644 --- a/docs/model_file.md +++ b/docs/model_file.md @@ -11,18 +11,17 @@ See an example model file: [Example 1](../examples/tutorial/model.json). - Id (Int) - ProducerNodeIds (Array of Int) - ConsumerNodeIds (Array of Int) - - Ops (Array of Op) - - Op (Object) - - Type (String) - - Name (String) - - IsVirtual (Boolean) - - ReadTensors (Array of Tensor) - - Tensor (Object, details below) - - WriteTensors (Array of Tensor) - - Tensor (Object, details below) - - ResultTensors (Array of Tensor) - - Tensor (Object, details below) - - Args (Object, structure depends on Op Type) + - Op (Object) + - Type (String) + - Name (String) + - IsVirtual (Boolean) + - ReadTensors (Array of Tensor) + - Tensor (Object, details below) + - WriteTensors (Array of Tensor) + - Tensor (Object, details below) + - ResultTensors (Array of Tensor) + - Tensor (Object, details below) + - Args (Object, structure depends on Op Type) A `Tensor` object has the following structure: @@ -49,11 +48,9 @@ An `Args` object has a flexible structure depending on the type of `Op`, which w ## Node -A `Node` object describes a node in the computation graph. +A `Node` object describes a node in the computation graph. A node consists of an operator (`Op`) that describes computation (or communication) task of the node. -A node consists of an array of one or more operators (`Op`s). The operators in a node are supposed to be executed in the order that appears in the array, but they may not have hard dependencies in between, i.e., a later operator's computation may depend only on a part of a previous operator's result. For example, if an element-wise operator is followed by another element-wise operator with the same shape of data, each element of the later operator depends only on a single element from the earlier operator. This poses possibility of operator fusion when we design an execution plan of this model. - -Each node may produce or consume tensors. Produced tensors are those that appear in an operator's `ResultTensors` array, while consumed tensors are those that appear in an operator's `ReadTensors` or `WriteTensors` array. Each node has a unique ID, and declares `ProducerNodeIds` and `ConsumerNodeIds` to describe dependencies between nodes. `ProducerNodeIds` lists IDs of all nodes that produce tensors consumed by this node. Similarly, `ConsumerNodeIds` lists IDs of all nodes that consume tensors produced by this node. +Each node may produce or consume tensors. Produced tensors are those that appear in the operator's `ResultTensors` array, while consumed tensors are those that appear in the operator's `ReadTensors` or `WriteTensors` array. Each node has a unique ID, and declares `ProducerNodeIds` and `ConsumerNodeIds` to describe dependencies between nodes. `ProducerNodeIds` lists IDs of all nodes that produce tensors consumed by this node. Similarly, `ConsumerNodeIds` lists IDs of all nodes that consume tensors produced by this node. ## Op diff --git a/examples/tutorial/model.json b/examples/tutorial/model.json index 1bc9233a5..c2b88bbd0 100644 --- a/examples/tutorial/model.json +++ b/examples/tutorial/model.json @@ -5,127 +5,136 @@ { "Id": 0, "ProducerNodeIds": [], - "ConsumerNodeIds": [2], - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":1,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":4,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - } - }, - { - "Type": "Sigmoid", - "Name": "sigmoid", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":6,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": {} - }, - { - "Type": "Mul", - "Name": "mul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":8,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": {} + "ConsumerNodeIds": [1,2], + "Op": { + "Type": "Matmul", + "Name": "matmul", + "IsVirtual": false, + "ReadTensors": [ + {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":1,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":4,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} } - ] + } }, { "Id": 1, - "ProducerNodeIds": [], + "ProducerNodeIds": [0], "ConsumerNodeIds": [2], - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":3,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":10,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - } - } - ] + "Op": { + "Type": "Sigmoid", + "Name": "sigmoid", + "IsVirtual": false, + "ReadTensors": [ + {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":6,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {} + } }, { "Id": 2, - "ProducerNodeIds": [1,0], + "ProducerNodeIds": [0,1], + "ConsumerNodeIds": [4], + "Op": { + "Type": "Mul", + "Name": "mul", + "IsVirtual": false, + "ReadTensors": [ + {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":8,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {} + } + }, + { + "Id": 3, + "ProducerNodeIds": [], + "ConsumerNodeIds": [4], + "Op": { + "Type": "Matmul", + "Name": "matmul_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":3,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":10,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + } + } + }, + { + "Id": 4, + "ProducerNodeIds": [2,3], + "ConsumerNodeIds": [5], + "Op": { + "Type": "Mul", + "Name": "mul_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":12,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {} + } + }, + { + "Id": 5, + "ProducerNodeIds": [4], "ConsumerNodeIds": [], - "Ops": [ - { - "Type": "Mul", - "Name": "mul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":12,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": {} - }, - { - "Type": "Matmul", - "Name": "matmul_2", - "IsVirtual": false, - "ReadTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":2,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":14,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":15,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - } + "Op": { + "Type": "Matmul", + "Name": "matmul_2", + "IsVirtual": false, + "ReadTensors": [ + {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":2,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":14,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":15,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} } - ] + } } ] } diff --git a/python/unittest/test_model.py b/python/unittest/test_model.py index 3da87795c..da8ae399a 100644 --- a/python/unittest/test_model.py +++ b/python/unittest/test_model.py @@ -17,9 +17,8 @@ def test_model(): assert m_json.get("Nodes", None) is not None assert len(m_json["Nodes"]) == 1 - assert m_json["Nodes"][0].get("Ops", None) is not None - assert len(m_json["Nodes"][0]["Ops"]) == 1 - assert m_json["Nodes"][0]["Ops"][0].get("Type", None) == "Add" + assert m_json["Nodes"][0].get("Op", None) is not None + assert m_json["Nodes"][0]["Op"].get("Type", None) == "Add" ark.Model.reset()