@@ -493,55 +493,82 @@ struct Sema {
493493
494494 // Semantic checking for the statements/comprehensions in a TC Def.
495495 TreeRef checkStmt (TreeRef stmt_) {
496- auto stmt = Comprehension (stmt_);
496+ if (stmt_->kind () == TK_COMPREHENSION) {
497+ return checkComprehension (Comprehension (stmt_));
498+ }
499+ return checkFor (For (stmt_));
500+ }
497501
502+ TreeRef checkFor (For f) {
503+ if (lookup (f.index (), false )) {
504+ throw ErrorReport (f) << " For loop index already defined" ;
505+ }
506+ TreeList stmts;
507+ for (auto s : f.statements ()) {
508+ if (s->kind () != TK_COMPREHENSION) {
509+ throw ErrorReport (s) << " Nested \" for\" loops NYI" ;
510+ }
511+ stmts.push_back (checkComprehension (Comprehension (s)));
512+ }
513+ // Check the range constraint after all statements
514+ // This way we don't need extra state to track indices coming from loops
515+ // that may have already been defined.
516+ checkRangeConstraint (f.rangeConstraint ());
517+ return For::create (
518+ f.range (),
519+ f.index (),
520+ f.rangeConstraint (),
521+ List::create (f.range (), std::move (stmts)));
522+ }
523+
524+ TreeRef checkComprehension (Comprehension comp) {
498525 // register index variables (non-reductions)
499- for (const auto & index : stmt .indices ()) {
526+ for (const auto & index : comp .indices ()) {
500527 std::string idx = index.name ();
501528 auto typ = indexType (index);
502529 insert (index_env, index, typ, true );
503530 }
504531
505532 // check that the input is not used for output - inputs are immutable
506- std::string name = stmt .ident ().name ();
533+ std::string name = comp .ident ().name ();
507534 if (inputParameters.count (name) > 0 ) {
508- throw ErrorReport (stmt_ ) << " TC inputs are immutable" ;
535+ throw ErrorReport (comp ) << " TC inputs are immutable" ;
509536 }
510537
511538 // make dimension variables for each dimension of the output tensor
512539 TreeList output_indices;
513- int n = stmt .indices ().size ();
540+ int n = comp .indices ().size ();
514541 for (int i = 0 ; i < n; ++i) {
515542 auto new_var =
516- Ident::create (stmt .range (), name + " ." + std::to_string (i));
543+ Ident::create (comp .range (), name + " ." + std::to_string (i));
517544 output_indices.push_back (new_var);
518545 }
519546
520547 // where clauses are checked _before_ the rhs because they
521548 // introduce let bindings that are in scope for the rhs
522- auto where_clauses_ = stmt .whereClauses ().map (
549+ auto where_clauses_ = comp .whereClauses ().map (
523550 [&](TreeRef rc) { return checkWhereClause (rc); });
524551
525- TreeRef rhs_ = checkExp (stmt .rhs (), true );
552+ TreeRef rhs_ = checkExp (comp .rhs (), true );
526553 TreeRef scalar_type = typeOfExpr (rhs_);
527554
528555 // if this statement will be returned and it is annotated in the return list
529556 // with a type (e.g. float(A,B)) then force the tensor to be that type
530557 // and check that the number of dimensions are consistent
531- auto output_annotation = annotated_output_types.find (stmt .ident ().name ());
558+ auto output_annotation = annotated_output_types.find (comp .ident ().name ());
532559 if (output_annotation != annotated_output_types.end ()) {
533560 auto tt = TensorType (output_annotation->second );
534561 auto matched_type = match_types (scalar_type, tt.scalarTypeTree ());
535562 if (tt.scalarTypeTree ()->kind () != matched_type->kind ()) {
536- throw ErrorReport (stmt )
563+ throw ErrorReport (comp )
537564 << " attempting to assign type "
538565 << kindToString (scalar_type->kind ()) << " to narrower type "
539566 << kindToString (tt.scalarTypeTree ()->kind ())
540567 << " without an explicit cast" ;
541568 }
542- if (tt.dims ().size () != stmt .indices ().size ()) {
543- throw ErrorReport (stmt )
544- << " tensor defined with " << stmt .indices ().size ()
569+ if (tt.dims ().size () != comp .indices ().size ()) {
570+ throw ErrorReport (comp )
571+ << " tensor defined with " << comp .indices ().size ()
545572 << " dimensions but declared as an output with " << tt.dims ().size ()
546573 << " dimensions." ;
547574 }
@@ -550,33 +577,33 @@ struct Sema {
550577 // After checking rhs and before creating lhs, we check if it is a reduction
551578 // without initialization (i.e., reduction operator without "!" suffix, and
552579 // lhs not defined previously).
553- if (isUninitializedReductionOperation (stmt .assignment ()) &&
554- nullptr == lookup (stmt .ident (), false )) {
555- ErrorReport err (stmt );
556- std::string tk = kindToToken (stmt .assignment ()->kind ());
557- err << " Reduction without initialization. If " << stmt .ident ().name ()
580+ if (isUninitializedReductionOperation (comp .assignment ()) &&
581+ nullptr == lookup (comp .ident (), false )) {
582+ ErrorReport err (comp );
583+ std::string tk = kindToToken (comp .assignment ()->kind ());
584+ err << " Reduction without initialization. If " << comp .ident ().name ()
558585 << " is not pre-initialized before calling the TC function,"
559586 << " consider using the !-suffixed reduction operator " << tk
560587 << " ! instead of " << tk;
561588 warn (err);
562589 }
563590
564591 auto type = TensorType::create (
565- stmt .range (),
592+ comp .range (),
566593 scalar_type,
567- List::create (stmt .range (), std::move (output_indices)));
568- insert (env, stmt .ident (), type, false );
594+ List::create (comp .range (), std::move (output_indices)));
595+ insert (env, comp .ident (), type, false );
569596
570597 // if we redefined an input, it is no longer valid for range expressions
571- live_input_names.erase (stmt .ident ().name ());
598+ live_input_names.erase (comp .ident ().name ());
572599
573- auto equivalent_statement_ = stmt .equivalent ().map ([&](Equivalent eq) {
600+ auto equivalent_statement_ = comp .equivalent ().map ([&](Equivalent eq) {
574601 auto indices_ = eq.accesses ().map (
575602 [&](TreeRef index) { return checkExp (index, true ); });
576603 return Equivalent::create (eq.range (), eq.name (), indices_);
577604 });
578605
579- TreeRef assignment = stmt .assignment ();
606+ TreeRef assignment = comp .assignment ();
580607 // For semantic consistency we allow overwriting reductions like +=!
581608 // to be used in the language when there are no actual reduction dimensions.
582609 // Later compile stages assume that there is at least one reduction
@@ -586,26 +613,26 @@ struct Sema {
586613 assignment = Compound::create (' =' , assignment->range (), {});
587614 }
588615
589- if (reduction_variables.size () > 0 && stmt .assignment ()->kind () == ' =' ) {
590- throw ErrorReport (stmt ) << " this statement includes reduction variable '"
616+ if (reduction_variables.size () > 0 && comp .assignment ()->kind () == ' =' ) {
617+ throw ErrorReport (comp ) << " this statement includes reduction variable '"
591618 << Ident (reduction_variables.back ()).name ()
592619 << " ' but does not specify a reduction." ;
593620 }
594621 TreeRef reduction_variable_list =
595- List::create (stmt .ident ().range (), std::move (reduction_variables));
622+ List::create (comp .ident ().range (), std::move (reduction_variables));
596623 TreeRef result = Comprehension::create (
597- stmt .range (),
598- stmt .ident (),
599- stmt .indices (),
600- stmt .assignment (),
624+ comp .range (),
625+ comp .ident (),
626+ comp .indices (),
627+ comp .assignment (),
601628 rhs_,
602629 where_clauses_,
603630 equivalent_statement_,
604631 reduction_variable_list);
605632
606- if (nonTemporaries.count (stmt .ident ().name ()) == 0 ) {
607- throw ErrorReport (stmt )
608- << stmt .ident ().name ()
633+ if (nonTemporaries.count (comp .ident ().name ()) == 0 ) {
634+ throw ErrorReport (comp )
635+ << comp .ident ().name ()
609636 << " is not listed as an input or output to this function. Temporaries tensors are not yet implemented" ;
610637 }
611638
0 commit comments