Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit e75086d

Browse files
Merge pull request #578 from facebookresearch/pr/pre-template
prepare for templated isl types
2 parents 03512da + 9904d30 commit e75086d

File tree

9 files changed

+100
-32
lines changed

9 files changed

+100
-32
lines changed

tc/core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_library(
3434
polyhedral/scop.cc
3535
polyhedral/separation.cc
3636
polyhedral/unroll.cc
37+
polyhedral/utils.cc
3738
)
3839
target_include_directories(tc_core PUBLIC ${LLVM_INCLUDE_DIRS})
3940
target_link_libraries(

tc/core/halide2isl.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tc/core/polyhedral/schedule_isl_conversion.h"
2525
#include "tc/core/polyhedral/schedule_transforms.h"
2626
#include "tc/core/polyhedral/schedule_tree.h"
27+
#include "tc/core/polyhedral/utils.h"
2728
#include "tc/core/tc2halide.h"
2829

2930
namespace tc {
@@ -259,7 +260,8 @@ isl::map extractAccess(
259260

260261
isl::space paramSpace = domain.paramSpace;
261262
isl::id tensorID(paramSpace.get_ctx(), tensor);
262-
auto tensorSpace = paramSpace.add_named_tuple_id_ui(tensorID, args.size());
263+
auto tensorTuple = constructTensorTuple(paramSpace, tensorID, args.size());
264+
auto tensorSpace = tensorTuple.get_space();
263265

264266
// Start with a totally unconstrained set - every point in
265267
// the allocation could be accessed.
@@ -275,9 +277,10 @@ isl::map extractAccess(
275277
// The coordinate written to in the range ...
276278
auto rangePoint = identity.get_aff(i);
277279
// ... equals the coordinate accessed as a function of the parameters.
278-
auto domainPoint = halide2isl::makeIslAffFromExpr(tensorSpace, args[i]);
280+
auto domainPoint = halide2isl::makeIslAffFromExpr(paramSpace, args[i]);
279281
if (!domainPoint.is_null()) {
280-
access = access.intersect(isl::pw_aff(domainPoint).eq_set(rangePoint));
282+
domainPoint = domainPoint.unbind_params_insert_domain(tensorTuple);
283+
access = access.intersect(domainPoint.eq_set(rangePoint));
281284
}
282285
}
283286

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ Scop::SyncLevel MappedScop::findBestSync(
573573
auto contextSt = scop_->scheduleRoot()->children()[0];
574574
auto contextElem = contextSt->as<detail::ScheduleTreeContext>();
575575
TC_CHECK(nullptr != contextElem);
576-
dependences = dependences.intersect_params(contextElem->context_);
576+
dependences = dependences.intersect_params(contextElem->context_.params());
577577

578578
if (dependences.is_subset(dependences.eq_at(domainToThread))) {
579579
return Scop::SyncLevel::None;

tc/core/polyhedral/memory_promotion.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "tc/core/polyhedral/exceptions.h"
2626
#include "tc/core/polyhedral/schedule_tree.h"
2727
#include "tc/core/polyhedral/scop.h"
28+
#include "tc/core/polyhedral/utils.h"
2829
#include "tc/external/isl.h"
2930

3031
namespace tc {
@@ -409,17 +410,20 @@ namespace {
409410
// context of the scop.
410411
isl::set tensorElementsSet(const Scop& scop, isl::id tensorId) {
411412
auto halideParameter = scop.findArgument(tensorId).parameter();
412-
auto space = scop.domain().get_space().params();
413+
auto space = scop.domain().get_space();
413414
auto nDim = halideParameter.dimensions();
414-
space = space.add_named_tuple_id_ui(tensorId, nDim);
415+
auto tensorTuple = constructTensorTuple(space, tensorId, nDim);
416+
auto tensorSpace = tensorTuple.get_space();
415417

416-
auto tensorElements = isl::set::universe(space);
417-
auto identity = isl::multi_aff::identity(space.range().map_from_set());
418+
auto tensorElements = isl::set::universe(tensorSpace);
419+
auto identity = isl::multi_aff::identity(tensorSpace.map_from_set());
418420
for (int i = 0; i < nDim; ++i) {
419421
auto minAff = halide2isl::makeIslAffFromExpr(
420422
space, halideParameter.min_constraint(i));
421423
auto extentAff = halide2isl::makeIslAffFromExpr(
422424
space, halideParameter.extent_constraint(i));
425+
minAff = minAff.unbind_params_insert_domain(tensorTuple);
426+
extentAff = extentAff.unbind_params_insert_domain(tensorTuple);
423427
auto aff = identity.get_aff(i);
424428
tensorElements = tensorElements & (minAff <= isl::aff_set(aff)) &
425429
(isl::aff_set(aff) < (minAff + extentAff));

tc/core/polyhedral/scop.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ struct Scop {
136136
isl::set makeContext(
137137
const std::unordered_map<std::string, T>& sizes =
138138
std::unordered_map<std::string, T>()) const {
139-
auto s = domain().get_space().params();
139+
auto s = domain().get_space();
140140
return makeSpecializationSet(s, sizes);
141141
}
142142

tc/core/polyhedral/utils.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/**
2+
* Copyright (c) 2018, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "tc/core/polyhedral/utils.h"
17+
18+
namespace tc {
19+
namespace polyhedral {
20+
21+
/* Construct a tuple representing the tensor with identifier "tensorId" and
22+
* dimension "dim" from the parameter space "paramSpace",
23+
* without any specific names for the indices, from the perspective
24+
* of the user.
25+
* Since some names are required, use names of the form "__tc_tensor_arg*".
26+
*/
27+
isl::multi_id
28+
constructTensorTuple(isl::space paramSpace, isl::id tensorId, size_t dim) {
29+
auto tensorSpace = paramSpace.add_named_tuple_id_ui(tensorId, dim);
30+
isl::id_list tensorArgs(paramSpace.get_ctx(), 0);
31+
for (size_t i = 0; i < dim; ++i) {
32+
auto name = std::string("__tc_tensor_arg") + std::to_string(i);
33+
tensorArgs = tensorArgs.add(isl::id(paramSpace.get_ctx(), name));
34+
}
35+
return isl::multi_id(tensorSpace, tensorArgs);
36+
}
37+
38+
} // namespace polyhedral
39+
} // namespace tc

tc/core/polyhedral/utils.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/**
2+
* Copyright (c) 2018, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "tc/external/isl.h"
19+
20+
namespace tc {
21+
namespace polyhedral {
22+
23+
/* Construct a tuple representing the tensor with identifier "tensorId" and
24+
* dimension "dim" from the parameter space "paramSpace",
25+
* without any specific names for the indices.
26+
*/
27+
isl::multi_id
28+
constructTensorTuple(isl::space paramSpace, isl::id tensorId, size_t dim);
29+
30+
} // namespace polyhedral
31+
} // namespace tc

tc/external/detail/islpp-inl.h

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ inline isl::aff operator/(isl::aff A, int i) {
4141
return A.scale_down(isl::val(A.get_ctx(), i));
4242
}
4343

44-
inline isl::aff operator+(int i, isl::aff A) {
45-
isl::ctx ctx = A.get_ctx();
46-
return A + isl::val(ctx, i);
44+
template <typename T>
45+
inline isl::aff operator+(int i, T A) {
46+
return A.add_constant_si(i);
4747
}
4848

4949
inline isl::aff operator+(isl::aff A, isl::val v) {
@@ -55,10 +55,6 @@ inline isl::aff operator+(isl::val v, isl::aff A) {
5555
return A + v;
5656
}
5757

58-
inline isl::aff operator+(isl::aff A, isl::aff B) {
59-
return A.add(B);
60-
}
61-
6258
inline isl::aff operator+(isl::aff A, int i) {
6359
return i + A;
6460
}
@@ -195,10 +191,6 @@ inline isl::multi_aff operator/(isl::multi_aff left, isl::multi_val right) {
195191
///////////////////////////////////////////////////////////////////////////////
196192
// Operations on isl::set and isl::union_set
197193
///////////////////////////////////////////////////////////////////////////////
198-
inline isl::set operator&(isl::set S1, isl::set S2) {
199-
return S1.intersect(S2);
200-
}
201-
202194
inline isl::union_set operator&(isl::union_set S1, isl::set S2) {
203195
return S1.intersect(isl::union_set(S2));
204196
}
@@ -207,10 +199,6 @@ inline isl::union_set operator&(isl::set S1, isl::union_set S2) {
207199
return S2 & S1;
208200
}
209201

210-
inline isl::union_set operator&(isl::union_set S1, isl::union_set S2) {
211-
return S1.intersect(S2);
212-
}
213-
214202
///////////////////////////////////////////////////////////////////////////////
215203
// Operations on isl::set and isl::point
216204
///////////////////////////////////////////////////////////////////////////////

tc/external/detail/islpp.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,21 @@ namespace isl {
3535
// the official ISL C++ bindings.
3636
//
3737

38+
template <typename T>
39+
inline T operator+(T a, T b) {
40+
return a.add(b);
41+
}
42+
3843
template <typename T>
3944
inline T operator-(T a, T b) {
4045
return a.sub(b);
4146
}
4247

48+
template <typename T>
49+
inline T operator&(T S1, T S2) {
50+
return S1.intersect(S2);
51+
}
52+
4353
inline isl::val operator*(isl::val l, isl::val r) {
4454
return l.mul(r);
4555
}
@@ -52,10 +62,6 @@ inline isl::val operator*(long i, isl::val v) {
5262
return v * i;
5363
}
5464

55-
inline isl::val operator+(isl::val l, isl::val r) {
56-
return l.add(r);
57-
}
58-
5965
inline isl::val operator+(isl::val v, long i) {
6066
return v.add(isl::val(v.get_ctx(), i));
6167
}
@@ -122,8 +128,6 @@ isl::aff operator*(isl::val V, isl::aff A);
122128

123129
isl::aff operator/(isl::aff A, int i);
124130

125-
isl::aff operator+(int i, isl::aff A);
126-
isl::aff operator+(isl::aff A, isl::aff B);
127131
isl::aff operator+(isl::aff A, int i);
128132
isl::aff operator+(isl::aff A, isl::val v);
129133
isl::aff operator+(isl::val v, isl::aff A);
@@ -184,10 +188,8 @@ isl::multi_aff operator/(isl::multi_aff left, isl::multi_val right);
184188
///////////////////////////////////////////////////////////////////////////////
185189
// Operations on isl::set and isl::union_set
186190
///////////////////////////////////////////////////////////////////////////////
187-
isl::set operator&(isl::set S1, isl::set S2);
188191
isl::union_set operator&(isl::union_set S1, isl::set S2);
189192
isl::union_set operator&(isl::set S1, isl::union_set S2);
190-
isl::union_set operator&(isl::union_set S1, isl::union_set S2);
191193

192194
///////////////////////////////////////////////////////////////////////////////
193195
// Operations on isl::set and isl::point

0 commit comments

Comments
 (0)