Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tc/core/halide2isl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ using namespace tc::polyhedral::detail;

SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) {
// const Stmt& s) {
// Collect and categorize all the Variable symbols
// Collect and categorize all the Halide Variable symbols as reduction
// or index variables
class BuildSymbolTable : public IRVisitor {
using IRVisitor::visit;
std::set<std::string> included;
Expand All @@ -59,7 +60,7 @@ SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) {

components.stmt.accept(&builder);
// Get params from components.params which contain everything declared in
// tcdef. However, the 0-D tensors are registered as both params and inputs,
// TC Def. However, the 0-D tensors are registered as both params and inputs,
// filter those out.
for (auto kvp : components.params) {
bool skip = false;
Expand Down Expand Up @@ -202,7 +203,7 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
std::vector<isl::aff> result;
// We cannot span multiple constraints if a modulo operation is involved.
// x > max(a,b) % C is not equivalent to (x > a % C && x > b % C).
auto lhs = makeIslAffBoundsFromExpr(space, e, false, false);
auto lhs = makeIslAffBoundsFromExpr(space, op->a, false, false);
CHECK_EQ(lhs.size(), 1u);
if (const int64_t* b = as_const_int(op->b)) {
return {lhs[0].mod(isl::val(space.get_ctx(), *b))};
Expand Down
96 changes: 54 additions & 42 deletions tc/core/tc2halide.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,43 +62,45 @@ Type translateScalarType(int tcType) {
}
}

// Translate the TC def input params to corresponding Halide components.
// params, inputs will be populated here.
void translateParam(
const lang::Param& p,
map<string, Parameter>* params,
vector<ImageParam>* inputs) {
// Check if the param has already been converted to halide components.
if (params->find(p.ident().name()) != params->end()) {
return;
} else {
lang::TensorType type = p.tensorType();
int dimensions = (int)type.dims().size();
ImageParam imageParam(
translateScalarType(type.scalarType()), dimensions, p.ident().name());
inputs->push_back(imageParam);
vector<Expr> dims;
for (auto d_ : type.dims()) {
if (d_->kind() == lang::TK_IDENT) {
auto d = lang::Ident(d_);
auto it = params->find(d.name());
Parameter p;
if (it != params->end()) {
p = it->second;
} else {
p = Parameter(Int(32), false, 0, d.name(), true);
(*params)[d.name()] = p;
}
dims.push_back(Variable::make(Int(32), p.name(), p));
}
lang::TensorType type = p.tensorType();
Copy link
Contributor Author

@prigoyal prigoyal Apr 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no need for else here since we return from the if condition if satisfied. cc @abadams, hope this is okay?

int dimensions = (int)type.dims().size();
ImageParam imageParam(
translateScalarType(type.scalarType()), dimensions, p.ident().name());
inputs->push_back(imageParam);
vector<Expr> dims;
for (auto d_ : type.dims()) {
if (d_->kind() == lang::TK_IDENT) {
auto d = lang::Ident(d_);
auto it = params->find(d.name());
Parameter p;
if (it != params->end()) {
p = it->second;
} else {
CHECK(d_->kind() == lang::TK_CONST);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hoisting looks like a behavior change to me, could you please explain why it is necessary or correlated with this commit message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the review. I am not sure I understand. There is no code change to this section except removing unnecessary else after line 74.

line 88 - 98 in the green section matches line 85 - 95 in red section. I think that github alignment for code changes is not exactly good and hence causing confusion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

diff computation alogrithm is not perfect; reviewing changes per-commit helps

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, you're right, I see there are 2 columns of numbers. I learned something today! Sorry for the noise here

int32_t value = lang::Const(d_).value();
dims.push_back(Expr(value));
p = Parameter(Int(32), false, 0, d.name(), true);
(*params)[d.name()] = p;
}
dims.push_back(Variable::make(Int(32), p.name(), p));
} else {
CHECK(d_->kind() == lang::TK_CONST);
int32_t value = lang::Const(d_).value();
dims.push_back(Expr(value));
}
}

for (int i = 0; i < imageParam.dimensions(); i++) {
imageParam.dim(i).set_bounds(0, dims[i]);
}
(*params)[imageParam.name()] = imageParam.parameter();
for (int i = 0; i < imageParam.dimensions(); i++) {
imageParam.dim(i).set_bounds(0, dims[i]);
}
(*params)[imageParam.name()] = imageParam.parameter();
}

void translateOutput(
Expand Down Expand Up @@ -156,6 +158,8 @@ Expr translateExpr(
return t(0) * t(1);
case '/':
return t(0) / t(1);
case '%':
return t(0) % t(1);
case lang::TK_MIN:
return min(t(0), t(1));
case lang::TK_MAX:
Expand Down Expand Up @@ -488,22 +492,25 @@ Expr reductionUpdate(Expr e) {
return Call::make(e.type(), kReductionUpdate, {e}, Call::Intrinsic);
}

// Translate a single TC comprehension/statement to Halide components: funcs,
// bounds, reductions.
//
// Note that the function definitions created by translateComprehension may
// contain kReductionUpdate intrinsics. These may have to be removed
// in order to be able to apply internal Halide analysis passes on them.
void translateComprehension(
const lang::Comprehension& c,
const lang::Comprehension& comprehension,
const map<string, Parameter>& params,
bool throwWarnings,
map<string, Function>* funcs,
FunctionBounds* bounds) {
Function f;
auto it = funcs->find(c.ident().name());
auto it = funcs->find(comprehension.ident().name());
if (it != funcs->end()) {
f = it->second;
} else {
f = Function(c.ident().name());
(*funcs)[c.ident().name()] = f;
f = Function(comprehension.ident().name());
(*funcs)[comprehension.ident().name()] = f;
}
// Function is the internal Halide IR type for a pipeline
// stage. Func is the front-end class that wraps it. Here it's
Expand All @@ -512,7 +519,7 @@ void translateComprehension(

vector<Var> lhs;
vector<Expr> lhs_as_exprs;
for (lang::Ident id : c.indices()) {
for (lang::Ident id : comprehension.indices()) {
lhs.push_back(Var(id.name()));
lhs_as_exprs.push_back(lhs.back());
}
Expand All @@ -521,17 +528,17 @@ void translateComprehension(
// in the future we may consider using Halide Let bindings when they
// are supported later
map<string, Expr> lets;
for (auto wc : c.whereClauses()) {
for (auto wc : comprehension.whereClauses()) {
if (wc->kind() == lang::TK_LET) {
auto let = lang::Let(wc);
lets[let.name().name()] = translateExpr(let.rhs(), params, *funcs, lets);
}
}

Expr rhs = translateExpr(c.rhs(), params, *funcs, lets);
Expr rhs = translateExpr(comprehension.rhs(), params, *funcs, lets);

std::vector<Expr> all_exprs;
for (auto wc : c.whereClauses()) {
for (auto wc : comprehension.whereClauses()) {
if (wc->kind() == lang::TK_EXISTS) {
all_exprs.push_back(
translateExpr(lang::Exists(wc).exp(), params, *funcs, lets));
Expand All @@ -555,7 +562,7 @@ void translateComprehension(
// values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity
// for the reduction and then applies the reduction.
bool should_zero = false;
switch (c.assignment()->kind()) {
switch (comprehension.assignment()->kind()) {
case lang::TK_PLUS_EQ_B:
should_zero = true; // fallthrough
case lang::TK_PLUS_EQ:
Expand Down Expand Up @@ -587,12 +594,13 @@ void translateComprehension(
case '=':
break;
default:
throw lang::ErrorReport(c) << "Unimplemented reduction "
<< c.assignment()->range().text() << "\n";
throw lang::ErrorReport(comprehension)
<< "Unimplemented reduction "
<< comprehension.assignment()->range().text() << "\n";
}

// Tag reductions as such
if (c.assignment()->kind() != '=') {
if (comprehension.assignment()->kind() != '=') {
rhs = reductionUpdate(rhs);
}

Expand Down Expand Up @@ -632,7 +640,7 @@ void translateComprehension(
Scope<Interval> solution;

// Put anything explicitly specified with a 'where' class in the solution
for (auto constraint_ : c.whereClauses()) {
for (auto constraint_ : comprehension.whereClauses()) {
if (constraint_->kind() != lang::TK_RANGE_CONSTRAINT)
continue;
auto constraint = lang::RangeConstraint(constraint_);
Expand All @@ -653,7 +661,8 @@ void translateComprehension(

// Infer the rest
all_exprs.push_back(rhs);
forwardBoundsInference(all_exprs, *bounds, c, throwWarnings, &solution);
forwardBoundsInference(
all_exprs, *bounds, comprehension, throwWarnings, &solution);

// TODO: What if subsequent updates have incompatible bounds
// (e.g. an in-place stencil)?. The .bound directive will use the
Expand All @@ -664,7 +673,7 @@ void translateComprehension(

for (Var v : lhs) {
if (!solution.contains(v.name())) {
throw lang::ErrorReport(c)
throw lang::ErrorReport(comprehension)
<< "Free variable " << v
<< " was not solved in range inference. May not be used right-hand side";
}
Expand All @@ -688,7 +697,7 @@ void translateComprehension(
for (size_t i = 0; i < unbound.size(); i++) {
auto v = unbound[unbound.size() - 1 - i];
if (!solution.contains(v->name)) {
throw lang::ErrorReport(c)
throw lang::ErrorReport(comprehension)
<< "Free variable " << v << " is unconstrained. "
<< "Use a 'where' clause to set its range.";
}
Expand Down Expand Up @@ -736,6 +745,7 @@ void translateComprehension(
stage.reorder(loop_nest);
}

// Translate a semantically checked TC def to HalideComponents struct.
HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
map<string, Function> funcs;
HalideComponents components;
Expand Down Expand Up @@ -895,6 +905,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) {
lang::Def(lang::Sema().checkFunction(treeRef)), throwWarnings);
}

// NOTE: there is no guarantee here that the tc string has only one def. It
// could have many defs. Only first def will be converted in that case.
HalideComponents
translate(isl::ctx ctx, const std::string& tc, bool throwWarnings) {
LOG_IF(INFO, tc::FLAGS_debug_halide) << tc;
Expand Down
4 changes: 2 additions & 2 deletions tc/core/tc2halide.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace tc2halide {
// of the input and output tensors. We do not explicitly enumerate the
// scalar params.
struct HalideComponents {
lang::TreeRef
def; // post-semantic analaysis tree, used for later error reporting
// post-semantic analaysis tree, used for later error reporting
lang::TreeRef def;
Halide::Internal::Stmt stmt;
std::vector<Halide::ImageParam> inputs;
std::map<std::string, Halide::Internal::Parameter> params;
Expand Down
4 changes: 2 additions & 2 deletions tc/lang/lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ namespace lang {
_(TK_LET, "let", "") \
_(TK_EXISTS, "exists", "exists")

static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!";
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!%";

enum TokenKind {
// we use characters to represent themselves so skip all valid characters
Expand Down Expand Up @@ -137,7 +137,7 @@ struct SharedParserData {
{TK_AND},
{'>', '<', TK_LE, TK_GE, TK_EQ, TK_NE},
{'+', '-'},
{'*', '/'},
{'*', '/', '%'},
};
std::vector<std::vector<int>> unary_ops = {
{'-', '!'},
Expand Down
Loading