Skip to content

Commit 4cc9cda

Browse files
[NFC][SYCL] Pass queue_impl by raw ptr in commands.hpp
Continuation of the refactoring efforts in #18715 #18748 #18830 #18907 #18983
1 parent acbabef commit 4cc9cda

17 files changed

+200
-185
lines changed

sycl/source/detail/cg.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -725,14 +725,10 @@ class CGHostTask : public CG {
725725
std::shared_ptr<detail::context_impl> MContext;
726726
std::vector<ArgDesc> MArgs;
727727

728-
CGHostTask(std::shared_ptr<HostTask> HostTask,
729-
std::shared_ptr<detail::queue_impl> Queue,
728+
CGHostTask(std::shared_ptr<HostTask> HostTask, detail::queue_impl *Queue,
730729
std::shared_ptr<detail::context_impl> Context,
731730
std::vector<ArgDesc> Args, CG::StorageInitHelper CGData,
732-
CGType Type, detail::code_location loc = {})
733-
: CG(Type, std::move(CGData), std::move(loc)),
734-
MHostTask(std::move(HostTask)), MQueue(Queue), MContext(Context),
735-
MArgs(std::move(Args)) {}
731+
CGType Type, detail::code_location loc = {});
736732
};
737733

738734
} // namespace detail

sycl/source/detail/graph_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
288288

289289
return std::make_unique<sycl::detail::CGHostTask>(
290290
sycl::detail::CGHostTask(
291-
std::move(HostTaskSPtr), CommandGroupPtr->MQueue,
291+
std::move(HostTaskSPtr), CommandGroupPtr->MQueue.get(),
292292
CommandGroupPtr->MContext, std::move(NewArgs), std::move(Data),
293293
CommandGroupPtr->getType(), Loc));
294294
}

sycl/source/detail/queue_impl.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,12 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
650650
// for in order ones.
651651
void revisitUnenqueuedCommandsState(const EventImplPtr &CompletedHostTask);
652652

