@@ -33,6 +33,21 @@ using std::vector;
3333
3434namespace {
3535
36+ using FunctionBounds = map<Function, map<string, Interval>, Function::Compare>;
37+
38+ struct TranslationUnit {
39+ HalideComponents components;
40+ Scope<Interval> enclosingLoopIndices;
41+ map<string, Function> funcs;
42+ FunctionBounds bounds;
43+ bool throwWarnings;
44+ };
45+
46+ void translateFor (const lang::For& f, TranslationUnit* tu);
47+ void translateComprehension (
48+ const lang::Comprehension& comprehension,
49+ TranslationUnit* tu);
50+
3651Type translateScalarType (int tcType) {
3752 switch (tcType) {
3853 case lang::TK_BOOL:
@@ -264,8 +279,6 @@ vector<const Variable*> unboundVariables(const vector<Var>& lhs, Expr rhs) {
264279 return finder.result ;
265280}
266281
267- typedef map<Function, map<string, Interval>, Function::Compare> FunctionBounds;
268-
269282void forwardBoundsInference (
270283 const std::vector<Expr>& exprs,
271284 const FunctionBounds& bounds,
@@ -500,6 +513,29 @@ Expr reductionUpdate(Expr e) {
500513 return Call::make (e.type (), kReductionUpdate , {e}, Call::Intrinsic);
501514}
502515
516+ void translateStatement (const lang::TreeRef& stmt, TranslationUnit* tu) {
517+ if (stmt->kind () == lang::TK_COMPREHENSION) {
518+ translateComprehension (lang::Comprehension (stmt), tu);
519+ } else {
520+ CHECK_EQ (stmt->kind (), lang::TK_FOR);
521+ translateFor (lang::For (stmt), tu);
522+ }
523+ }
524+
525+ void translateFor (const lang::For& f, TranslationUnit* pTU) {
526+ const map<string, Parameter>& params = pTU->components .params ;
527+ auto constraint = lang::RangeConstraint (f.rangeConstraint ());
528+ Interval i;
529+ const map<string, Expr> lets;
530+ i.min = translateExpr (constraint.start (), params, pTU->funcs , lets);
531+ i.max = translateExpr (constraint.end (), params, pTU->funcs , lets) - 1 ;
532+ pTU->enclosingLoopIndices .push (f.index ().name (), i);
533+ for (auto stm : f.statements ()) {
534+ translateStatement (stm, pTU);
535+ }
536+ pTU->enclosingLoopIndices .pop (f.index ().name ());
537+ }
538+
503539// Translate a single TC comprehension/statement to Halide components: funcs,
504540// bounds, reductions.
505541//
@@ -508,10 +544,11 @@ Expr reductionUpdate(Expr e) {
508544// in order to be able to apply internal Halide analysis passes on them.
509545void translateComprehension (
510546 const lang::Comprehension& comprehension,
511- const map<string, Parameter>& params,
512- bool throwWarnings,
513- map<string, Function>* funcs,
514- FunctionBounds* bounds) {
547+ TranslationUnit* pTU) {
548+ const map<string, Parameter>& params = pTU->components .params ;
549+ bool throwWarnings = pTU->throwWarnings ;
550+ map<string, Function>* funcs = &pTU->funcs ;
551+ FunctionBounds* bounds = &pTU->bounds ;
515552 Function f;
516553 auto it = funcs->find (comprehension.ident ().name ());
517554 if (it != funcs->end ()) {
@@ -647,6 +684,12 @@ void translateComprehension(
647684 // demand).
648685 Scope<Interval> solution;
649686
687+ // Copy information from enclosing "for" loops
688+ for (auto entry = pTU->enclosingLoopIndices .cbegin ();
689+ entry != pTU->enclosingLoopIndices .cend ();
690+ ++entry) {
691+ solution.push (entry.name (), entry.value ());
692+ }
650693 // Put anything explicitly specified with a 'where' class in the solution
651694 for (auto constraint_ : comprehension.whereClauses ()) {
652695 if (constraint_->kind () != lang::TK_RANGE_CONSTRAINT)
@@ -656,6 +699,11 @@ void translateComprehension(
656699 i.min = translateExpr (constraint.start (), params, *funcs, lets);
657700 i.max = translateExpr (constraint.end (), params, *funcs, lets) - 1 ;
658701
702+ if (solution.contains (constraint.ident ().name ())) {
703+ throw lang::ErrorReport (constraint_)
704+ << " Multiple range constraints per index NYI" ;
705+ }
706+
659707 // TODO: In the future we'll want to make any non-trivial bounds
660708 // into hidden scalar parameters, and just pass variables to the
661709 // polyhedral layer instead of potentially complex
@@ -755,25 +803,20 @@ void translateComprehension(
755803
756804// Translate a semantically checked TC def to HalideComponents struct.
757805HalideComponents translateDef (const lang::Def& def, bool throwWarnings) {
758- map<string, Function> funcs;
759- HalideComponents components;
760- components.def = def;
761- FunctionBounds bounds;
806+ TranslationUnit tu;
807+ tu.components .def = def;
808+ tu.throwWarnings = throwWarnings;
762809
763810 for (auto p : def.params ()) {
764- translateParam (p, &components.params , &components.inputs );
811+ translateParam (p, &tu. components .params , &tu. components .inputs );
765812 }
766- for (auto c : def.statements ()) {
767- translateComprehension (
768- lang::Comprehension (c),
769- components.params ,
770- throwWarnings,
771- &funcs,
772- &bounds);
813+ // Semantically valid TCs include at most one outer sequential loop for now
814+ for (auto stm : def.statements ()) {
815+ translateStatement (stm, &tu);
773816 }
774817 vector<Function> outputs;
775818 for (auto p : def.returns ()) {
776- translateOutput (p, funcs, &outputs);
819+ translateOutput (p, tu. funcs , &outputs);
777820 }
778821
779822 // Now apply an extremely simplified version of Halide lowering
@@ -804,11 +847,12 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
804847 // used in the pipelines we construct here, so just make a host target.
805848 Target target (" host" );
806849 Stmt s = schedule_functions (outputs, fused_groups, env, target, any_memoized);
850+ LOG_IF (ERROR, tc::FLAGS_debug_halide) << s;
807851 // we insert these to allow for inplace mutation of in/out tensors
808852 s = remove_undef (s);
809853 // Apply forward bounds inference results. This replaces the usual Halide
810854 // bounds inference.
811- for (auto p : bounds) {
855+ for (auto p : tu. bounds ) {
812856 const Function& f = p.first ;
813857 for (auto b : p.second ) {
814858 const string& var = b.first ;
@@ -893,20 +937,20 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
893937 };
894938 s = SubstituteAllLets ().mutate (s);
895939
896- components.stmt = s;
940+ tu. components .stmt = s;
897941
898942 for (Function f : outputs) {
899943 OutputImageParam o = Func (f).output_buffers ()[0 ];
900944 // Apply forward bounds inference results to the output buffers.
901- const auto & b = bounds[f];
945+ const auto & b = tu. bounds [f];
902946 for (int i = 0 ; i < o.dimensions (); i++) {
903947 const Interval& bound = b.at (f.args ()[i]);
904948 o.dim (i).set_bounds (bound.min , simplify (bound.max - bound.min + 1 ));
905949 }
906- components.outputs .push_back (o);
950+ tu. components .outputs .push_back (o);
907951 }
908952
909- return components;
953+ return tu. components ;
910954}
911955} // namespace
912956
0 commit comments