@@ -62,43 +62,45 @@ Type translateScalarType(int tcType) {
6262  }
6363}
6464
65+ //  translate the TC def input params to corresponding Halide components.
66+ //  params, inputs will be populated here
6567void  translateParam (
6668    const  lang::Param& p,
6769    map<string, Parameter>* params,
6870    vector<ImageParam>* inputs) {
71+   //  check if the param is already converted to halide components
6972  if  (params->find (p.ident ().name ()) != params->end ()) {
7073    return ;
71-   } else  {
72-     lang::TensorType type = p.tensorType ();
73-     int  dimensions = (int )type.dims ().size ();
74-     ImageParam imageParam (
75-         translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
76-     inputs->push_back (imageParam);
77-     vector<Expr> dims;
78-     for  (auto  d_ : type.dims ()) {
79-       if  (d_->kind () == lang::TK_IDENT) {
80-         auto  d = lang::Ident (d_);
81-         auto  it = params->find (d.name ());
82-         Parameter p;
83-         if  (it != params->end ()) {
84-           p = it->second ;
85-         } else  {
86-           p = Parameter (Int (32 ), false , 0 , d.name (), true );
87-           (*params)[d.name ()] = p;
88-         }
89-         dims.push_back (Variable::make (Int (32 ), p.name (), p));
74+   }
75+   lang::TensorType type = p.tensorType ();
76+   int  dimensions = (int )type.dims ().size ();
77+   ImageParam imageParam (
78+       translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
79+   inputs->push_back (imageParam);
80+   vector<Expr> dims;
81+   for  (auto  d_ : type.dims ()) {
82+     if  (d_->kind () == lang::TK_IDENT) {
83+       auto  d = lang::Ident (d_);
84+       auto  it = params->find (d.name ());
85+       Parameter p;
86+       if  (it != params->end ()) {
87+         p = it->second ;
9088      } else  {
91-         CHECK (d_->kind () == lang::TK_CONST);
92-         int32_t  value = lang::Const (d_).value ();
93-         dims.push_back (Expr (value));
89+         p = Parameter (Int (32 ), false , 0 , d.name (), true );
90+         (*params)[d.name ()] = p;
9491      }
92+       dims.push_back (Variable::make (Int (32 ), p.name (), p));
93+     } else  {
94+       CHECK (d_->kind () == lang::TK_CONST);
95+       int32_t  value = lang::Const (d_).value ();
96+       dims.push_back (Expr (value));
9597    }
98+   }
9699
97-     for  (int  i = 0 ; i < imageParam.dimensions (); i++) {
98-       imageParam.dim (i).set_bounds (0 , dims[i]);
99-     }
100-     (*params)[imageParam.name ()] = imageParam.parameter ();
100+   for  (int  i = 0 ; i < imageParam.dimensions (); i++) {
101+     imageParam.dim (i).set_bounds (0 , dims[i]);
101102  }
103+   (*params)[imageParam.name ()] = imageParam.parameter ();
102104}
103105
104106void  translateOutput (
@@ -156,6 +158,8 @@ Expr translateExpr(
156158      return  t (0 ) * t (1 );
157159    case  ' /' 
158160      return  t (0 ) / t (1 );
161+     case  ' %' 
162+       return  t (0 ) % t (1 );
159163    case  lang::TK_MIN:
160164      return  min (t (0 ), t (1 ));
161165    case  lang::TK_MAX:
@@ -492,20 +496,25 @@ Expr reductionUpdate(Expr e) {
492496  return  Call::make (e.type (), kReductionUpdate , {e}, Call::Intrinsic);
493497}
494498
499+ //  translate a single TC comprehension/statement to Halide component.
500+ //  funcs, bounds, reductions will be populated
495501void  translateComprehension (
496-     const  lang::Comprehension& c ,
502+     const  lang::Comprehension& comprehension ,
497503    const  map<string, Parameter>& params,
498504    bool  throwWarnings,
499505    map<string, Function>* funcs,
500506    FunctionBounds* bounds,
501507    vector<Function>* reductions) {
508+   //  Function is the internal Halide IR type for a pipeline
509+   //  stage. Func is the front-end class that wraps it. Here it's
510+   //  convenient to use both. Why? what is not exposed in Func?
502511  Function f;
503-   auto  it = funcs->find (c .ident ().name ());
512+   auto  it = funcs->find (comprehension .ident ().name ());
504513  if  (it != funcs->end ()) {
505514    f = it->second ;
506515  } else  {
507-     f = Function (c .ident ().name ());
508-     (*funcs)[c .ident ().name ()] = f;
516+     f = Function (comprehension .ident ().name ());
517+     (*funcs)[comprehension .ident ().name ()] = f;
509518  }
510519  //  Function is the internal Halide IR type for a pipeline
511520  //  stage. Func is the front-end class that wraps it. Here it's
@@ -514,7 +523,7 @@ void translateComprehension(
514523
515524  vector<Var> lhs;
516525  vector<Expr> lhs_as_exprs;
517-   for  (lang::Ident id : c .indices ()) {
526+   for  (lang::Ident id : comprehension .indices ()) {
518527    lhs.push_back (Var (id.name ()));
519528    lhs_as_exprs.push_back (lhs.back ());
520529  }
@@ -523,17 +532,17 @@ void translateComprehension(
523532  //  in the future we may consider using Halide Let bindings when they
524533  //  are supported later
525534  map<string, Expr> lets;
526-   for  (auto  wc : c .whereClauses ()) {
535+   for  (auto  wc : comprehension .whereClauses ()) {
527536    if  (wc->kind () == lang::TK_LET) {
528537      auto  let = lang::Let (wc);
529538      lets[let.name ().name ()] = translateExpr (let.rhs (), params, *funcs, lets);
530539    }
531540  }
532541
533-   Expr rhs = translateExpr (c .rhs (), params, *funcs, lets);
542+   Expr rhs = translateExpr (comprehension .rhs (), params, *funcs, lets);
534543
535544  std::vector<Expr> all_exprs;
536-   for  (auto  wc : c .whereClauses ()) {
545+   for  (auto  wc : comprehension .whereClauses ()) {
537546    if  (wc->kind () == lang::TK_EXISTS) {
538547      all_exprs.push_back (
539548          translateExpr (lang::Exists (wc).exp (), params, *funcs, lets));
@@ -557,7 +566,7 @@ void translateComprehension(
557566  //  values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity
558567  //  for the reduction and then applies the reduction.
559568  bool  should_zero = false ;
560-   switch  (c .assignment ()->kind ()) {
569+   switch  (comprehension .assignment ()->kind ()) {
561570    case  lang::TK_PLUS_EQ_B:
562571      should_zero = true ; //  fallthrough
563572    case  lang::TK_PLUS_EQ:
@@ -589,11 +598,12 @@ void translateComprehension(
589598    case  ' =' 
590599      break ;
591600    default :
592-       throw  lang::ErrorReport (c) << " Unimplemented reduction " 
593-                                  << c.assignment ()->range ().text () << " \n " 
601+       throw  lang::ErrorReport (comprehension)
602+           << " Unimplemented reduction " 
603+           << comprehension.assignment ()->range ().text () << " \n " 
594604  }
595605
596-   if  (c .assignment ()->kind () != ' =' 
606+   if  (comprehension .assignment ()->kind () != ' =' 
597607    reductions->push_back (f);
598608  }
599609
@@ -633,7 +643,7 @@ void translateComprehension(
633643  Scope<Interval> solution;
634644
635645  //  Put anything explicitly specified with a 'where' class in the solution
636-   for  (auto  constraint_ : c .whereClauses ()) {
646+   for  (auto  constraint_ : comprehension .whereClauses ()) {
637647    if  (constraint_->kind () != lang::TK_RANGE_CONSTRAINT)
638648      continue ;
639649    auto  constraint = lang::RangeConstraint (constraint_);
@@ -654,7 +664,8 @@ void translateComprehension(
654664
655665  //  Infer the rest
656666  all_exprs.push_back (rhs);
657-   forwardBoundsInference (all_exprs, *bounds, c, throwWarnings, &solution);
667+   forwardBoundsInference (
668+       all_exprs, *bounds, comprehension, throwWarnings, &solution);
658669
659670  //  TODO: What if subsequent updates have incompatible bounds
660671  //  (e.g. an in-place stencil)?. The .bound directive will use the
@@ -665,7 +676,7 @@ void translateComprehension(
665676
666677  for  (Var v : lhs) {
667678    if  (!solution.contains (v.name ())) {
668-       throw  lang::ErrorReport (c )
679+       throw  lang::ErrorReport (comprehension )
669680          << " Free variable " 
670681          << "  was not solved in range inference. May not be used right-hand side" 
671682    }
@@ -689,7 +700,7 @@ void translateComprehension(
689700    for  (size_t  i = 0 ; i < unbound.size (); i++) {
690701      auto  v = unbound[unbound.size () - 1  - i];
691702      if  (!solution.contains (v->name )) {
692-         throw  lang::ErrorReport (c )
703+         throw  lang::ErrorReport (comprehension )
693704            << " Free variable " "  is unconstrained. " 
694705            << " Use a 'where' clause to set its range." 
695706      }
@@ -737,6 +748,7 @@ void translateComprehension(
737748  stage.reorder (loop_nest);
738749}
739750
751+ //  translate a semantically checked TC def to Halide components struct
740752HalideComponents translateDef (const  lang::Def& def, bool  throwWarnings) {
741753  map<string, Function> funcs;
742754  HalideComponents components;
@@ -956,6 +968,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) {
956968      lang::Def (lang::Sema ().checkFunction (treeRef)), throwWarnings);
957969}
958970
971+ //  NOTE: there is no guarantee here that the tc string has only one def. It
972+ //  could have many defs. Only first def will be converted in that case.
959973HalideComponents
960974translate (isl::ctx ctx, const  std::string& tc, bool  throwWarnings) {
961975  LOG_IF (INFO, tc::FLAGS_debug_halide) << tc;
0 commit comments