Skip to content

Commit 0211dd1

Browse files
[NFCI][SYCL] Refactor reduction-handler interactions (#18794)
1) Move `reduction::withAuxHandler` -> `HandlerAccess::postProcess`, as that's what it's really about. 2) Add comments describing what I see as issues with the previous/current implementation plus minor fixes of what I could address easily. 3) Added `HandlerAccess:preProcess` instead of using `addCounterInit` introduced in #13659. The original idea behind `withAuxHandler` and reduction implementations in general is to decouple them from SYCL RT internals as much as possible and `addCounterInit` was a step in the exact opposite direction. --------- Co-authored-by: Sergey Semenov <[email protected]>
1 parent 0dff0ff commit 0211dd1

File tree

6 files changed

+107
-33
lines changed

6 files changed

+107
-33
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3426,8 +3426,6 @@ class __SYCL_EXPORT handler {
34263426
friend class detail::reduction_impl_algo;
34273427

34283428
friend inline void detail::reduction::finalizeHandler(handler &CGH);
3429-
template <class FunctorTy>
3430-
friend void detail::reduction::withAuxHandler(handler &CGH, FunctorTy Func);
34313429

34323430
template <typename KernelName, detail::reduction::strategy Strategy, int Dims,
34333431
typename PropertiesT, typename... RestT>
@@ -3870,6 +3868,8 @@ class __SYCL_EXPORT handler {
38703868
void setKernelNameBasedCachePtr(
38713869
detail::KernelNameBasedCacheT *KernelNameBasedCachePtr);
38723870

3871+
queue getQueue();
3872+
38733873
protected:
38743874
/// Registers event dependencies in this command group.
38753875
void depends_on(const detail::EventImplPtr &Event);
@@ -3889,6 +3889,96 @@ class HandlerAccess {
38893889
kernel Kernel) {
38903890
Handler.parallel_for_impl(Range, Props, Kernel);
38913891
}
3892+
3893+
template <typename T, typename> struct dependent {
3894+
using type = T;
3895+
};
3896+
template <typename T>
3897+
using dependent_queue_t = typename dependent<queue, T>::type;
3898+
template <typename T>
3899+
using dependent_handler_t = typename dependent<handler, T>::type;
3900+
3901+
// pre/postProcess are used only for reductions right now, but the
3902+
// abstractions they provide aren't reduction-specific. The main problem they
3903+
// solve is
3904+
//
3905+
// # User code
3906+
// q.submit([&](handler &cgh) {
3907+
// set_dependencies(cgh);
3908+
// enqueue_whatever(cgh);
3909+
// }); // single submission
3910+
//
3911+
// that needs to be implemented as multiple enqueues involving
3912+
// pre-/post-processing internally. SYCL prohibits recursive submits from
3913+
// inside control group function object (lambda above) so we resort to a
3914+
// somewhat hacky way of creating multiple `handler`s and manual finalization
3915+
// of them (instead of the one in `queue::submit`).
3916+
//
3917+
// Overloads with `queue &q` are provided in case the caller has it created
3918+
// already to avoid unnecessary reference count increments associated with
3919+
// `handler::getQueue()`.
3920+
template <class FunctorTy>
3921+
static void preProcess(handler &CGH, dependent_queue_t<FunctorTy> &q,
3922+
FunctorTy Func) {
3923+
bool EventNeeded = !q.is_in_order();
3924+
handler AuxHandler(getSyclObjImpl(q), EventNeeded);
3925+
AuxHandler.copyCodeLoc(CGH);
3926+
std::forward<FunctorTy>(Func)(AuxHandler);
3927+
auto E = AuxHandler.finalize();
3928+
assert(!CGH.MIsFinalized &&
3929+
"Can't do pre-processing if the command has been enqueued already!");
3930+
if (EventNeeded)
3931+
CGH.depends_on(E);
3932+
}
3933+
template <class FunctorTy>
3934+
static void preProcess(dependent_handler_t<FunctorTy> &CGH,
3935+
FunctorTy &&Func) {
3936+
preProcess(CGH, CGH.getQueue(), std::forward<FunctorTy>(Func));
3937+
}
3938+
template <class FunctorTy>
3939+
static void postProcess(dependent_handler_t<FunctorTy> &CGH,
3940+
FunctorTy &&Func) {
3941+
// The "hacky" `handler`s manipulation mentioned above and implemented here
3942+
// is far from perfect. A better approach would be
3943+
//
3944+
// bool OrigNeedsEvent = CGH.needsEvent()
3945+
// assert(CGH.not_finalized/enqueued());
3946+
// if (!InOrderQueue)
3947+
// CGH.setNeedsEvent()
3948+
//
3949+
// handler PostProcessHandler(Queue, OrigNeedsEvent)
3950+
// auto E = CGH.finalize(); // enqueue original or current last
3951+
// // post-process
3952+
// if (!InOrder)
3953+
// PostProcessHandler.depends_on(E)
3954+
//
3955+
// swap_impls(CGH, PostProcessHandler)
3956+
// return; // queue::submit finalizes PostProcessHandler and returns its
3957+
// // event if necessary.
3958+
//
3959+
// Still hackier than "real" `queue::submit` but at least somewhat sane.
3960+
// That, however hasn't been tried yet and we have an even hackier approach
3961+
// copied from what's been done in an old reductions implementation before
3962+
// eventless submission work has started. Not sure how feasible the approach
3963+
// above is at this moment.
3964+
3965+
// This `finalize` is wrong (at least logically) if
3966+
// `assert(!CGH.eventNeeded())`
3967+
auto E = CGH.finalize();
3968+
dependent_queue_t<FunctorTy> Queue = CGH.getQueue();
3969+
bool InOrder = Queue.is_in_order();
3970+
// Cannot use `CGH.eventNeeded()` alone as there might be subsequent
3971+
// `postProcess` calls and we cannot address them properly similarly to the
3972+
// `finalize` issue described above. `swap_impls` suggested above might be
3973+
// able to handle this scenario naturally.
3974+
handler AuxHandler(getSyclObjImpl(Queue), CGH.eventNeeded() || !InOrder);
3975+
if (!InOrder)
3976+
AuxHandler.depends_on(E);
3977+
AuxHandler.copyCodeLoc(CGH);
3978+
std::forward<FunctorTy>(Func)(AuxHandler);
3979+
CGH.MLastEvent = AuxHandler.finalize();
3980+
return;
3981+
}
38923982
};
38933983
} // namespace detail
38943984

