@@ -175,42 +175,30 @@ namespace refactor::frontend {
175175
176176 std::vector<computation::Node> nodes (_internal.nodes .size ());
177177 std::vector<computation::Edge> edges (_internal.edges .size ());
178+
179+ auto fn = [&edges, this ](auto i) {
180+ if (edges[i].tensor ) {
181+ return ;
182+ }
183+ auto const &[tensor, name] = _internal.edges [i];
184+ computation::Shape shape (tensor->shape .size ());
185+ std::transform (std::execution::unseq,
186+ tensor->shape .begin (), tensor->shape .end (), shape.begin (),
187+ [](auto const &dim) { return dim.value (); });
188+ auto layout = shape.size () == 4 ? computation::LayoutType::NCHW : computation::LayoutType::Others;
189+ edges[i].tensor = computation::Tensor::share (tensor->dataType , std::move (shape), layout, tensor->data );
190+ edges[i].name = name;
191+ };
192+
178193 std::transform (_internal.topology .begin (), _internal.topology .end (), nodes.begin (),
179- [&edges , this ](auto const &nodeRef) {
194+ [&fn , this ](auto const &nodeRef) {
180195 auto const &[op, name] = _internal.nodes [nodeRef.idx ];
196+ std::for_each (nodeRef.inputs .begin (), nodeRef.inputs .end (), fn);
197+ std::for_each (nodeRef.outputs .begin (), nodeRef.outputs .end (), fn);
181198 auto constant = std::all_of (std::execution::unseq,
182199 nodeRef.outputs .begin (), nodeRef.outputs .end (),
183200 [this ](auto i) { return _internal.edges [i].tensor ->data ; });
184- if (constant) {
185- return computation::Node{nullptr , name};
186- }
187- auto fn = [&edges, &nodeRef, this ](auto i) {
188- if (edges[i].tensor ) {
189- return ;
190- }
191- auto const &[tensor, name] = _internal.edges [i];
192- computation::Shape shape (tensor->shape .size ());
193- std::transform (std::execution::unseq,
194- tensor->shape .begin (), tensor->shape .end (), shape.begin (),
195- [](auto const &dim) { return dim.value (); });
196- auto layout = shape.size () == 4 ? computation::LayoutType::NCHW : computation::LayoutType::Others;
197- edges[i].tensor = computation::Tensor::share (tensor->dataType , std::move (shape), layout, tensor->data );
198- edges[i].name = name;
199- };
200- auto op_ = op->lower (TensorRefs (_internal.edges , nodeRef.inputs ));
201- auto valueDependentInputs = op->valueDependentInputs ();
202- auto it = valueDependentInputs.begin ();
203- for (auto i : range0_ (nodeRef.inputs .size ())) {
204- auto input = nodeRef.inputs [i];
205- if (it != valueDependentInputs.end () && i == *it) {
206- edges[input].name = _internal.edges [input].name ;
207- ++it;
208- continue ;
209- }
210- fn (input);
211- }
212- std::for_each (std::execution::unseq, nodeRef.outputs .begin (), nodeRef.outputs .end (), fn);
213- return computation::Node{std::move (op_), name};
201+ return computation::Node{constant ? nullptr : op->lower (TensorRefs (_internal.edges , nodeRef.inputs )), name};
214202 });
215203
216204 auto const endTime = high_resolution_clock::now ();
0 commit comments