@@ -225,7 +225,7 @@ bool promotionImprovesCoalescing(
225225 auto depth = marker->scheduleDepth (root);
226226 auto activePoints = activeDomainPoints (root, mapping);
227227 auto localAccesses = originalAccesses.intersect_domain (activePoints);
228- auto schedule = prefixSchedule (root, marker);
228+ auto schedule = prefixSchedule<Prefix> (root, marker);
229229 auto scheduledAccesses = localAccesses.apply_domain (schedule);
230230 for (auto access : isl::UnionAsVector<isl::union_map>(scheduledAccesses)) {
231231 auto scheduleSpace = access.get_space ().domain ();
@@ -262,6 +262,8 @@ isl::union_set collectMappingsTo(const Scop& scop) {
262262 return mapping;
263263}
264264
265+ struct Unrolled ;
266+
265267/*
266268 * Check that only unrolled loops may appear in access subscripts.
267269 * Because the scoping point can be above a branching tree, descend into each
@@ -292,11 +294,12 @@ isl::union_set collectMappingsTo(const Scop& scop) {
292294 * different references may have different values, but all of them remain
293295 * independent of non-unrolled loop iterators.
294296 */
297+ template <typename Outer>
295298bool accessSubscriptsAreUnrolledLoops (
296299 const TensorReferenceGroup& group,
297300 const detail::ScheduleTree* root,
298301 const detail::ScheduleTree* scope,
299- isl::multi_union_pw_aff outerSchedule) {
302+ isl::MultiUnionPwAff<Statement, Outer> outerSchedule) {
300303 using namespace detail ;
301304
302305 auto nodes = ScheduleTree::collect (scope);
@@ -315,7 +318,7 @@ bool accessSubscriptsAreUnrolledLoops(
315318
316319 auto unrolledDims = isl::union_pw_aff_list (leaf->ctx_ , 1 );
317320 for (auto node : ancestors) {
318- auto band = node->as <detail::ScheduleTreeBand>();
321+ auto band = node->template as <detail::ScheduleTreeBand>();
319322 if (!band) {
320323 continue ;
321324 }
@@ -331,8 +334,9 @@ bool accessSubscriptsAreUnrolledLoops(
331334 }
332335
333336 auto space =
334- subdomain.get_space ().add_unnamed_tuple_ui (unrolledDims.size ());
335- auto unrolledDimsMupa = isl::multi_union_pw_aff (space, unrolledDims);
337+ subdomain.get_space ().template add_unnamed_tuple_ui <Unrolled>(unrolledDims.size ());
338+ auto unrolledDimsMupa = isl::MultiUnionPwAff<Statement, Unrolled>(
339+ space, isl::UnionPwAffListOn<Statement>(unrolledDims));
336340
337341 // It is possible that no loops are unrolled, in which case
338342 // unrolledDimsMupa is zero-dimensional and needs an explicit domain
@@ -341,10 +345,11 @@ bool accessSubscriptsAreUnrolledLoops(
341345 unrolledDimsMupa.intersect_domain (group.originalAccesses ().domain ());
342346
343347 auto accesses = group.originalAccesses ();
344- auto schedule = outerSchedule.flat_range_product (unrolledDimsMupa);
345- accesses = accesses.apply_domain (isl::union_map::from (schedule));
348+ auto schedule = outerSchedule.range_product (unrolledDimsMupa);
349+ auto scheduleMap = schedule.toUnionMap ();
350+ auto scheduledAccesses = accesses.apply_domain (scheduleMap);
346351
347- if (!accesses .is_single_valued ()) {
352+ if (!scheduledAccesses .is_single_valued ()) {
348353 return false ;
349354 }
350355 }
@@ -364,23 +369,25 @@ bool accessSubscriptsAreUnrolledLoops(
364369 * thread associated to a given pair of tensor element and outer schedule
365370 * iteration.
366371 */
372+ template <typename Outer>
367373bool isPromotableToRegistersBelow (
368374 const TensorReferenceGroup& group,
369375 const detail::ScheduleTree* root,
370376 const detail::ScheduleTree* scope,
371- isl::multi_union_pw_aff outer,
372- isl::multi_union_pw_aff thread) {
377+ isl::MultiUnionPwAff<Statement, Outer> outer,
378+ isl::MultiUnionPwAff<Statement, Thread> thread) {
373379 if (!accessSubscriptsAreUnrolledLoops (
374- group, root, scope, outer.flat_range_product (thread))) {
380+ group, root, scope, outer.range_product (thread))) {
375381 return false ;
376382 }
377383
378384 auto originalAccesses = group.originalAccesses ();
379- auto map = isl::union_map::from (outer);
380- map = map.range_product (originalAccesses);
381- map = map.apply_domain (isl::union_map::from (thread));
385+ auto outerMap = isl::UnionMap<Statement, Outer>::from (outer);
386+ auto pair = outerMap.range_product (originalAccesses);
387+ auto threadMap = isl::UnionMap<Statement, Thread>::from (thread);
388+ auto threadToPair = pair.apply_domain (threadMap);
382389
383- return map .is_injective ();
390+ return threadToPair .is_injective ();
384391}
385392
386393/*
@@ -653,15 +660,15 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
653660 auto blockSchedule = mscop.blockMappingSchedule (mscop.schedule ());
654661
655662 // Pure affine schedule without (mapping) filters.
656- auto partialSchedMupa = partialScheduleMupa (root, scope);
663+ auto partialSchedMupa = partialScheduleMupa<Scope> (root, scope);
657664 // Schedule with block mapping filter.
658665 auto partialSched =
659666 isl::union_map::from (partialSchedMupa).intersect_domain (blockMapping);
660667 // The following promotion validity and profitability checks need to be
661668 // performed with respect to the block mapping, so append the block schedule.
662669 // If the partial schedule contains it already, it will just end up with
663670 // identical dimensions without affecting the result of the checks.
664- partialSchedMupa = partialSchedMupa.flat_range_product (blockSchedule);
671+ auto partialSchedBlockMupa = partialSchedMupa.range_product (blockSchedule);
665672
666673 for (auto & tensorGroups : groupMap) {
667674 auto tensorId = tensorGroups.first ;
@@ -675,11 +682,11 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
675682 continue ;
676683 }
677684 if (!isPromotableToRegistersBelow (
678- *group, root, scope, partialSchedMupa , threadSchedule)) {
685+ *group, root, scope, partialSchedBlockMupa , threadSchedule)) {
679686 continue ;
680687 }
681688 // Check reuse within threads.
682- auto schedule = partialSchedMupa .flat_range_product (threadSchedule);
689+ auto schedule = partialSchedBlockMupa .flat_range_product (threadSchedule);
683690 if (!hasReuseWithin (*group, schedule)) {
684691 continue ;
685692 }
0 commit comments