653-
static ContextImplPtr getContext(const QueueImplPtr &Queue) {
653+
static ContextImplPtr getContext(queue_impl *Queue) {
654654
return Queue ? Queue->getContextImplPtr() : nullptr;
655655
}
656+
static ContextImplPtr getContext(const QueueImplPtr &Queue) {
657+
return getContext(Queue.get());
658+
}
656659

657660
// Must be called under MMutex protection
658661
void doUnenqueuedCommandCleanup(

sycl/source/detail/scheduler/commands.cpp

Lines changed: 62 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,14 @@ static unsigned long long getQueueID(const std::shared_ptr<queue_impl> &Queue) {
127127
}
128128
#endif
129129

130-
static context_impl *getContext(const QueueImplPtr &Queue) {
130+
static context_impl *getContext(queue_impl *Queue) {
131131
if (Queue)
132132
return &Queue->getContextImpl();
133133
return nullptr;
134134
}
135+
static context_impl *getContext(const std::shared_ptr<queue_impl> &Queue) {
136+
return getContext(Queue.get());
137+
}
135138

136139
#ifdef __SYCL_ENABLE_GNU_DEMANGLING
137140
struct DemangleHandle {
@@ -510,7 +513,7 @@ void Command::waitForPreparedHostEvents() const {
510513
HostEvent->waitInternal();
511514
}
512515

513-
void Command::waitForEvents(QueueImplPtr Queue,
516+
void Command::waitForEvents(queue_impl *Queue,
514517
std::vector<EventImplPtr> &EventImpls,
515518
ur_event_handle_t &Event) {
516519
#ifndef NDEBUG
@@ -566,12 +569,12 @@ void Command::waitForEvents(QueueImplPtr Queue,
566569
/// references to event_impl class members because Command
567570
/// should not outlive the event connected to it.
568571
Command::Command(
569-
CommandType Type, QueueImplPtr Queue,
572+
CommandType Type, queue_impl *Queue,
570573
ur_exp_command_buffer_handle_t CommandBuffer,
571574
const std::vector<ur_exp_command_buffer_sync_point_t> &SyncPoints)
572-
: MQueue(std::move(Queue)),
573-
MEvent(MQueue ? detail::event_impl::create_device_event(*MQueue)
574-
: detail::event_impl::create_incomplete_host_event()),
575+
: MQueue(Queue ? Queue->shared_from_this() : nullptr),
576+
MEvent(Queue ? detail::event_impl::create_device_event(*Queue)
577+
: detail::event_impl::create_incomplete_host_event()),
575578
MPreparedDepsEvents(MEvent->getPreparedDepsEvents()),
576579
MPreparedHostDepsEvents(MEvent->getPreparedHostDepsEvents()), MType(Type),
577580
MCommandBuffer(CommandBuffer), MSyncPointDeps(SyncPoints) {
@@ -1034,7 +1037,7 @@ void Command::copySubmissionCodeLocation() {
10341037
#endif
10351038
}
10361039

1037-
AllocaCommandBase::AllocaCommandBase(CommandType Type, QueueImplPtr Queue,
1040+
AllocaCommandBase::AllocaCommandBase(CommandType Type, queue_impl *Queue,
10381041
Requirement Req,
10391042
AllocaCommandBase *LinkedAllocaCmd,
10401043
bool IsConst)
@@ -1077,10 +1080,10 @@ bool AllocaCommandBase::supportsPostEnqueueCleanup() const { return false; }
10771080

10781081
bool AllocaCommandBase::readyForCleanup() const { return false; }
10791082

1080-
AllocaCommand::AllocaCommand(QueueImplPtr Queue, Requirement Req,
1083+
AllocaCommand::AllocaCommand(queue_impl *Queue, Requirement Req,
10811084
bool InitFromUserData,
10821085
AllocaCommandBase *LinkedAllocaCmd, bool IsConst)
1083-
: AllocaCommandBase(CommandType::ALLOCA, std::move(Queue), std::move(Req),
1086+
: AllocaCommandBase(CommandType::ALLOCA, Queue, std::move(Req),
10841087
LinkedAllocaCmd, IsConst),
10851088
MInitFromUserData(InitFromUserData) {
10861089
// Node event must be created before the dependent edge is added to this
@@ -1115,7 +1118,7 @@ ur_result_t AllocaCommand::enqueueImp() {
11151118

11161119
if (!MQueue) {
11171120
// Do not need to make allocation if we have a linked device allocation
1118-
Command::waitForEvents(MQueue, EventImpls, UREvent);
1121+
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
11191122
MEvent->setHandle(UREvent);
11201123

11211124
return UR_RESULT_SUCCESS;
@@ -1155,12 +1158,11 @@ void AllocaCommand::printDot(std::ostream &Stream) const {
11551158
}
11561159
}
11571160

1158-
AllocaSubBufCommand::AllocaSubBufCommand(QueueImplPtr Queue, Requirement Req,
1161+
AllocaSubBufCommand::AllocaSubBufCommand(queue_impl *Queue, Requirement Req,
11591162
AllocaCommandBase *ParentAlloca,
11601163
std::vector<Command *> &ToEnqueue,
11611164
std::vector<Command *> &ToCleanUp)
1162-
: AllocaCommandBase(CommandType::ALLOCA_SUB_BUF, std::move(Queue),
1163-
std::move(Req),
1165+
: AllocaCommandBase(CommandType::ALLOCA_SUB_BUF, Queue, std::move(Req),
11641166
/*LinkedAllocaCmd*/ nullptr, /*IsConst*/ false),
11651167
MParentAlloca(ParentAlloca) {
11661168
// Node event must be created before the dependent edge
@@ -1241,8 +1243,8 @@ void AllocaSubBufCommand::printDot(std::ostream &Stream) const {
12411243
}
12421244
}
12431245

1244-
ReleaseCommand::ReleaseCommand(QueueImplPtr Queue, AllocaCommandBase *AllocaCmd)
1245-
: Command(CommandType::RELEASE, std::move(Queue)), MAllocaCmd(AllocaCmd) {
1246+
ReleaseCommand::ReleaseCommand(queue_impl *Queue, AllocaCommandBase *AllocaCmd)
1247+
: Command(CommandType::RELEASE, Queue), MAllocaCmd(AllocaCmd) {
12461248
emitInstrumentationDataProxy();
12471249
}
12481250

@@ -1295,9 +1297,9 @@ ur_result_t ReleaseCommand::enqueueImp() {
12951297
}
12961298

12971299
if (NeedUnmap) {
1298-
const QueueImplPtr &Queue = CurAllocaIsHost
1299-
? MAllocaCmd->MLinkedAllocaCmd->getQueue()
1300-
: MAllocaCmd->getQueue();
1300+
queue_impl *Queue = CurAllocaIsHost
1301+
? MAllocaCmd->MLinkedAllocaCmd->getQueue()
1302+
: MAllocaCmd->getQueue();
13011303

13021304
assert(Queue);
13031305

@@ -1328,7 +1330,7 @@ ur_result_t ReleaseCommand::enqueueImp() {
13281330
}
13291331
ur_event_handle_t UREvent = nullptr;
13301332
if (SkipRelease)
1331-
Command::waitForEvents(MQueue, EventImpls, UREvent);
1333+
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
13321334
else {
13331335
if (auto Result = callMemOpHelper(
13341336
MemoryManager::release, getContext(MQueue),
@@ -1366,11 +1368,10 @@ bool ReleaseCommand::supportsPostEnqueueCleanup() const { return false; }
13661368
bool ReleaseCommand::readyForCleanup() const { return false; }
13671369

13681370
MapMemObject::MapMemObject(AllocaCommandBase *SrcAllocaCmd, Requirement Req,
1369-
void **DstPtr, QueueImplPtr Queue,
1371+
void **DstPtr, queue_impl *Queue,
13701372
access::mode MapMode)
1371-
: Command(CommandType::MAP_MEM_OBJ, std::move(Queue)),
1372-
MSrcAllocaCmd(SrcAllocaCmd), MSrcReq(std::move(Req)), MDstPtr(DstPtr),
1373-
MMapMode(MapMode) {
1373+
: Command(CommandType::MAP_MEM_OBJ, Queue), MSrcAllocaCmd(SrcAllocaCmd),
1374+
MSrcReq(std::move(Req)), MDstPtr(DstPtr), MMapMode(MapMode) {
13741375
emitInstrumentationDataProxy();
13751376
}
13761377

@@ -1430,9 +1431,9 @@ void MapMemObject::printDot(std::ostream &Stream) const {
14301431
}
14311432

14321433
UnMapMemObject::UnMapMemObject(AllocaCommandBase *DstAllocaCmd, Requirement Req,
1433-
void **SrcPtr, QueueImplPtr Queue)
1434-
: Command(CommandType::UNMAP_MEM_OBJ, std::move(Queue)),
1435-
MDstAllocaCmd(DstAllocaCmd), MDstReq(std::move(Req)), MSrcPtr(SrcPtr) {
1434+
void **SrcPtr, queue_impl *Queue)
1435+
: Command(CommandType::UNMAP_MEM_OBJ, Queue), MDstAllocaCmd(DstAllocaCmd),
1436+
MDstReq(std::move(Req)), MSrcPtr(SrcPtr) {
14361437
emitInstrumentationDataProxy();
14371438
}
14381439

@@ -1516,11 +1517,11 @@ MemCpyCommand::MemCpyCommand(Requirement SrcReq,
15161517
AllocaCommandBase *SrcAllocaCmd,
15171518
Requirement DstReq,
15181519
AllocaCommandBase *DstAllocaCmd,
1519-
QueueImplPtr SrcQueue, QueueImplPtr DstQueue)
1520-
: Command(CommandType::COPY_MEMORY, std::move(DstQueue)),
1521-
MSrcQueue(SrcQueue), MSrcReq(std::move(SrcReq)),
1522-
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(DstReq)),
1523-
MDstAllocaCmd(DstAllocaCmd) {
1520+
queue_impl *SrcQueue, queue_impl *DstQueue)
1521+
: Command(CommandType::COPY_MEMORY, DstQueue),
1522+
MSrcQueue(SrcQueue ? SrcQueue->shared_from_this() : nullptr),
1523+
MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd),
1524+
MDstReq(std::move(DstReq)), MDstAllocaCmd(DstAllocaCmd) {
15241525
if (MSrcQueue) {
15251526
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
15261527
}
@@ -1652,7 +1653,7 @@ ur_result_t UpdateHostRequirementCommand::enqueueImp() {
16521653
waitForPreparedHostEvents();
16531654
std::vector<EventImplPtr> EventImpls = MPreparedDepsEvents;
16541655
ur_event_handle_t UREvent = nullptr;
1655-
Command::waitForEvents(MQueue, EventImpls, UREvent);
1656+
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
16561657
MEvent->setHandle(UREvent);
16571658

16581659
assert(MSrcAllocaCmd && "Expected valid alloca command");
@@ -1689,11 +1690,11 @@ void UpdateHostRequirementCommand::printDot(std::ostream &Stream) const {
16891690
MemCpyCommandHost::MemCpyCommandHost(Requirement SrcReq,
16901691
AllocaCommandBase *SrcAllocaCmd,
16911692
Requirement DstReq, void **DstPtr,
1692-
QueueImplPtr SrcQueue,
1693-
QueueImplPtr DstQueue)
1694-
: Command(CommandType::COPY_MEMORY, std::move(DstQueue)),
1695-
MSrcQueue(SrcQueue), MSrcReq(std::move(SrcReq)),
1696-
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(DstReq)), MDstPtr(DstPtr) {
1693+
queue_impl *SrcQueue, queue_impl *DstQueue)
1694+
: Command(CommandType::COPY_MEMORY, DstQueue),
1695+
MSrcQueue(SrcQueue ? SrcQueue->shared_from_this() : nullptr),
1696+
MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd),
1697+
MDstReq(std::move(DstReq)), MDstPtr(DstPtr) {
16971698
if (MSrcQueue) {
16981699
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
16991700
}
@@ -1735,7 +1736,7 @@ ContextImplPtr MemCpyCommandHost::getWorkerContext() const {
17351736
}
17361737

17371738
ur_result_t MemCpyCommandHost::enqueueImp() {
1738-
const QueueImplPtr &Queue = MWorkerQueue;
1739+
queue_impl *Queue = MWorkerQueue.get();
17391740
waitForPreparedHostEvents();
17401741
std::vector<EventImplPtr> EventImpls = MPreparedDepsEvents;
17411742
std::vector<ur_event_handle_t> RawEvents = getUrEvents(EventImpls);
@@ -1774,7 +1775,7 @@ EmptyCommand::EmptyCommand() : Command(CommandType::EMPTY_TASK, nullptr) {
17741775
ur_result_t EmptyCommand::enqueueImp() {
17751776
waitForPreparedHostEvents();
17761777
ur_event_handle_t UREvent = nullptr;
1777-
waitForEvents(MQueue, MPreparedDepsEvents, UREvent);
1778+
waitForEvents(MQueue.get(), MPreparedDepsEvents, UREvent);
17781779
MEvent->setHandle(UREvent);
17791780
return UR_RESULT_SUCCESS;
17801781
}
@@ -1858,9 +1859,9 @@ void MemCpyCommandHost::printDot(std::ostream &Stream) const {
18581859
}
18591860

18601861
UpdateHostRequirementCommand::UpdateHostRequirementCommand(
1861-
QueueImplPtr Queue, Requirement Req, AllocaCommandBase *SrcAllocaCmd,
1862+
queue_impl *Queue, Requirement Req, AllocaCommandBase *SrcAllocaCmd,
18621863
void **DstPtr)
1863-
: Command(CommandType::UPDATE_REQUIREMENT, std::move(Queue)),
1864+
: Command(CommandType::UPDATE_REQUIREMENT, Queue),
18641865
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(Req)), MDstPtr(DstPtr) {
18651866

18661867
emitInstrumentationDataProxy();
@@ -1956,11 +1957,10 @@ static std::string_view cgTypeToString(detail::CGType Type) {
19561957
}
19571958

19581959
ExecCGCommand::ExecCGCommand(
1959-
std::unique_ptr<detail::CG> CommandGroup, QueueImplPtr Queue,
1960+
std::unique_ptr<detail::CG> CommandGroup, queue_impl *Queue,
19601961
bool EventNeeded, ur_exp_command_buffer_handle_t CommandBuffer,
19611962
const std::vector<ur_exp_command_buffer_sync_point_t> &Dependencies)
1962-
: Command(CommandType::RUN_CG, std::move(Queue), CommandBuffer,
1963-
Dependencies),
1963+
: Command(CommandType::RUN_CG, Queue, CommandBuffer, Dependencies),
19641964
MEventNeeded(EventNeeded), MCommandGroup(std::move(CommandGroup)) {
19651965
if (MCommandGroup->getType() == detail::CGType::CodeplayHostTask) {
19661966
MEvent->setSubmittedQueue(
@@ -2777,20 +2777,18 @@ void enqueueImpKernel(
27772777
}
27782778
}
27792779

2780-
ur_result_t enqueueReadWriteHostPipe(const QueueImplPtr &Queue,
2780+
ur_result_t enqueueReadWriteHostPipe(queue_impl &Queue,
27812781
const std::string &PipeName, bool blocking,
27822782
void *ptr, size_t size,
27832783
std::vector<ur_event_handle_t> &RawEvents,
27842784
detail::event_impl *OutEventImpl,
27852785
bool read) {
2786-
assert(Queue &&
2787-
"ReadWrite host pipe submissions should have an associated queue");
27882786
detail::HostPipeMapEntry *hostPipeEntry =
27892787
ProgramManager::getInstance().getHostPipeEntry(PipeName);
27902788

27912789
ur_program_handle_t Program = nullptr;
2792-
device Device = Queue->get_device();
2793-
ContextImplPtr ContextImpl = Queue->getContextImplPtr();
2790+
device Device = Queue.get_device();
2791+
ContextImplPtr ContextImpl = Queue.getContextImplPtr();
27942792
std::optional<ur_program_handle_t> CachedProgram =
27952793
ContextImpl->getProgramForHostPipe(Device, hostPipeEntry);
27962794
if (CachedProgram)
@@ -2799,17 +2797,16 @@ ur_result_t enqueueReadWriteHostPipe(const QueueImplPtr &Queue,
27992797
// If there was no cached program, build one.
28002798
device_image_plain devImgPlain =
28012799
ProgramManager::getInstance().getDeviceImageFromBinaryImage(
2802-
hostPipeEntry->getDevBinImage(), Queue->get_context(),
2803-
Queue->get_device());
2800+
hostPipeEntry->getDevBinImage(), Queue.get_context(), Device);
28042801
device_image_plain BuiltImage = ProgramManager::getInstance().build(
28052802
std::move(devImgPlain), {std::move(Device)}, {});
28062803
Program = getSyclObjImpl(BuiltImage)->get_ur_program_ref();
28072804
}
28082805
assert(Program && "Program for this hostpipe is not compiled.");
28092806

2810-
const AdapterPtr &Adapter = Queue->getAdapter();
2807+
const AdapterPtr &Adapter = Queue.getAdapter();
28112808

2812-
ur_queue_handle_t ur_q = Queue->getHandleRef();
2809+
ur_queue_handle_t ur_q = Queue.getHandleRef();
28132810
ur_result_t Error;
28142811

28152812
ur_event_handle_t UREvent = nullptr;
@@ -3667,7 +3664,7 @@ ur_result_t ExecCGCommand::enqueueImpQueue() {
36673664
if (!EventImpl) {
36683665
EventImpl = MEvent.get();
36693666
}
3670-
return enqueueReadWriteHostPipe(MQueue, pipeName, blocking, hostPtr,
3667+
return enqueueReadWriteHostPipe(*MQueue, pipeName, blocking, hostPtr,
36713668
typeSize, RawEvents, EventImpl, read);
36723669
}
36733670
case CGType::ExecCommandBuffer: {
@@ -3802,7 +3799,7 @@ bool ExecCGCommand::readyForCleanup() const {
38023799
}
38033800

38043801
UpdateCommandBufferCommand::UpdateCommandBufferCommand(
3805-
QueueImplPtr Queue,
3802+
queue_impl *Queue,
38063803
ext::oneapi::experimental::detail::exec_graph_impl *Graph,
38073804
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
38083805
Nodes)
@@ -3813,7 +3810,7 @@ ur_result_t UpdateCommandBufferCommand::enqueueImp() {
38133810
waitForPreparedHostEvents();
38143811
std::vector<EventImplPtr> EventImpls = MPreparedDepsEvents;
38153812
ur_event_handle_t UREvent = nullptr;
3816-
Command::waitForEvents(MQueue, EventImpls, UREvent);
3813+
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
38173814
MEvent->setHandle(UREvent);
38183815

38193816
auto CheckAndFindAlloca = [](Requirement *Req, const DepDesc &Dep) {
@@ -3885,6 +3882,15 @@ void UpdateCommandBufferCommand::printDot(std::ostream &Stream) const {
38853882
void UpdateCommandBufferCommand::emitInstrumentationData() {}
38863883
bool UpdateCommandBufferCommand::producesPiEvent() const { return false; }
38873884

3885+
CGHostTask::CGHostTask(std::shared_ptr<HostTask> HostTask,
3886+
detail::queue_impl *Queue,
3887+
std::shared_ptr<detail::context_impl> Context,
3888+
std::vector<ArgDesc> Args, CG::StorageInitHelper CGData,
3889+
CGType Type, detail::code_location loc)
3890+
: CG(Type, std::move(CGData), std::move(loc)),
3891+
MHostTask(std::move(HostTask)),
3892+
MQueue(Queue ? Queue->shared_from_this() : nullptr), MContext(Context),
3893+
MArgs(std::move(Args)) {}
38883894
} // namespace detail
38893895
} // namespace _V1
38903896
} // namespace sycl

0 commit comments

Comments
 (0)