sycl/include/sycl/reduction.hpp

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -834,10 +834,6 @@ using __sycl_init_mem_for =
834834
std::conditional_t<std::is_same_v<KernelName, auto_name>, auto_name,
835835
reduction::InitMemKrn<KernelName>>;
836836

837-
__SYCL_EXPORT void
838-
addCounterInit(handler &CGH, std::shared_ptr<sycl::detail::queue_impl> &Queue,
839-
std::shared_ptr<int> &Counter);
840-
841837
template <typename T, class BinaryOperation, int Dims, size_t Extent,
842838
bool ExplicitIdentity, typename RedOutVar>
843839
class reduction_impl_algo {
@@ -995,7 +991,7 @@ class reduction_impl_algo {
995991
accessor Mem{*Buf, CGH};
996992
Func(Mem);
997993

998-
reduction::withAuxHandler(CGH, [&](handler &CopyHandler) {
994+
HandlerAccess::postProcess(CGH, [&](handler &CopyHandler) {
999995
// MSVC (19.32.31329) has problems compiling the line below when used
1000996
// as a host compiler in c++17 mode (but not in c++latest)
1001997
// accessor Mem{*Buf, CopyHandler};
@@ -1071,19 +1067,16 @@ class reduction_impl_algo {
10711067
// On discrete (vs. integrated) GPUs it's faster to initialize memory with an
10721068
// extra kernel than copy it from the host.
10731069
auto getGroupsCounterAccDiscrete(handler &CGH) {
1074-
queue q = createSyclObjFromImpl<queue>(CGH.MQueue);
1075-
device Dev = q.get_device();
1070+
queue q = CGH.getQueue();
10761071
auto Deleter = [=](auto *Ptr) { free(Ptr, q); };
10771072

10781073
std::shared_ptr<int> Counter(malloc_device<int>(1, q), Deleter);
10791074
CGH.addReduction(Counter);
10801075

1081-
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1082-
std::shared_ptr<detail::queue_impl> Queue(CGH.MQueue);
1083-
#else
1084-
std::shared_ptr<detail::queue_impl> &Queue = CGH.MQueue;
1085-
#endif
1086-
addCounterInit(CGH, Queue, Counter);
1076+
HandlerAccess::preProcess(CGH, q,
1077+
[Counter = Counter.get()](handler &AuxHandler) {
1078+
AuxHandler.memset(Counter, 0, sizeof(int));
1079+
});
10871080

10881081
return Counter.get();
10891082
}
@@ -1178,20 +1171,6 @@ auto make_reduction(RedOutVar RedVar, RestTy &&...Rest) {
11781171

11791172
namespace reduction {
11801173
inline void finalizeHandler(handler &CGH) { CGH.finalize(); }
1181-
template <class FunctorTy> void withAuxHandler(handler &CGH, FunctorTy Func) {
1182-
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1183-
detail::EventImplPtr E = CGH.finalize();
1184-
#else
1185-
event E = CGH.finalize();
1186-
#endif
1187-
handler AuxHandler(CGH.MQueue, CGH.eventNeeded());
1188-
if (!createSyclObjFromImpl<queue>(CGH.MQueue).is_in_order())
1189-
AuxHandler.depends_on(E);
1190-
AuxHandler.copyCodeLoc(CGH);
1191-
Func(AuxHandler);
1192-
CGH.MLastEvent = AuxHandler.finalize();
1193-
return;
1194-
}
11951174
} // namespace reduction
11961175

11971176
// This method is used for implementation of parallel_for accepting 1 reduction.
@@ -1785,7 +1764,7 @@ struct NDRangeReduction<
17851764
"the reduction.");
17861765
size_t NWorkItems = NDRange.get_group_range().size();
17871766
while (NWorkItems > 1) {
1788-
reduction::withAuxHandler(CGH, [&](handler &AuxHandler) {
1767+
HandlerAccess::postProcess(CGH, [&](handler &AuxHandler) {
17891768
size_t NElements = Reduction::num_elements;
17901769
size_t NWorkGroups;
17911770
size_t WGSize = reduComputeWGSize(NWorkItems, MaxWGSize, NWorkGroups);
@@ -1837,7 +1816,7 @@ struct NDRangeReduction<
18371816
} // end while (NWorkItems > 1)
18381817

18391818
if constexpr (Reduction::is_usm) {
1840-
reduction::withAuxHandler(CGH, [&](handler &CopyHandler) {
1819+
HandlerAccess::postProcess(CGH, [&](handler &CopyHandler) {
18411820
reduSaveFinalResultToUserMem<KernelName>(CopyHandler, Redu);
18421821
});
18431822
}
@@ -1969,7 +1948,7 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
19691948
size_t WGSize = reduComputeWGSize(NWorkItems, MaxWGSize, NWorkGroups);
19701949

19711950
auto Rest = [&](auto KernelTag) {
1972-
reduction::withAuxHandler(CGH, [&](handler &AuxHandler) {
1951+
HandlerAccess::postProcess(CGH, [&](handler &AuxHandler) {
19731952
// We can deduce IsOneWG from the tag type.
19741953
constexpr bool IsOneWG =
19751954
std::is_same_v<std::remove_reference_t<decltype(KernelTag)>,
@@ -2650,7 +2629,7 @@ template <> struct NDRangeReduction<reduction::strategy::multi> {
26502629

26512630
size_t NWorkItems = NDRange.get_group_range().size();
26522631
while (NWorkItems > 1) {
2653-
reduction::withAuxHandler(CGH, [&](handler &AuxHandler) {
2632+
HandlerAccess::postProcess(CGH, [&](handler &AuxHandler) {
26542633
NWorkItems = reduAuxCGFunc<KernelName, decltype(KernelFunc)>(
26552634
AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
26562635
});

sycl/source/detail/reduction.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ __SYCL_EXPORT size_t reduGetPreferredWGSize(std::shared_ptr<queue_impl> &Queue,
177177
return reduGetMaxWGSize(Queue, LocalMemBytesPerWorkItem);
178178
}
179179

180+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
180181
__SYCL_EXPORT void
181182
addCounterInit(handler &CGH, std::shared_ptr<sycl::detail::queue_impl> &Queue,
182183
std::shared_ptr<int> &Counter) {
@@ -189,6 +190,7 @@ addCounterInit(handler &CGH, std::shared_ptr<sycl::detail::queue_impl> &Queue,
189190
EventImpl->setHandle(UREvent);
190191
CGH.depends_on(createSyclObjFromImpl<event>(EventImpl));
191192
}
193+
#endif
192194

193195
__SYCL_EXPORT void verifyReductionProps(const property_list &Props) {
194196
auto CheckDataLessProperties = [](int PropertyKind) {

sycl/source/handler.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2417,5 +2417,6 @@ void handler::copyCodeLoc(const handler &other) {
24172417
impl->MIsTopCodeLoc = other.impl->MIsTopCodeLoc;
24182418
}
24192419

2420+
queue handler::getQueue() { return createSyclObjFromImpl<queue>(MQueue); }
24202421
} // namespace _V1
24212422
} // namespace sycl

sycl/test/abi/sycl_symbols_linux.dump

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3608,6 +3608,7 @@ _ZN4sycl3_V17handler6memcpyEPvPKvm
36083608
_ZN4sycl3_V17handler6memsetEPvim
36093609
_ZN4sycl3_V17handler7setTypeENS0_6detail6CGTypeE
36103610
_ZN4sycl3_V17handler8finalizeEv
3611+
_ZN4sycl3_V17handler8getQueueEv
36113612
_ZN4sycl3_V17handler8prefetchEPKvm
36123613
_ZN4sycl3_V17handler9clearArgsEv
36133614
_ZN4sycl3_V17handler9fill_implEPvPKvmm

sycl/test/abi/sycl_symbols_windows.dump

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4103,6 +4103,7 @@
41034103
?getPtr@SampledImageAccessorBaseHost@detail@_V1@sycl@@QEBAPEAXXZ
41044104
?getPtr@UnsampledImageAccessorBaseHost@detail@_V1@sycl@@QEAAPEAXXZ
41054105
?getPtr@UnsampledImageAccessorBaseHost@detail@_V1@sycl@@QEBAPEAXXZ
4106+
?getQueue@handler@_V1@sycl@@AEAA?AVqueue@23@XZ
41064107
?getRowPitch@image_plain@detail@_V1@sycl@@IEBA_KXZ
41074108
?getSampler@SampledImageAccessorBaseHost@detail@_V1@sycl@@QEBA?AUimage_sampler@34@XZ
41084109
?getSampler@image_plain@detail@_V1@sycl@@IEBA?AUimage_sampler@34@XZ

0 commit comments

Comments
 (0)