Skip to content

[NFC][SYCL][Graph] Update some maps to use raw node_impl * #19334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sycl/source/detail/async_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
// If this is being recorded from an in-order queue we need to get the last
// in-order node if any, since this will later become a dependency of the
// node being processed here.
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue);
if (detail::node_impl *LastInOrderNode = Graph->getLastInorderNode(Queue);
LastInOrderNode) {
DepNodes.push_back(LastInOrderNode);
DepNodes.push_back(LastInOrderNode->shared_from_this());
}
return DepNodes;
}
Expand Down
35 changes: 17 additions & 18 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ void exec_graph_impl::makePartitions() {
const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
for (auto &Node : MNodeStorage) {
if (Node->MPartitionNum == i) {
MPartitionNodes[Node] = PartitionFinalNum;
MPartitionNodes[Node.get()] = PartitionFinalNum;
if (isPartitionRoot(Node)) {
Partition->MRoots.insert(Node);
if (Node->MCGType == CGType::CodeplayHostTask) {
Expand Down Expand Up @@ -290,8 +290,7 @@ void exec_graph_impl::makePartitions() {
for (auto const &Root : Partition->MRoots) {
auto RootNode = Root.lock();
for (node_impl &NodeDep : RootNode->predecessors()) {
auto &Predecessor =
MPartitions[MPartitionNodes[NodeDep.shared_from_this()]];
auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
Partition->MPredecessors.push_back(Predecessor.get());
Predecessor->MSuccessors.push_back(Partition.get());
}
Expand Down Expand Up @@ -610,8 +609,7 @@ bool graph_impl::checkForCycles() {
return CycleFound;
}

std::shared_ptr<node_impl>
graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
node_impl *graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
if (!Queue) {
assert(0 ==
MInorderQueueMap.count(std::weak_ptr<sycl::detail::queue_impl>{}));
Expand All @@ -624,8 +622,8 @@ graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
}

void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
std::shared_ptr<node_impl> Node) {
MInorderQueueMap[Queue.weak_from_this()] = std::move(Node);
node_impl &Node) {
MInorderQueueMap[Queue.weak_from_this()] = &Node;
}

void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
Expand Down Expand Up @@ -728,9 +726,9 @@ void exec_graph_impl::findRealDeps(
} else {
auto CurrentNodePtr = CurrentNode.shared_from_this();
// Verify if CurrentNode belong the the same partition
if (MPartitionNodes[CurrentNodePtr] == ReferencePartitionNum) {
if (MPartitionNodes[&CurrentNode] == ReferencePartitionNum) {
// Verify that the sync point has actually been set for this node.
auto SyncPoint = MSyncPoints.find(CurrentNodePtr);
auto SyncPoint = MSyncPoints.find(&CurrentNode);
assert(SyncPoint != MSyncPoints.end() &&
"No sync point has been set for node dependency.");
// Check if the dependency has already been added.
Expand All @@ -749,7 +747,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
std::shared_ptr<node_impl> Node) {
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
for (node_impl &N : Node->predecessors()) {
findRealDeps(Deps, N, MPartitionNodes[Node]);
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
}
ur_exp_command_buffer_sync_point_t NewSyncPoint;
ur_exp_command_buffer_command_handle_t NewCommand = 0;
Expand Down Expand Up @@ -782,7 +780,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr, nullptr);

if (MIsUpdatable) {
MCommandMap[Node] = NewCommand;
MCommandMap[Node.get()] = NewCommand;
}

if (Res != UR_RESULT_SUCCESS) {
Expand All @@ -805,7 +803,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,

std::vector<ur_exp_command_buffer_sync_point_t> Deps;
for (node_impl &N : Node->predecessors()) {
findRealDeps(Deps, N, MPartitionNodes[Node]);
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
}

sycl::detail::EventImplPtr Event =
Expand All @@ -814,7 +812,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
/*EventNeeded=*/true, CommandBuffer, Deps);

if (MIsUpdatable) {
MCommandMap[Node] = Event->getCommandBufferCommand();
MCommandMap[Node.get()] = Event->getCommandBufferCommand();
}

return Event->getSyncPoint();
Expand All @@ -830,7 +828,8 @@ void exec_graph_impl::buildRequirements() {
Node->MCommandGroup->getRequirements().begin(),
Node->MCommandGroup->getRequirements().end());

std::shared_ptr<partition> &Partition = MPartitions[MPartitionNodes[Node]];
std::shared_ptr<partition> &Partition =
MPartitions[MPartitionNodes[Node.get()]];

Partition->MRequirements.insert(
Partition->MRequirements.end(),
Expand Down Expand Up @@ -877,10 +876,10 @@ void exec_graph_impl::createCommandBuffers(
Node->MCommandGroup.get())
->MStreams.size() ==
0) {
MSyncPoints[Node] =
MSyncPoints[Node.get()] =
enqueueNodeDirect(MContext, DeviceImpl, OutCommandBuffer, Node);
} else {
MSyncPoints[Node] = enqueueNode(OutCommandBuffer, Node);
MSyncPoints[Node.get()] = enqueueNode(OutCommandBuffer, Node);
}
}

Expand Down Expand Up @@ -1726,7 +1725,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
auto ExecNode = MIDCache.find(Node->MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");

auto Command = MCommandMap.find(ExecNode->second);
auto Command = MCommandMap.find(ExecNode->second.get());
assert(Command != MCommandMap.end());
UpdateDesc.hCommand = Command->second;

Expand Down Expand Up @@ -1756,7 +1755,7 @@ exec_graph_impl::getURUpdatableNodes(

auto ExecNode = MIDCache.find(Node->MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
auto PartitionIndex = MPartitionNodes.find(ExecNode->second);
auto PartitionIndex = MPartitionNodes.find(ExecNode->second.get());
assert(PartitionIndex != MPartitionNodes.end());
PartitionedNodes[PartitionIndex->second].push_back(Node);
}
Expand Down
25 changes: 10 additions & 15 deletions sycl/source/detail/graph/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @param Queue In-order queue to find the last node added to the graph from.
/// @return Last node in this graph added from \p Queue recording, or empty
/// shared pointer if none.
std::shared_ptr<node_impl>
getLastInorderNode(sycl::detail::queue_impl *Queue);
node_impl *getLastInorderNode(sycl::detail::queue_impl *Queue);

/// Track the last node added to this graph from an in-order queue.
/// @param Queue In-order queue to register \p Node for.
/// @param Node Last node that was added to this graph from \p Queue.
void setLastInorderNode(sycl::detail::queue_impl &Queue,
std::shared_ptr<node_impl> Node);
void setLastInorderNode(sycl::detail::queue_impl &Queue, node_impl &Node);

/// Prints the contents of the graph to a text file in DOT format.
/// @param FilePath Path to the output file.
Expand Down Expand Up @@ -465,15 +463,14 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @param[in] Queue The queue the barrier was recorded from.
/// @param[in] BarrierNodeImpl The created barrier node.
void setBarrierDep(std::weak_ptr<sycl::detail::queue_impl> Queue,
std::shared_ptr<node_impl> BarrierNodeImpl) {
MBarrierDependencyMap[Queue] = BarrierNodeImpl;
node_impl &BarrierNodeImpl) {
MBarrierDependencyMap[Queue] = &BarrierNodeImpl;
}

/// Get the last barrier node that was submitted to the queue.
/// @param[in] Queue The queue to find the last barrier node of. An empty
/// shared_ptr is returned if no barrier node has been recorded to the queue.
std::shared_ptr<node_impl>
getBarrierDep(std::weak_ptr<sycl::detail::queue_impl> Queue) {
node_impl *getBarrierDep(std::weak_ptr<sycl::detail::queue_impl> Queue) {
return MBarrierDependencyMap[Queue];
}

Expand Down Expand Up @@ -553,7 +550,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// Map for every in-order queue thats recorded a node to the graph, what
/// the last node added was. We can use this to create new edges on the last
/// node if any more nodes are added to the graph from the queue.
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
std::map<std::weak_ptr<sycl::detail::queue_impl>, node_impl *,
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
MInorderQueueMap;
/// Controls whether we skip the cycle checks in makeEdge, set by the presence
Expand All @@ -568,7 +565,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {

/// Mapping from queues to barrier nodes. For each queue the last barrier
/// node recorded to the graph from the queue is stored.
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
std::map<std::weak_ptr<sycl::detail::queue_impl>, node_impl *,
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
MBarrierDependencyMap;
/// Graph memory pool for handling graph-owned memory allocations for this
Expand Down Expand Up @@ -886,14 +883,13 @@ class exec_graph_impl {
std::shared_ptr<graph_impl> MGraphImpl;
/// Map of nodes in the exec graph to the sync point representing their
/// execution in the command graph.
std::unordered_map<std::shared_ptr<node_impl>,
ur_exp_command_buffer_sync_point_t>
std::unordered_map<node_impl *, ur_exp_command_buffer_sync_point_t>
MSyncPoints;
/// Sycl queue impl ptr associated with this graph.
std::shared_ptr<sycl::detail::queue_impl> MQueueImpl;
/// Map of nodes in the exec graph to the partition number to which they
/// belong.
std::unordered_map<std::shared_ptr<node_impl>, int> MPartitionNodes;
std::unordered_map<node_impl *, int> MPartitionNodes;
/// Device associated with this executable graph.
sycl::device MDevice;
/// Context associated with this executable graph.
Expand All @@ -909,8 +905,7 @@ class exec_graph_impl {
/// Storage for copies of nodes from the original modifiable graph.
std::vector<std::shared_ptr<node_impl>> MNodeStorage;
/// Map of nodes to their associated UR command handles.
std::unordered_map<std::shared_ptr<node_impl>,
ur_exp_command_buffer_command_handle_t>
std::unordered_map<node_impl *, ur_exp_command_buffer_command_handle_t>
MCommandMap;
/// List of partition without any predecessors in this exec graph.
std::vector<std::weak_ptr<partition>> MRootPartitions;
Expand Down
16 changes: 9 additions & 7 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -888,28 +888,30 @@ event handler::finalize() {
// node can set it as a predecessor.
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
Deps;
if (auto DependentNode = GraphImpl->getLastInorderNode(Queue)) {
Deps.push_back(std::move(DependentNode));
if (ext::oneapi::experimental::detail::node_impl *DependentNode =
GraphImpl->getLastInorderNode(Queue)) {
Deps.push_back(DependentNode->shared_from_this());
}
NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps);

// If we are recording an in-order queue remember the new node, so it
// can be used as a dependency for any more nodes recorded from this
// queue.
GraphImpl->setLastInorderNode(*Queue, NodeImpl);
GraphImpl->setLastInorderNode(*Queue, *NodeImpl);
} else {
auto LastBarrierRecordedFromQueue =
GraphImpl->getBarrierDep(Queue->weak_from_this());
ext::oneapi::experimental::detail::node_impl
*LastBarrierRecordedFromQueue =
GraphImpl->getBarrierDep(Queue->weak_from_this());
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
Deps;

if (LastBarrierRecordedFromQueue) {
Deps.push_back(LastBarrierRecordedFromQueue);
Deps.push_back(LastBarrierRecordedFromQueue->shared_from_this());
}
NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps);

if (NodeImpl->MCGType == sycl::detail::CGType::Barrier) {
GraphImpl->setBarrierDep(Queue->weak_from_this(), NodeImpl);
GraphImpl->setBarrierDep(Queue->weak_from_this(), *NodeImpl);
}
}

Expand Down
Loading