-
Couldn't load subscription status.
- Fork 213
Add % operator and propagate it from parser -> halide -> isl #348
Changes from all commits
e143ff2
71539d1
8d0781c
fb956bd
6e9da0d
b36f466
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,43 +62,45 @@ Type translateScalarType(int tcType) { | |
| } | ||
| } | ||
|
|
||
| // Translate the TC def input params to corresponding Halide components. | ||
| // params, inputs will be populated here. | ||
| void translateParam( | ||
| const lang::Param& p, | ||
| map<string, Parameter>* params, | ||
| vector<ImageParam>* inputs) { | ||
| // Check if the param has already been converted to halide components. | ||
| if (params->find(p.ident().name()) != params->end()) { | ||
| return; | ||
| } else { | ||
| lang::TensorType type = p.tensorType(); | ||
| int dimensions = (int)type.dims().size(); | ||
| ImageParam imageParam( | ||
| translateScalarType(type.scalarType()), dimensions, p.ident().name()); | ||
| inputs->push_back(imageParam); | ||
| vector<Expr> dims; | ||
| for (auto d_ : type.dims()) { | ||
| if (d_->kind() == lang::TK_IDENT) { | ||
| auto d = lang::Ident(d_); | ||
| auto it = params->find(d.name()); | ||
| Parameter p; | ||
| if (it != params->end()) { | ||
| p = it->second; | ||
| } else { | ||
| p = Parameter(Int(32), false, 0, d.name(), true); | ||
| (*params)[d.name()] = p; | ||
| } | ||
| dims.push_back(Variable::make(Int(32), p.name(), p)); | ||
| } | ||
| lang::TensorType type = p.tensorType(); | ||
| int dimensions = (int)type.dims().size(); | ||
| ImageParam imageParam( | ||
| translateScalarType(type.scalarType()), dimensions, p.ident().name()); | ||
| inputs->push_back(imageParam); | ||
| vector<Expr> dims; | ||
| for (auto d_ : type.dims()) { | ||
| if (d_->kind() == lang::TK_IDENT) { | ||
| auto d = lang::Ident(d_); | ||
| auto it = params->find(d.name()); | ||
| Parameter p; | ||
| if (it != params->end()) { | ||
| p = it->second; | ||
| } else { | ||
| CHECK(d_->kind() == lang::TK_CONST); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This hoisting looks like a behavior change to me, could you please explain why it is necessary or correlated with this commit message? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the review. I am not sure I understand. There is no code change to this section except removing unnecessary else after line 74. line 88 - 98 in the green section matches line 85 - 95 in red section. I think that github alignment for code changes is not exactly good and hence causing confusion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. diff computation alogrithm is not perfect; reviewing changes per-commit helps There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes, you're right, I see there are 2 columns of numbers. I learned something today! Sorry for the noise here |
||
| int32_t value = lang::Const(d_).value(); | ||
| dims.push_back(Expr(value)); | ||
| p = Parameter(Int(32), false, 0, d.name(), true); | ||
| (*params)[d.name()] = p; | ||
| } | ||
| dims.push_back(Variable::make(Int(32), p.name(), p)); | ||
| } else { | ||
| CHECK(d_->kind() == lang::TK_CONST); | ||
| int32_t value = lang::Const(d_).value(); | ||
| dims.push_back(Expr(value)); | ||
| } | ||
| } | ||
|
|
||
| for (int i = 0; i < imageParam.dimensions(); i++) { | ||
| imageParam.dim(i).set_bounds(0, dims[i]); | ||
| } | ||
| (*params)[imageParam.name()] = imageParam.parameter(); | ||
| for (int i = 0; i < imageParam.dimensions(); i++) { | ||
| imageParam.dim(i).set_bounds(0, dims[i]); | ||
| } | ||
| (*params)[imageParam.name()] = imageParam.parameter(); | ||
| } | ||
|
|
||
| void translateOutput( | ||
|
|
@@ -156,6 +158,8 @@ Expr translateExpr( | |
| return t(0) * t(1); | ||
| case '/': | ||
| return t(0) / t(1); | ||
| case '%': | ||
| return t(0) % t(1); | ||
| case lang::TK_MIN: | ||
| return min(t(0), t(1)); | ||
| case lang::TK_MAX: | ||
|
|
@@ -488,22 +492,25 @@ Expr reductionUpdate(Expr e) { | |
| return Call::make(e.type(), kReductionUpdate, {e}, Call::Intrinsic); | ||
| } | ||
|
|
||
| // Translate a single TC comprehension/statement to Halide components: funcs, | ||
| // bounds, reductions. | ||
| // | ||
| // Note that the function definitions created by translateComprehension may | ||
| // contain kReductionUpdate intrinsics. These may have to be removed | ||
| // in order to be able to apply internal Halide analysis passes on them. | ||
| void translateComprehension( | ||
| const lang::Comprehension& c, | ||
| const lang::Comprehension& comprehension, | ||
| const map<string, Parameter>& params, | ||
| bool throwWarnings, | ||
| map<string, Function>* funcs, | ||
| FunctionBounds* bounds) { | ||
| Function f; | ||
| auto it = funcs->find(c.ident().name()); | ||
| auto it = funcs->find(comprehension.ident().name()); | ||
| if (it != funcs->end()) { | ||
| f = it->second; | ||
| } else { | ||
| f = Function(c.ident().name()); | ||
| (*funcs)[c.ident().name()] = f; | ||
| f = Function(comprehension.ident().name()); | ||
| (*funcs)[comprehension.ident().name()] = f; | ||
| } | ||
| // Function is the internal Halide IR type for a pipeline | ||
| // stage. Func is the front-end class that wraps it. Here it's | ||
|
|
@@ -512,7 +519,7 @@ void translateComprehension( | |
|
|
||
| vector<Var> lhs; | ||
| vector<Expr> lhs_as_exprs; | ||
| for (lang::Ident id : c.indices()) { | ||
| for (lang::Ident id : comprehension.indices()) { | ||
| lhs.push_back(Var(id.name())); | ||
| lhs_as_exprs.push_back(lhs.back()); | ||
| } | ||
|
|
@@ -521,17 +528,17 @@ void translateComprehension( | |
| // in the future we may consider using Halide Let bindings when they | ||
| // are supported later | ||
| map<string, Expr> lets; | ||
| for (auto wc : c.whereClauses()) { | ||
| for (auto wc : comprehension.whereClauses()) { | ||
| if (wc->kind() == lang::TK_LET) { | ||
| auto let = lang::Let(wc); | ||
| lets[let.name().name()] = translateExpr(let.rhs(), params, *funcs, lets); | ||
| } | ||
| } | ||
|
|
||
| Expr rhs = translateExpr(c.rhs(), params, *funcs, lets); | ||
| Expr rhs = translateExpr(comprehension.rhs(), params, *funcs, lets); | ||
|
|
||
| std::vector<Expr> all_exprs; | ||
| for (auto wc : c.whereClauses()) { | ||
| for (auto wc : comprehension.whereClauses()) { | ||
| if (wc->kind() == lang::TK_EXISTS) { | ||
| all_exprs.push_back( | ||
| translateExpr(lang::Exists(wc).exp(), params, *funcs, lets)); | ||
|
|
@@ -555,7 +562,7 @@ void translateComprehension( | |
| // values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity | ||
| // for the reduction and then applies the reduction. | ||
| bool should_zero = false; | ||
| switch (c.assignment()->kind()) { | ||
| switch (comprehension.assignment()->kind()) { | ||
| case lang::TK_PLUS_EQ_B: | ||
| should_zero = true; // fallthrough | ||
| case lang::TK_PLUS_EQ: | ||
|
|
@@ -587,12 +594,13 @@ void translateComprehension( | |
| case '=': | ||
| break; | ||
| default: | ||
| throw lang::ErrorReport(c) << "Unimplemented reduction " | ||
| << c.assignment()->range().text() << "\n"; | ||
| throw lang::ErrorReport(comprehension) | ||
| << "Unimplemented reduction " | ||
| << comprehension.assignment()->range().text() << "\n"; | ||
| } | ||
|
|
||
| // Tag reductions as such | ||
| if (c.assignment()->kind() != '=') { | ||
| if (comprehension.assignment()->kind() != '=') { | ||
| rhs = reductionUpdate(rhs); | ||
| } | ||
|
|
||
|
|
@@ -632,7 +640,7 @@ void translateComprehension( | |
| Scope<Interval> solution; | ||
|
|
||
| // Put anything explicitly specified with a 'where' class in the solution | ||
| for (auto constraint_ : c.whereClauses()) { | ||
| for (auto constraint_ : comprehension.whereClauses()) { | ||
| if (constraint_->kind() != lang::TK_RANGE_CONSTRAINT) | ||
| continue; | ||
| auto constraint = lang::RangeConstraint(constraint_); | ||
|
|
@@ -653,7 +661,8 @@ void translateComprehension( | |
|
|
||
| // Infer the rest | ||
| all_exprs.push_back(rhs); | ||
| forwardBoundsInference(all_exprs, *bounds, c, throwWarnings, &solution); | ||
| forwardBoundsInference( | ||
| all_exprs, *bounds, comprehension, throwWarnings, &solution); | ||
|
|
||
| // TODO: What if subsequent updates have incompatible bounds | ||
| // (e.g. an in-place stencil)?. The .bound directive will use the | ||
|
|
@@ -664,7 +673,7 @@ void translateComprehension( | |
|
|
||
| for (Var v : lhs) { | ||
| if (!solution.contains(v.name())) { | ||
| throw lang::ErrorReport(c) | ||
| throw lang::ErrorReport(comprehension) | ||
| << "Free variable " << v | ||
| << " was not solved in range inference. May not be used right-hand side"; | ||
| } | ||
|
|
@@ -688,7 +697,7 @@ void translateComprehension( | |
| for (size_t i = 0; i < unbound.size(); i++) { | ||
| auto v = unbound[unbound.size() - 1 - i]; | ||
| if (!solution.contains(v->name)) { | ||
| throw lang::ErrorReport(c) | ||
| throw lang::ErrorReport(comprehension) | ||
| << "Free variable " << v << " is unconstrained. " | ||
| << "Use a 'where' clause to set its range."; | ||
| } | ||
|
|
@@ -736,6 +745,7 @@ void translateComprehension( | |
| stage.reorder(loop_nest); | ||
| } | ||
|
|
||
| // Translate a semantically checked TC def to HalideComponents struct. | ||
| HalideComponents translateDef(const lang::Def& def, bool throwWarnings) { | ||
| map<string, Function> funcs; | ||
| HalideComponents components; | ||
|
|
@@ -895,6 +905,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) { | |
| lang::Def(lang::Sema().checkFunction(treeRef)), throwWarnings); | ||
| } | ||
|
|
||
| // NOTE: there is no guarantee here that the tc string has only one def. It | ||
| // could have many defs. Only first def will be converted in that case. | ||
| HalideComponents | ||
| translate(isl::ctx ctx, const std::string& tc, bool throwWarnings) { | ||
| LOG_IF(INFO, tc::FLAGS_debug_halide) << tc; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no need for
elsehere since we return from theifcondition if satisfied. cc @abadams, hope this is okay?