Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 0f3f64b

Browse files
Merge pull request #457 from facebookresearch/pr/update
tc2halide: properly tag reduction updates
2 parents 12e5ae6 + 6e93aab commit 0f3f64b

File tree

7 files changed

+82
-54
lines changed

7 files changed

+82
-54
lines changed

tc/core/halide2isl.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -482,14 +482,6 @@ std::vector<Reduction> findReductions(const Stmt& s) {
482482
class FindReductions : public IRVisitor {
483483
using IRVisitor::visit;
484484

485-
bool isReductionInit(const Provide* op) {
486-
if (const Call* call = op->values[0].as<Call>()) {
487-
return call->is_intrinsic(tc2halide::kReductionInit);
488-
} else {
489-
return false;
490-
}
491-
}
492-
493485
bool isReductionUpdate(const Provide* op) {
494486
if (const Call* call = op->values[0].as<Call>()) {
495487
return call->is_intrinsic(tc2halide::kReductionUpdate);

tc/core/halide_utils.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ std::string halideCodegenC(const Stmt& stmt) {
123123
using IRPrinter::visit;
124124

125125
void visit(const Call* op) override {
126-
if (op->is_intrinsic(tc2halide::kReductionInit) ||
127-
op->is_intrinsic(tc2halide::kReductionUpdate)) {
126+
if (op->is_intrinsic(tc2halide::kReductionUpdate)) {
128127
op->args[0].accept(this);
129128
} else if (
130129
op->call_type == Call::Halide || op->call_type == Call::Image) {

tc/core/polyhedral/codegen_llvm.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,7 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
234234
auto addr = builder->CreateInBoundsGEP(baseAddr, args);
235235
value = builder->CreateLoad(addr);
236236
return;
237-
} else if (
238-
call->is_intrinsic(tc2halide::kReductionInit) ||
239-
call->is_intrinsic(tc2halide::kReductionUpdate)) {
237+
} else if (call->is_intrinsic(tc2halide::kReductionUpdate)) {
240238
call->args[0].accept(this);
241239
return;
242240
} else {

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -550,9 +550,7 @@ void emitHalideExpr(
550550
op->call_type == Halide::Internal::Call::CallType::Image) {
551551
tc::polyhedral::detail::emitMappedTensorAccess(
552552
op->name, op, op->args, context);
553-
} else if (
554-
op->is_intrinsic(tc2halide::kReductionInit) ||
555-
op->is_intrinsic(tc2halide::kReductionUpdate)) {
553+
} else if (op->is_intrinsic(tc2halide::kReductionUpdate)) {
556554
op->args[0].accept(this);
557555
} else {
558556
IRPrinter::visit(op);

tc/core/tc2halide.cc

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -484,21 +484,19 @@ void forwardBoundsInference(
484484
}
485485
}
486486

487-
Expr reductionInit(Expr e) {
488-
return Call::make(e.type(), kReductionInit, {e}, Call::Intrinsic);
489-
}
490-
491487
Expr reductionUpdate(Expr e) {
492488
return Call::make(e.type(), kReductionUpdate, {e}, Call::Intrinsic);
493489
}
494490

491+
// Note that the function definitions created by translateComprehension may
492+
// contain kReductionUpdate intrinsics. These may have to be removed
493+
// in order to be able to apply internal Halide analysis passes on them.
495494
void translateComprehension(
496495
const lang::Comprehension& c,
497496
const map<string, Parameter>& params,
498497
bool throwWarnings,
499498
map<string, Function>* funcs,
500-
FunctionBounds* bounds,
501-
vector<Function>* reductions) {
499+
FunctionBounds* bounds) {
502500
Function f;
503501
auto it = funcs->find(c.ident().name());
504502
if (it != funcs->end()) {
@@ -593,8 +591,9 @@ void translateComprehension(
593591
<< c.assignment()->range().text() << "\n";
594592
}
595593

594+
// Tag reductions as such
596595
if (c.assignment()->kind() != '=') {
597-
reductions->push_back(f);
596+
rhs = reductionUpdate(rhs);
598597
}
599598

600599
// Bind any scalar params on the rhs to their parameter objects.
@@ -743,13 +742,12 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
743742
components.def = def;
744743
FunctionBounds bounds;
745744

746-
vector<Function> reductions;
747745
for (auto p : def.params()) {
748746
translateParam(p, &components.params, &components.inputs);
749747
}
750748
for (auto c : def.statements()) {
751749
translateComprehension(
752-
c, components.params, throwWarnings, &funcs, &bounds, &reductions);
750+
c, components.params, throwWarnings, &funcs, &bounds);
753751
}
754752
vector<Function> outputs;
755753
for (auto p : def.returns()) {
@@ -807,33 +805,6 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
807805
s = uniquify_variable_names(s);
808806
s = simplify(s);
809807

810-
// Tag reductions as such
811-
for (const Function& f : reductions) {
812-
class TagReduction : public IRMutator2 {
813-
using IRMutator2::visit;
814-
bool found_init = false;
815-
Stmt visit(const Provide* op) override {
816-
if (op->name == f.name()) {
817-
if (found_init) {
818-
return Provide::make(
819-
op->name, {reductionUpdate(op->values[0])}, op->args);
820-
} else {
821-
found_init = true;
822-
return Provide::make(
823-
op->name, {reductionInit(op->values[0])}, op->args);
824-
}
825-
} else {
826-
return op;
827-
}
828-
}
829-
const Function& f;
830-
831-
public:
832-
TagReduction(const Function& f) : f(f) {}
833-
} tagReduction(f);
834-
s = tagReduction.mutate(s);
835-
}
836-
837808
// Trim ProducerConsumer annotations. TC doesn't use them.
838809
class RemoveProducerConsumer : public IRMutator2 {
839810
using IRMutator2::visit;

tc/core/tc2halide.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ struct HalideComponents {
4040
};
4141

4242
// For TC reductions, the right-hand-sides of the corresponding
43-
// Provide nodes are tagged with intrinsics with the following names.
44-
Halide::Internal::Call::ConstString kReductionInit = "ReductionInit";
43+
// Provide nodes are tagged with intrinsics with the following name.
4544
Halide::Internal::Call::ConstString kReductionUpdate = "ReductionUpdate";
4645

4746
// Translate a TC parse tree into equivalent Halide imperative IR with

test/test_cuda_mapper.cc

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,77 @@ def fun(float(N) I) -> (O) {
861861
EXPECT_TRUE(code.find("O[0] = (O") != std::string::npos);
862862
}
863863

864+
struct ReductionTest : public PolyhedralMapperTest {
865+
static CudaMappingOptions reductionTestMappingOptions() {
866+
return DefaultOptions()
867+
.outerScheduleFusionStrategy(tc::FusionStrategy::Preserve3Coincident)
868+
.outerScheduleAllowSkewing(false)
869+
.outerSchedulePositiveOrthant(true)
870+
.intraTileScheduleFusionStrategy(tc::FusionStrategy::Min)
871+
.intraTileScheduleAllowSkewing(false)
872+
.intraTileSchedulePositiveOrthant(true)
873+
.fixParametersBeforeScheduling(false)
874+
.tile(18, 32)
875+
.unroll(16)
876+
.tileImperfectlyNested(false)
877+
.matchLibraryCalls(true)
878+
.mapToThreads({512})
879+
.mapToBlocks({16384})
880+
.useSharedMemory(true)
881+
.usePrivateMemory(false)
882+
.unrollCopyShared(true);
883+
}
884+
885+
void Check(const string& tc) {
886+
auto code = codegenMapped(tc, reductionTestMappingOptions());
887+
using tc::code::cuda::kCUBReductionName;
888+
EXPECT_TRUE(code.find(kCUBReductionName) != std::string::npos);
889+
}
890+
};
891+
892+
/*
893+
* Check that a reduction library call is produced when the reduction
894+
* instruction is before an instruction modifying the same tensor.
895+
*/
896+
TEST_F(ReductionTest, BeforeInstruction) {
897+
Check(R"TC(
898+
def fun(float(N, K) I) -> (O) {
899+
O(n) +=! I(n, r_n)
900+
O(n) = O(n) / (K)
901+
}
902+
)TC");
903+
}
904+
905+
/*
906+
* Check that a reduction library call is produced when the reduction
907+
* instruction is after an instruction modifying the same tensor.
908+
*/
909+
TEST_F(ReductionTest, AfterInstruction) {
910+
Check(R"TC(
911+
def fun(float(N, K) I, float(N) O0) -> (O) {
912+
O(n) = 0.0 where n in 0:N
913+
O(n) += O0(n)
914+
O(n) += I(n, r_n)
915+
}
916+
)TC");
917+
}
918+
919+
/*
920+
* Check that a reduction library call is produced when the reduction
921+
* instruction is placed after an instruction modifying the same tensor and
922+
* before an instruction modifying the same tensor.
923+
*/
924+
TEST_F(ReductionTest, BetweenInstructions) {
925+
Check(R"TC(
926+
def fun(float(N, K) I, float(N) O0) -> (O) {
927+
O(n) = 0.0 where n in 0:N
928+
O(n) += O0(n)
929+
O(n) += I(n, r_n)
930+
O(n) = O(n) / (K)
931+
}
932+
)TC");
933+
}
934+
864935
static const string kTcMM = R"TC(
865936
def fun(float(M, K) A, float(K, N) B) -> (C) {
866937
C(m, n) +=! A(m, r_k) * B(r_k, n)

0 commit comments

Comments
 (0)