From e143ff2e54cabdbf18ba5f029d8ac705003c2b03 Mon Sep 17 00:00:00 2001 From: prigoyal Date: Tue, 1 May 2018 12:01:47 -0700 Subject: [PATCH 1/6] add some documentation to sema/lexer/parser and tc2halide --- tc/core/halide2isl.cc | 5 +++-- tc/core/tc2halide.cc | 9 +++++++++ tc/core/tc2halide.h | 4 ++-- tc/lang/sema.h | 20 ++++++++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/tc/core/halide2isl.cc b/tc/core/halide2isl.cc index 1319b97a5..4d4755110 100644 --- a/tc/core/halide2isl.cc +++ b/tc/core/halide2isl.cc @@ -34,7 +34,8 @@ using namespace tc::polyhedral::detail; SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) { // const Stmt& s) { - // Collect and categorize all the Variable symbols + // Collect and categorize all the Halide Variable symbols as reduction + // or index variables class BuildSymbolTable : public IRVisitor { using IRVisitor::visit; std::set included; @@ -59,7 +60,7 @@ SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) { components.stmt.accept(&builder); // Get params from components.params which contain everything declared in - // tcdef. However, the 0-D tensors are registered as both params and inputs, + // TC Def. However, the 0-D tensors are registered as both params and inputs, // filter those out. for (auto kvp : components.params) { bool skip = false; diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 5dd216407..2cd8d79b1 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -62,10 +62,13 @@ Type translateScalarType(int tcType) { } } +// Translate the TC def input params to corresponding Halide components. +// params, inputs will be populated here. void translateParam( const lang::Param& p, map* params, vector* inputs) { + // Check if the param has already been converted to halide components. if (params->find(p.ident().name()) != params->end()) { return; } else { @@ -488,6 +491,9 @@ Expr reductionUpdate(Expr e) { return Call::make(e.type(), kReductionUpdate, {e}, Call::Intrinsic); } +// Translate a single TC comprehension/statement to Halide components: funcs, +// bounds, reductions. +// // Note that the function definitions created by translateComprehension may // contain kReductionUpdate intrinsics. These may have to be removed // in order to be able to apply internal Halide analysis passes on them. @@ -736,6 +742,7 @@ void translateComprehension( stage.reorder(loop_nest); } +// Translate a semantically checked TC def to HalideComponents struct. HalideComponents translateDef(const lang::Def& def, bool throwWarnings) { map funcs; HalideComponents components; @@ -895,6 +902,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) { lang::Def(lang::Sema().checkFunction(treeRef)), throwWarnings); } +// NOTE: there is no guarantee here that the tc string has only one def. It +// could have many defs. Only first def will be converted in that case. HalideComponents translate(isl::ctx ctx, const std::string& tc, bool throwWarnings) { LOG_IF(INFO, tc::FLAGS_debug_halide) << tc; diff --git a/tc/core/tc2halide.h b/tc/core/tc2halide.h index 1c2c33248..aab822072 100644 --- a/tc/core/tc2halide.h +++ b/tc/core/tc2halide.h @@ -27,8 +27,8 @@ namespace tc2halide { // of the input and output tensors. We do not explicitly enumerate the // scalar params. struct HalideComponents { - lang::TreeRef - def; // post-semantic analaysis tree, used for later error reporting + // post-semantic analaysis tree, used for later error reporting + lang::TreeRef def; Halide::Internal::Stmt stmt; std::vector inputs; std::map params; diff --git a/tc/lang/sema.h b/tc/lang/sema.h index 3b84afe75..4330ec545 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -339,6 +339,21 @@ struct Sema { throw ErrorReport(exp) << "NYI - semantic checking for " << exp; } } + // This is the entry function for semantic analysis. It is called by + // tc2halide to associate type with each node of the tree and to also make + // sure that the tree is sematically correct. For example: a variable + // may not be input two times. Parser only verifies for the syntax but does + // not check the semantics. + // + // It converts the TK_APPLY nodes to TK_ACCESS or TK_BUILT_IN + // + // The reduction variables are deduced and the objects are created for them + // and they are appended to the tree + // + // Type checking is also done by small amount of code + // + // The method 'withType' can be used to associate the type with a given node + // TreeRef checkFunction(TreeRef func_) { auto func = Def(func_); auto params_ = @@ -350,6 +365,10 @@ struct Sema { } } + // Everything has to be input or output. Keep track of the variables that + // are either input/output. We will check that the statements have variables + // from this list. If not, then throw error that temporaries are not yet + // implemented. for (auto p : func.params()) { nonTemporaries.insert(p.ident().name()); inputParameters.insert(p.ident().name()); @@ -437,6 +456,7 @@ struct Sema { return checkRangeConstraint(RangeConstraint(ref)); } } + // Semantic checking for the statements/comprehensions in a TC Def. TreeRef checkStmt(TreeRef stmt_) { auto stmt = Comprehension(stmt_); From 71539d19303f5e3ce133071b3b8b22ea3c05608b Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:26:34 -0700 Subject: [PATCH 2/6] remove unnecessary else in if-else given early exit the 'if' condition is checked and 'return' happens if the condition is met. Using 'else' is not needed --- tc/core/tc2halide.cc | 51 ++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 2cd8d79b1..0146a806f 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -71,37 +71,36 @@ void translateParam( // Check if the param has already been converted to halide components. if (params->find(p.ident().name()) != params->end()) { return; - } else { - lang::TensorType type = p.tensorType(); - int dimensions = (int)type.dims().size(); - ImageParam imageParam( - translateScalarType(type.scalarType()), dimensions, p.ident().name()); - inputs->push_back(imageParam); - vector dims; - for (auto d_ : type.dims()) { - if (d_->kind() == lang::TK_IDENT) { - auto d = lang::Ident(d_); - auto it = params->find(d.name()); - Parameter p; - if (it != params->end()) { - p = it->second; - } else { - p = Parameter(Int(32), false, 0, d.name(), true); - (*params)[d.name()] = p; - } - dims.push_back(Variable::make(Int(32), p.name(), p)); + } + lang::TensorType type = p.tensorType(); + int dimensions = (int)type.dims().size(); + ImageParam imageParam( + translateScalarType(type.scalarType()), dimensions, p.ident().name()); + inputs->push_back(imageParam); + vector dims; + for (auto d_ : type.dims()) { + if (d_->kind() == lang::TK_IDENT) { + auto d = lang::Ident(d_); + auto it = params->find(d.name()); + Parameter p; + if (it != params->end()) { + p = it->second; } else { - CHECK(d_->kind() == lang::TK_CONST); - int32_t value = lang::Const(d_).value(); - dims.push_back(Expr(value)); + p = Parameter(Int(32), false, 0, d.name(), true); + (*params)[d.name()] = p; } + dims.push_back(Variable::make(Int(32), p.name(), p)); + } else { + CHECK(d_->kind() == lang::TK_CONST); + int32_t value = lang::Const(d_).value(); + dims.push_back(Expr(value)); } + } - for (int i = 0; i < imageParam.dimensions(); i++) { - imageParam.dim(i).set_bounds(0, dims[i]); - } - (*params)[imageParam.name()] = imageParam.parameter(); + for (int i = 0; i < imageParam.dimensions(); i++) { + imageParam.dim(i).set_bounds(0, dims[i]); } + (*params)[imageParam.name()] = imageParam.parameter(); } void translateOutput( From 8d0781caf42278dedb9b54c72b6aea38532711f7 Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:31:30 -0700 Subject: [PATCH 3/6] improve code readability proper variable naming and whitelining between functions --- tc/core/tc2halide.cc | 34 +++++++++++++++++---------------- tc/lang/sema.h | 45 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 23 deletions(-) diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 0146a806f..c5eed3367 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -497,18 +497,18 @@ Expr reductionUpdate(Expr e) { // contain kReductionUpdate intrinsics. These may have to be removed // in order to be able to apply internal Halide analysis passes on them. void translateComprehension( - const lang::Comprehension& c, + const lang::Comprehension& comprehension, const map& params, bool throwWarnings, map* funcs, FunctionBounds* bounds) { Function f; - auto it = funcs->find(c.ident().name()); + auto it = funcs->find(comprehension.ident().name()); if (it != funcs->end()) { f = it->second; } else { - f = Function(c.ident().name()); - (*funcs)[c.ident().name()] = f; + f = Function(comprehension.ident().name()); + (*funcs)[comprehension.ident().name()] = f; } // Function is the internal Halide IR type for a pipeline // stage. Func is the front-end class that wraps it. Here it's @@ -517,7 +517,7 @@ void translateComprehension( vector lhs; vector lhs_as_exprs; - for (lang::Ident id : c.indices()) { + for (lang::Ident id : comprehension.indices()) { lhs.push_back(Var(id.name())); lhs_as_exprs.push_back(lhs.back()); } @@ -526,17 +526,17 @@ void translateComprehension( // in the future we may consider using Halide Let bindings when they // are supported later map lets; - for (auto wc : c.whereClauses()) { + for (auto wc : comprehension.whereClauses()) { if (wc->kind() == lang::TK_LET) { auto let = lang::Let(wc); lets[let.name().name()] = translateExpr(let.rhs(), params, *funcs, lets); } } - Expr rhs = translateExpr(c.rhs(), params, *funcs, lets); + Expr rhs = translateExpr(comprehension.rhs(), params, *funcs, lets); std::vector all_exprs; - for (auto wc : c.whereClauses()) { + for (auto wc : comprehension.whereClauses()) { if (wc->kind() == lang::TK_EXISTS) { all_exprs.push_back( translateExpr(lang::Exists(wc).exp(), params, *funcs, lets)); @@ -560,7 +560,7 @@ void translateComprehension( // values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity // for the reduction and then applies the reduction. bool should_zero = false; - switch (c.assignment()->kind()) { + switch (comprehension.assignment()->kind()) { case lang::TK_PLUS_EQ_B: should_zero = true; // fallthrough case lang::TK_PLUS_EQ: @@ -592,12 +592,13 @@ void translateComprehension( case '=': break; default: - throw lang::ErrorReport(c) << "Unimplemented reduction " - << c.assignment()->range().text() << "\n"; + throw lang::ErrorReport(comprehension) + << "Unimplemented reduction " + << comprehension.assignment()->range().text() << "\n"; } // Tag reductions as such - if (c.assignment()->kind() != '=') { + if (comprehension.assignment()->kind() != '=') { rhs = reductionUpdate(rhs); } @@ -637,7 +638,7 @@ void translateComprehension( Scope solution; // Put anything explicitly specified with a 'where' class in the solution - for (auto constraint_ : c.whereClauses()) { + for (auto constraint_ : comprehension.whereClauses()) { if (constraint_->kind() != lang::TK_RANGE_CONSTRAINT) continue; auto constraint = lang::RangeConstraint(constraint_); @@ -658,7 +659,8 @@ void translateComprehension( // Infer the rest all_exprs.push_back(rhs); - forwardBoundsInference(all_exprs, *bounds, c, throwWarnings, &solution); + forwardBoundsInference( + all_exprs, *bounds, comprehension, throwWarnings, &solution); // TODO: What if subsequent updates have incompatible bounds // (e.g. an in-place stencil)?. The .bound directive will use the @@ -669,7 +671,7 @@ void translateComprehension( for (Var v : lhs) { if (!solution.contains(v.name())) { - throw lang::ErrorReport(c) + throw lang::ErrorReport(comprehension) << "Free variable " << v << " was not solved in range inference. May not be used right-hand side"; } @@ -693,7 +695,7 @@ void translateComprehension( for (size_t i = 0; i < unbound.size(); i++) { auto v = unbound[unbound.size() - 1 - i]; if (!solution.contains(v->name)) { - throw lang::ErrorReport(c) + throw lang::ErrorReport(comprehension) << "Free variable " << v << " is unconstrained. " << "Use a 'where' clause to set its range."; } diff --git a/tc/lang/sema.h b/tc/lang/sema.h index 4330ec545..dabec7330 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -166,6 +166,7 @@ struct Sema { } return expr_to_type.at(ref); } + // associate a type with this expression TreeRef withType(TreeRef expr, TreeRef type) { auto inserted = expr_to_type.emplace(expr, type).second; @@ -179,6 +180,7 @@ struct Sema { } return TensorType(typ); } + TreeRef matchAllTypes(TreeRef list, TreeRef matched_type = nullptr) { for (auto e : list->trees()) { if (!matched_type) @@ -188,6 +190,7 @@ struct Sema { } return matched_type; } + TreeRef expectIntegral(TreeRef e) { if (TypeInfo(typeOfExpr(e)).code() == TypeInfo::Float) { throw ErrorReport(e) << " expected integral type but found " @@ -195,16 +198,19 @@ struct Sema { } return e; } + void expectBool(TreeRef anchor, int token) { if (token != TK_BOOL) { throw ErrorReport(anchor) << "expected boolean but found " << kindToString(token); } } + TreeRef expectBool(TreeRef exp) { expectBool(exp, typeOfExpr(exp)->kind()); return exp; } + TreeRef lookupVarOrCreateIndex(Ident ident) { TreeRef type = lookup(ident, false); if (!type) { @@ -216,6 +222,7 @@ struct Sema { } return type; } + TreeRef checkExp(TreeRef exp, bool allow_access) { switch (exp->kind()) { case TK_APPLY: { @@ -339,6 +346,7 @@ struct Sema { throw ErrorReport(exp) << "NYI - semantic checking for " << exp; } } + // This is the entry function for semantic analysis. It is called by // tc2halide to associate type with each node of the tree and to also make // sure that the tree is sematically correct. For example: a variable @@ -352,7 +360,7 @@ struct Sema { // // Type checking is also done by small amount of code // - // The method 'withType' can be used to associate the type with a given node + // The method 'withType' is used to associate the type with a given node // TreeRef checkFunction(TreeRef func_) { auto func = Def(func_); @@ -385,21 +393,27 @@ struct Sema { Def::create(func.range(), func.name(), params_, returns_, statements_); return r; } + TreeRef indexType(TreeRef anchor) { - return c(TK_INT32, anchor->range(), {}); + return createCompound(TK_INT32, anchor->range(), {}); } + TreeRef dimType(TreeRef anchor) { return indexType(anchor); } + TreeRef floatType(TreeRef anchor) { - return c(TK_FLOAT, anchor->range(), {}); + return createCompound(TK_FLOAT, anchor->range(), {}); } + TreeRef boolType(TreeRef anchor) { - return c(TK_BOOL, anchor->range(), {}); + return createCompound(TK_BOOL, anchor->range(), {}); } + void checkDim(Ident dim) { insert(env, dim, dimType(dim), false); } + TreeRef checkTensorType(TreeRef type) { auto tt = TensorType(type); for (const auto& d : tt.dims()) { @@ -409,6 +423,7 @@ struct Sema { } return type; } + TreeRef checkParam(TreeRef param) { auto p = Param(param); TreeRef type_ = checkTensorType(p.type()); @@ -416,11 +431,13 @@ struct Sema { live_input_names.insert(p.ident().name()); return param; } + TreeRef checkReturn(TreeRef ret) { auto r = Param(ret); TreeRef real_type = lookup(env, r.ident(), true); return ret; } + TreeRef checkList(TreeRef list, std::function fn) { TC_ASSERT(list, list->kind() == TK_LIST); TreeList r; @@ -429,6 +446,7 @@ struct Sema { } return List::create(list->range(), std::move(r)); } + TreeRef checkRangeConstraint(RangeConstraint rc) { // RCs are checked _before_ the rhs of the TC, so // it is possible the index is not in the environment yet @@ -441,11 +459,13 @@ struct Sema { auto e = expectIntegral(checkExp(rc.end(), false)); return RangeConstraint::create(rc.range(), rc.ident(), s, e); } + TreeRef checkLet(Let l) { auto rhs = checkExp(l.rhs(), true); insert(let_env, l.name(), typeOfExpr(rhs), true); return Let::create(l.range(), l.name(), rhs); } + TreeRef checkWhereClause(TreeRef ref) { if (ref->kind() == TK_LET) { return checkLet(Let(ref)); @@ -456,6 +476,7 @@ struct Sema { return checkRangeConstraint(RangeConstraint(ref)); } } + // Semantic checking for the statements/comprehensions in a TC Def. TreeRef checkStmt(TreeRef stmt_) { auto stmt = Comprehension(stmt_); @@ -467,11 +488,13 @@ struct Sema { insert(index_env, index, typ, true); } - // make dimension variables for each dimension of the output tensor + // check that the input is not used for output - inputs are immutable std::string name = stmt.ident().name(); if (inputParameters.count(name) > 0) { throw ErrorReport(stmt_) << "TC inputs are immutable"; } + + // make dimension variables for each dimension of the output tensor TreeList output_indices; int n = stmt.indices().size(); for (int i = 0; i < n; ++i) { @@ -578,6 +601,7 @@ struct Sema { return result; } + static bool isUninitializedReductionOperation(TreeRef assignment) { switch (assignment->kind()) { case TK_PLUS_EQ: @@ -589,6 +613,7 @@ struct Sema { return false; } } + bool isNotInplace(TreeRef assignment) { switch (assignment->kind()) { case TK_PLUS_EQ_B: @@ -600,6 +625,7 @@ struct Sema { return false; } } + std::string dumpEnv() { std::stringstream ss; std::vector> elems(env.begin(), env.end()); @@ -618,6 +644,7 @@ struct Sema { private: using Env = std::unordered_map; + void insert(Env& the_env, Ident ident, TreeRef value, bool must_be_undefined) { std::string name = ident.name(); @@ -630,6 +657,7 @@ struct Sema { throw ErrorReport(ident) << name << " already defined"; } } + TreeRef lookup(Ident ident, bool required) { TreeRef v = lookup(index_env, ident, false); if (!v) @@ -638,6 +666,7 @@ struct Sema { v = lookup(env, ident, required); return v; } + TreeRef lookup(Env& the_env, Ident ident, bool required) { std::string name = ident.name(); auto it = the_env.find(name); @@ -647,10 +676,12 @@ struct Sema { } return it == the_env.end() ? nullptr : it->second; } - TreeRef c(int kind, const SourceRange& range, TreeList&& trees) { + + TreeRef createCompound(int kind, const SourceRange& range, TreeList&& trees) { return Compound::create(kind, range, std::move(trees)); } - TreeRef s(const std::string& s) { + + TreeRef createString(const std::string& s) { return String::create(s); } From fb956bd6b6a45d353ffd224ab5af201f05932f59 Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:32:49 -0700 Subject: [PATCH 4/6] fix segfault in halide2isl for mod op change makeIslAffBoundsFromExpr(..., e, ...) to makeIslAffBoundsFromExpr(..., op->a, ...) to avoid infinite recursion leading to segfault --- tc/core/halide2isl.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tc/core/halide2isl.cc b/tc/core/halide2isl.cc index 4d4755110..c99b221e6 100644 --- a/tc/core/halide2isl.cc +++ b/tc/core/halide2isl.cc @@ -203,7 +203,7 @@ std::vector makeIslAffBoundsFromExpr( std::vector result; // We cannot span multiple constraints if a modulo operation is involved. // x > max(a,b) % C is not equivalent to (x > a % C && x > b % C). - auto lhs = makeIslAffBoundsFromExpr(space, e, false, false); + auto lhs = makeIslAffBoundsFromExpr(space, op->a, false, false); CHECK_EQ(lhs.size(), 1u); if (const int64_t* b = as_const_int(op->b)) { return {lhs[0].mod(isl::val(space.get_ctx(), *b))}; From 6e9da0d687595229449ad5cf9c4217f0ec7c029f Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:35:49 -0700 Subject: [PATCH 5/6] add support for modulo operator support % operator: propagate it from parser to halide to isl and add unit tests --- tc/core/tc2halide.cc | 2 ++ tc/lang/lexer.h | 4 ++-- tc/lang/sema.h | 1 + tc/lang/test_expected/math.expected | 4 +++- test/cuda/test_corner_cases.cc | 12 +++++++++++ test/test_cuda_mapper.cc | 33 +++++++++++++++++++++++++++++ test/test_lang.cc | 2 +- 7 files changed, 54 insertions(+), 4 deletions(-) diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index c5eed3367..76adea3d5 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -158,6 +158,8 @@ Expr translateExpr( return t(0) * t(1); case '/': return t(0) / t(1); + case '%': + return t(0) % t(1); case lang::TK_MIN: return min(t(0), t(1)); case lang::TK_MAX: diff --git a/tc/lang/lexer.h b/tc/lang/lexer.h index d9bbcd9b2..9e40a3092 100644 --- a/tc/lang/lexer.h +++ b/tc/lang/lexer.h @@ -87,7 +87,7 @@ namespace lang { _(TK_LET, "let", "") \ _(TK_EXISTS, "exists", "exists") -static const char* valid_single_char_tokens = "+-*/()[]?:,={}>', '<', TK_LE, TK_GE, TK_EQ, TK_NE}, {'+', '-'}, - {'*', '/'}, + {'*', '/', '%'}, }; std::vector> unary_ops = { {'-', '!'}, diff --git a/tc/lang/sema.h b/tc/lang/sema.h index dabec7330..8c54b57be 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -293,6 +293,7 @@ struct Sema { case '-': case '*': case '/': + case '%': case TK_MIN: case TK_MAX: { auto nexp = diff --git a/tc/lang/test_expected/math.expected b/tc/lang/test_expected/math.expected index 8b21f02b1..700d8d7d8 100644 --- a/tc/lang/test_expected/math.expected +++ b/tc/lang/test_expected/math.expected @@ -1,7 +1,9 @@ (- (+ (+ - (- (const 3 (int32))) + (% + (- (const 3 (int32))) + (const 2 (int32))) (* (const 4 (int32)) (const 5 (int32)))) diff --git a/test/cuda/test_corner_cases.cc b/test/cuda/test_corner_cases.cc index 1c0aab1d6..5c0aa2fe8 100644 --- a/test/cuda/test_corner_cases.cc +++ b/test/cuda/test_corner_cases.cc @@ -331,6 +331,18 @@ TEST(TestCornerCases, E25) { {b}); } +TEST(TestCornerCases, E26) { + auto a = I(); + auto b = I(); + auto r = I(1); + Succeed( + "def f(int32 a, int32 b) -> (c) { c(i) = int32(a % b) where i in 0:1 }", + {a, b}, + {r}); + auto e = at::Scalar(a).toInt() % at::Scalar(b).toInt(); + CHECK_EQ(at::Scalar(r[0]).toInt(), e); +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true); diff --git a/test/test_cuda_mapper.cc b/test/test_cuda_mapper.cc index 915815403..0eccb60bc 100644 --- a/test/test_cuda_mapper.cc +++ b/test/test_cuda_mapper.cc @@ -1118,6 +1118,39 @@ TEST_F(PolyhedralMapperTest, EmptyMappingFilter) { mscop->codegen(specializedName); } +TEST_F(PolyhedralMapperTest, ModulusConstantRHS) { + string tc = R"TC( +def fun(float(N) a) -> (b) { b(i) = a(i % 3) where i in 0:N } +)TC"; + // This triggers tc2halide conversion and should not throw. + auto scop = Prepare(tc); + for (auto r : scop->reads.wrap().get_set_list()) { + auto read = r.unwrap(); + // skip irrelevant reads, if any + if (read.range().get_tuple_name() != std::string("a")) { + continue; + } + EXPECT_EQ(r.get_stride(0), 3); + } +} + +TEST_F(PolyhedralMapperTest, ModulusVariableRHS) { + string tc = R"TC( +def local_sparse_convolution(float(N, C, H, W) I, float(O, KC, KH, KW) W1) -> (O1) { + O1(n, o, h, w) +=! I(n, kc % c, h + kh, w + kw) * W1(o, kc, kh, kw) where c in 1:C +} +)TC"; + // This triggers tc2halide conversion and should not throw. + auto scop = Prepare(tc); + for (auto r : scop->reads.range().get_set_list()) { + // skip irrelevant reads, if any + if (r.get_tuple_name() != std::string("I")) { + continue; + } + EXPECT_TRUE(r.plain_is_universe()); + } +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true); diff --git a/test/test_lang.cc b/test/test_lang.cc index 4596d6ad0..c1c2b1902 100644 --- a/test/test_lang.cc +++ b/test/test_lang.cc @@ -232,7 +232,7 @@ int main(int argc, char** argv) { ASSERT(s->tree(0)->stringValue() == "min"); } { - std::string stuff = "-3+4*5+7-a"; + std::string stuff = "-3%2+4*5+7-a"; Parser p(stuff); auto r = p.parseExp(); std::stringstream ss; From b36f466c180ffd729f506caef7c3610a6a3d7680 Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:36:40 -0700 Subject: [PATCH 6/6] sema createString cleanup dead function and not used anywhere so cleaning it up --- tc/lang/sema.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tc/lang/sema.h b/tc/lang/sema.h index 8c54b57be..03c7a9385 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -682,10 +682,6 @@ struct Sema { return Compound::create(kind, range, std::move(trees)); } - TreeRef createString(const std::string& s) { - return String::create(s); - } - std::vector reduction_variables; // per-statement Env index_env; // per-statement Env let_env; // per-statement, used for where i =