@@ -225,7 +225,7 @@ bool promotionImprovesCoalescing(
225
225
auto depth = marker->scheduleDepth (root);
226
226
auto activePoints = activeDomainPoints (root, mapping);
227
227
auto localAccesses = originalAccesses.intersect_domain (activePoints);
228
- auto schedule = prefixSchedule (root, marker);
228
+ auto schedule = prefixSchedule<Prefix> (root, marker);
229
229
auto scheduledAccesses = localAccesses.apply_domain (schedule);
230
230
for (auto access : isl::UnionAsVector<isl::union_map>(scheduledAccesses)) {
231
231
auto scheduleSpace = access.get_space ().domain ();
@@ -262,6 +262,8 @@ isl::union_set collectMappingsTo(const Scop& scop) {
262
262
return mapping;
263
263
}
264
264
265
+ struct Unrolled ;
266
+
265
267
/*
266
268
* Check that only unrolled loops may appear in access subscripts.
267
269
* Because the scoping point can be above a branching tree, descend into each
@@ -292,11 +294,12 @@ isl::union_set collectMappingsTo(const Scop& scop) {
292
294
* different references may have different values, but all of them remain
293
295
* independent of non-unrolled loop iterators.
294
296
*/
297
+ template <typename Outer>
295
298
bool accessSubscriptsAreUnrolledLoops (
296
299
const TensorReferenceGroup& group,
297
300
const detail::ScheduleTree* root,
298
301
const detail::ScheduleTree* scope,
299
- isl::multi_union_pw_aff outerSchedule) {
302
+ isl::MultiUnionPwAff<Statement, Outer> outerSchedule) {
300
303
using namespace detail ;
301
304
302
305
auto nodes = ScheduleTree::collect (scope);
@@ -315,7 +318,7 @@ bool accessSubscriptsAreUnrolledLoops(
315
318
316
319
auto unrolledDims = isl::union_pw_aff_list (leaf->ctx_ , 1 );
317
320
for (auto node : ancestors) {
318
- auto band = node->as <detail::ScheduleTreeBand>();
321
+ auto band = node->template as <detail::ScheduleTreeBand>();
319
322
if (!band) {
320
323
continue ;
321
324
}
@@ -331,8 +334,9 @@ bool accessSubscriptsAreUnrolledLoops(
331
334
}
332
335
333
336
auto space =
334
- subdomain.get_space ().add_unnamed_tuple_ui (unrolledDims.size ());
335
- auto unrolledDimsMupa = isl::multi_union_pw_aff (space, unrolledDims);
337
+ subdomain.get_space ().template add_unnamed_tuple_ui <Unrolled>(unrolledDims.size ());
338
+ auto unrolledDimsMupa = isl::MultiUnionPwAff<Statement, Unrolled>(
339
+ space, isl::UnionPwAffListOn<Statement>(unrolledDims));
336
340
337
341
// It is possible that no loops are unrolled, in which case
338
342
// unrolledDimsMupa is zero-dimensional and needs an explicit domain
@@ -341,10 +345,11 @@ bool accessSubscriptsAreUnrolledLoops(
341
345
unrolledDimsMupa.intersect_domain (group.originalAccesses ().domain ());
342
346
343
347
auto accesses = group.originalAccesses ();
344
- auto schedule = outerSchedule.flat_range_product (unrolledDimsMupa);
345
- accesses = accesses.apply_domain (isl::union_map::from (schedule));
348
+ auto schedule = outerSchedule.range_product (unrolledDimsMupa);
349
+ auto scheduleMap = schedule.toUnionMap ();
350
+ auto scheduledAccesses = accesses.apply_domain (scheduleMap);
346
351
347
- if (!accesses .is_single_valued ()) {
352
+ if (!scheduledAccesses .is_single_valued ()) {
348
353
return false ;
349
354
}
350
355
}
@@ -364,23 +369,25 @@ bool accessSubscriptsAreUnrolledLoops(
364
369
* thread associated to a given pair of tensor element and outer schedule
365
370
* iteration.
366
371
*/
372
+ template <typename Outer>
367
373
bool isPromotableToRegistersBelow (
368
374
const TensorReferenceGroup& group,
369
375
const detail::ScheduleTree* root,
370
376
const detail::ScheduleTree* scope,
371
- isl::multi_union_pw_aff outer,
372
- isl::multi_union_pw_aff thread) {
377
+ isl::MultiUnionPwAff<Statement, Outer> outer,
378
+ isl::MultiUnionPwAff<Statement, Thread> thread) {
373
379
if (!accessSubscriptsAreUnrolledLoops (
374
- group, root, scope, outer.flat_range_product (thread))) {
380
+ group, root, scope, outer.range_product (thread))) {
375
381
return false ;
376
382
}
377
383
378
384
auto originalAccesses = group.originalAccesses ();
379
- auto map = isl::union_map::from (outer);
380
- map = map.range_product (originalAccesses);
381
- map = map.apply_domain (isl::union_map::from (thread));
385
+ auto outerMap = isl::UnionMap<Statement, Outer>::from (outer);
386
+ auto pair = outerMap.range_product (originalAccesses);
387
+ auto threadMap = isl::UnionMap<Statement, Thread>::from (thread);
388
+ auto threadToPair = pair.apply_domain (threadMap);
382
389
383
- return map .is_injective ();
390
+ return threadToPair .is_injective ();
384
391
}
385
392
386
393
/*
@@ -653,15 +660,15 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
653
660
auto blockSchedule = mscop.blockMappingSchedule (mscop.schedule ());
654
661
655
662
// Pure affine schedule without (mapping) filters.
656
- auto partialSchedMupa = partialScheduleMupa (root, scope);
663
+ auto partialSchedMupa = partialScheduleMupa<Scope> (root, scope);
657
664
// Schedule with block mapping filter.
658
665
auto partialSched =
659
666
isl::union_map::from (partialSchedMupa).intersect_domain (blockMapping);
660
667
// The following promotion validity and profitability checks need to be
661
668
// performed with respect to the block mapping, so append the block schedule.
662
669
// If the partial schedule contains it already, it will just end up with
663
670
// identical dimensions without affecting the result of the checks.
664
- partialSchedMupa = partialSchedMupa.flat_range_product (blockSchedule);
671
+ auto partialSchedBlockMupa = partialSchedMupa.range_product (blockSchedule);
665
672
666
673
for (auto & tensorGroups : groupMap) {
667
674
auto tensorId = tensorGroups.first ;
@@ -675,11 +682,11 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
675
682
continue ;
676
683
}
677
684
if (!isPromotableToRegistersBelow (
678
- *group, root, scope, partialSchedMupa , threadSchedule)) {
685
+ *group, root, scope, partialSchedBlockMupa , threadSchedule)) {
679
686
continue ;
680
687
}
681
688
// Check reuse within threads.
682
- auto schedule = partialSchedMupa .flat_range_product (threadSchedule);
689
+ auto schedule = partialSchedBlockMupa .flat_range_product (threadSchedule);
683
690
if (!hasReuseWithin (*group, schedule)) {
684
691
continue ;
685
692
}
0 commit comments