@@ -484,21 +484,19 @@ void forwardBoundsInference(
484484 }
485485}
486486
487- Expr reductionInit (Expr e) {
488- return Call::make (e.type (), kReductionInit , {e}, Call::Intrinsic);
489- }
490-
491487Expr reductionUpdate (Expr e) {
492488 return Call::make (e.type (), kReductionUpdate , {e}, Call::Intrinsic);
493489}
494490
491+ // Note that the function definitions created by translateComprehension may
492+ // contain kReductionUpdate intrinsics. These may have to be removed
493+ // in order to be able to apply internal Halide analysis passes on them.
495494void translateComprehension (
496495 const lang::Comprehension& c,
497496 const map<string, Parameter>& params,
498497 bool throwWarnings,
499498 map<string, Function>* funcs,
500- FunctionBounds* bounds,
501- vector<Function>* reductions) {
499+ FunctionBounds* bounds) {
502500 Function f;
503501 auto it = funcs->find (c.ident ().name ());
504502 if (it != funcs->end ()) {
@@ -593,8 +591,9 @@ void translateComprehension(
593591 << c.assignment ()->range ().text () << " \n " ;
594592 }
595593
594+ // Tag reductions as such
596595 if (c.assignment ()->kind () != ' =' ) {
597- reductions-> push_back (f );
596+ rhs = reductionUpdate (rhs );
598597 }
599598
600599 // Bind any scalar params on the rhs to their parameter objects.
@@ -743,13 +742,12 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
743742 components.def = def;
744743 FunctionBounds bounds;
745744
746- vector<Function> reductions;
747745 for (auto p : def.params ()) {
748746 translateParam (p, &components.params , &components.inputs );
749747 }
750748 for (auto c : def.statements ()) {
751749 translateComprehension (
752- c, components.params , throwWarnings, &funcs, &bounds, &reductions );
750+ c, components.params , throwWarnings, &funcs, &bounds);
753751 }
754752 vector<Function> outputs;
755753 for (auto p : def.returns ()) {
@@ -807,33 +805,6 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
807805 s = uniquify_variable_names (s);
808806 s = simplify (s);
809807
810- // Tag reductions as such
811- for (const Function& f : reductions) {
812- class TagReduction : public IRMutator2 {
813- using IRMutator2::visit;
814- bool found_init = false ;
815- Stmt visit (const Provide* op) override {
816- if (op->name == f.name ()) {
817- if (found_init) {
818- return Provide::make (
819- op->name , {reductionUpdate (op->values [0 ])}, op->args );
820- } else {
821- found_init = true ;
822- return Provide::make (
823- op->name , {reductionInit (op->values [0 ])}, op->args );
824- }
825- } else {
826- return op;
827- }
828- }
829- const Function& f;
830-
831- public:
832- TagReduction (const Function& f) : f(f) {}
833- } tagReduction (f);
834- s = tagReduction.mutate (s);
835- }
836-
837808 // Trim ProducerConsumer annotations. TC doesn't use them.
838809 class RemoveProducerConsumer : public IRMutator2 {
839810 using IRMutator2::visit;
0 commit comments