Skip to content

Commit 5046a53

Browse files
committed
fix(frontend): 简化 lower to computation 逻辑,不要跳过任何边的映射
Signed-off-by: YdrMaster <[email protected]>
1 parent c500e1d commit 5046a53

File tree

1 file changed

+19
-31
lines changed

1 file changed

+19
-31
lines changed

src/06frontend/src/graph.cc

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -175,42 +175,30 @@ namespace refactor::frontend {
175175

176176
std::vector<computation::Node> nodes(_internal.nodes.size());
177177
std::vector<computation::Edge> edges(_internal.edges.size());
178+
179+
auto fn = [&edges, this](auto i) {
180+
if (edges[i].tensor) {
181+
return;
182+
}
183+
auto const &[tensor, name] = _internal.edges[i];
184+
computation::Shape shape(tensor->shape.size());
185+
std::transform(std::execution::unseq,
186+
tensor->shape.begin(), tensor->shape.end(), shape.begin(),
187+
[](auto const &dim) { return dim.value(); });
188+
auto layout = shape.size() == 4 ? computation::LayoutType::NCHW : computation::LayoutType::Others;
189+
edges[i].tensor = computation::Tensor::share(tensor->dataType, std::move(shape), layout, tensor->data);
190+
edges[i].name = name;
191+
};
192+
178193
std::transform(_internal.topology.begin(), _internal.topology.end(), nodes.begin(),
179-
[&edges, this](auto const &nodeRef) {
194+
[&fn, this](auto const &nodeRef) {
180195
auto const &[op, name] = _internal.nodes[nodeRef.idx];
196+
std::for_each(nodeRef.inputs.begin(), nodeRef.inputs.end(), fn);
197+
std::for_each(nodeRef.outputs.begin(), nodeRef.outputs.end(), fn);
181198
auto constant = std::all_of(std::execution::unseq,
182199
nodeRef.outputs.begin(), nodeRef.outputs.end(),
183200
[this](auto i) { return _internal.edges[i].tensor->data; });
184-
if (constant) {
185-
return computation::Node{nullptr, name};
186-
}
187-
auto fn = [&edges, &nodeRef, this](auto i) {
188-
if (edges[i].tensor) {
189-
return;
190-
}
191-
auto const &[tensor, name] = _internal.edges[i];
192-
computation::Shape shape(tensor->shape.size());
193-
std::transform(std::execution::unseq,
194-
tensor->shape.begin(), tensor->shape.end(), shape.begin(),
195-
[](auto const &dim) { return dim.value(); });
196-
auto layout = shape.size() == 4 ? computation::LayoutType::NCHW : computation::LayoutType::Others;
197-
edges[i].tensor = computation::Tensor::share(tensor->dataType, std::move(shape), layout, tensor->data);
198-
edges[i].name = name;
199-
};
200-
auto op_ = op->lower(TensorRefs(_internal.edges, nodeRef.inputs));
201-
auto valueDependentInputs = op->valueDependentInputs();
202-
auto it = valueDependentInputs.begin();
203-
for (auto i : range0_(nodeRef.inputs.size())) {
204-
auto input = nodeRef.inputs[i];
205-
if (it != valueDependentInputs.end() && i == *it) {
206-
edges[input].name = _internal.edges[input].name;
207-
++it;
208-
continue;
209-
}
210-
fn(input);
211-
}
212-
std::for_each(std::execution::unseq, nodeRef.outputs.begin(), nodeRef.outputs.end(), fn);
213-
return computation::Node{std::move(op_), name};
201+
return computation::Node{constant ? nullptr : op->lower(TensorRefs(_internal.edges, nodeRef.inputs)), name};
214202
});
215203

216204
auto const endTime = high_resolution_clock::now();

0 commit comments

Comments
 (0)