diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index 5928cec31f784..580607ba3b73d 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -409,22 +409,19 @@ void graph_impl::markCGMemObjs( } } -std::shared_ptr graph_impl::add(nodes_range Deps) { - const std::shared_ptr &NodeImpl = std::make_shared(); - - MNodeStorage.push_back(NodeImpl); +node_impl &graph_impl::add(nodes_range Deps) { + node_impl &NodeImpl = createNode(); addDepsToNode(NodeImpl, Deps); // Add an event associated with this explicit node for mixed usage addEventForNode(sycl::detail::event_impl::create_completed_host_event(), - *NodeImpl); + NodeImpl); return NodeImpl; } -std::shared_ptr -graph_impl::add(std::function CGF, - const std::vector &Args, - std::vector> &Deps) { +node_impl &graph_impl::add(std::function CGF, + const std::vector &Args, + nodes_range Deps) { (void)Args; #ifdef __INTEL_PREVIEW_BREAKING_CHANGES detail::handler_impl HandlerImpl{*this}; @@ -435,7 +432,8 @@ graph_impl::add(std::function CGF, // Pass the node deps to the handler so they are available when processing the // CGF, need for async_malloc nodes. - Handler.impl->MNodeDeps = Deps; + for (node_impl &N : Deps) + Handler.impl->MNodeDeps.push_back(N.shared_from_this()); #if XPTI_ENABLE_INSTRUMENTATION // Save code location if one was set in TLS. @@ -471,12 +469,12 @@ graph_impl::add(std::function CGF, : ext::oneapi::experimental::detail::getNodeTypeFromCG( Handler.getType()); - auto NodeImpl = + node_impl &NodeImpl = this->add(NodeType, std::move(Handler.impl->MGraphNodeCG), Deps); // Add an event associated with this explicit node for mixed usage addEventForNode(sycl::detail::event_impl::create_completed_host_event(), - *NodeImpl); + NodeImpl); // Retrieve any dynamic parameters which have been registered in the CGF and // register the actual nodes with them. @@ -489,16 +487,15 @@ graph_impl::add(std::function CGF, } for (auto &[DynamicParam, ArgIndex] : DynamicParams) { - DynamicParam->registerNode(NodeImpl, ArgIndex); + DynamicParam->registerNode(NodeImpl.shared_from_this(), ArgIndex); } return NodeImpl; } -std::shared_ptr -graph_impl::add(node_type NodeType, - std::shared_ptr CommandGroup, - nodes_range Deps) { +node_impl &graph_impl::add(node_type NodeType, + std::shared_ptr CommandGroup, + nodes_range Deps) { // A unique set of dependencies obtained by checking requirements and events std::set UniqueDeps = getCGEdges(CommandGroup); @@ -506,9 +503,7 @@ graph_impl::add(node_type NodeType, // Track and mark the memory objects being used by the graph. markCGMemObjs(CommandGroup); - const std::shared_ptr &NodeImpl = - std::make_shared(NodeType, std::move(CommandGroup)); - MNodeStorage.push_back(NodeImpl); + node_impl &NodeImpl = createNode(NodeType, std::move(CommandGroup)); // Add any deps determined from requirements and events into the dependency // list @@ -516,17 +511,16 @@ graph_impl::add(node_type NodeType, addDepsToNode(NodeImpl, UniqueDeps); if (NodeType == node_type::async_free) { - auto AsyncFreeCG = - static_cast(NodeImpl->MCommandGroup.get()); + auto AsyncFreeCG = static_cast(NodeImpl.MCommandGroup.get()); // If this is an async free node mark that it is now available for reuse, // and pass the async free node for tracking. - MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), *NodeImpl); + MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), NodeImpl); } return NodeImpl; } -std::shared_ptr +node_impl & graph_impl::add(std::shared_ptr &DynCGImpl, nodes_range Deps) { // Set of Dependent nodes based on CG event and accessor dependencies. @@ -550,15 +544,14 @@ graph_impl::add(std::shared_ptr &DynCGImpl, const auto &ActiveKernel = DynCGImpl->getActiveCG(); node_type NodeType = ext::oneapi::experimental::detail::getNodeTypeFromCG(DynCGImpl->MCGType); - std::shared_ptr NodeImpl = - add(NodeType, ActiveKernel, Deps); + detail::node_impl &NodeImpl = add(NodeType, ActiveKernel, Deps); // Add an event associated with this explicit node for mixed usage addEventForNode(sycl::detail::event_impl::create_completed_host_event(), - *NodeImpl); + NodeImpl); // Track the dynamic command-group used inside the node object - DynCGImpl->MNodes.push_back(NodeImpl); + DynCGImpl->MNodes.push_back(NodeImpl.shared_from_this()); return NodeImpl; } @@ -651,7 +644,7 @@ void graph_impl::makeEdge(std::shared_ptr Src, bool DestWasGraphRoot = Dest->MPredecessors.size() == 0; // We need to add the edges first before checking for cycles - Src->registerSuccessor(Dest); + Src->registerSuccessor(*Dest); bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1; if (DestLostRootStatus) { @@ -1264,7 +1257,7 @@ void exec_graph_impl::duplicateNodes() { // Look through all the original node successors, find their copies and // register those as successors with the current copied node for (node_impl &NextNode : OriginalNode->successors()) { - auto Successor = NodesMap.at(NextNode.shared_from_this()); + node_impl &Successor = *NodesMap.at(NextNode.shared_from_this()); NodeCopy->registerSuccessor(Successor); } } @@ -1306,7 +1299,8 @@ void exec_graph_impl::duplicateNodes() { auto NodeCopy = NewSubgraphNodes[i]; for (node_impl &NextNode : SubgraphNode->successors()) { - auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this()); + node_impl &Successor = + *SubgraphNodesMap.at(NextNode.shared_from_this()); NodeCopy->registerSuccessor(Successor); } } @@ -1340,7 +1334,7 @@ void exec_graph_impl::duplicateNodes() { // Add all input nodes from the subgraph as successors for this node // instead for (auto &Input : Inputs) { - PredNode.registerSuccessor(Input); + PredNode.registerSuccessor(*Input); } } @@ -1359,7 +1353,7 @@ void exec_graph_impl::duplicateNodes() { // Add all Output nodes from the subgraph as predecessors for this node // instead for (auto &Output : Outputs) { - Output->registerSuccessor(SuccNode.shared_from_this()); + Output->registerSuccessor(SuccNode); } } @@ -1840,38 +1834,25 @@ node modifiable_command_graph::addImpl(dynamic_command_group &DynCGF, "dynamic command-group."); } - std::vector> DepImpls; - for (auto &D : Deps) { - DepImpls.push_back(sycl::detail::getSyclObjImpl(D)); - } - graph_impl::WriteLock Lock(impl->MMutex); - std::shared_ptr NodeImpl = impl->add(DynCGFImpl, DepImpls); - return sycl::detail::createSyclObjFromImpl(std::move(NodeImpl)); + detail::node_impl &NodeImpl = impl->add(DynCGFImpl, Deps); + return sycl::detail::createSyclObjFromImpl(NodeImpl); } node modifiable_command_graph::addImpl(const std::vector &Deps) { impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function"); - std::vector> DepImpls; - for (auto &D : Deps) { - DepImpls.push_back(sycl::detail::getSyclObjImpl(D)); - } graph_impl::WriteLock Lock(impl->MMutex); - std::shared_ptr NodeImpl = impl->add(DepImpls); - return sycl::detail::createSyclObjFromImpl(std::move(NodeImpl)); + detail::node_impl &NodeImpl = impl->add(Deps); + return sycl::detail::createSyclObjFromImpl(NodeImpl); } node modifiable_command_graph::addImpl(std::function CGF, const std::vector &Deps) { impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function"); - std::vector> DepImpls; - for (auto &D : Deps) { - DepImpls.push_back(sycl::detail::getSyclObjImpl(D)); - } - std::shared_ptr NodeImpl = impl->add(CGF, {}, DepImpls); - return sycl::detail::createSyclObjFromImpl(std::move(NodeImpl)); + detail::node_impl &NodeImpl = impl->add(CGF, {}, Deps); + return sycl::detail::createSyclObjFromImpl(NodeImpl); } void modifiable_command_graph::addGraphLeafDependencies(node Node) { diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index 85a35a93cb2c2..69e8835001652 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -147,30 +147,30 @@ class graph_impl : public std::enable_shared_from_this { /// @param CommandGroup The CG which stores all information for this node. /// @param Deps Dependencies of the created node. /// @return Created node in the graph. - std::shared_ptr add(node_type NodeType, - std::shared_ptr CommandGroup, - nodes_range Deps); + node_impl &add(node_type NodeType, + std::shared_ptr CommandGroup, + nodes_range Deps); /// Create a CGF node in the graph. /// @param CGF Command-group function to create node with. /// @param Args Node arguments. /// @param Deps Dependencies of the created node. /// @return Created node in the graph. - std::shared_ptr add(std::function CGF, - const std::vector &Args, - std::vector> &Deps); + node_impl &add(std::function CGF, + const std::vector &Args, + nodes_range Deps); /// Create an empty node in the graph. /// @param Deps List of predecessor nodes. /// @return Created node in the graph. - std::shared_ptr add(nodes_range Deps); + node_impl &add(nodes_range Deps); /// Create a dynamic command-group node in the graph. /// @param DynCGImpl Dynamic command-group used to create node. /// @param Deps List of predecessor nodes. /// @return Created node in the graph. - std::shared_ptr - add(std::shared_ptr &DynCGImpl, nodes_range Deps); + node_impl &add(std::shared_ptr &DynCGImpl, + nodes_range Deps); /// Add a queue to the set of queues which are currently recording to this /// graph. @@ -511,6 +511,12 @@ class graph_impl : public std::enable_shared_from_this { } private: + template node_impl &createNode(Ts &&...Args) { + MNodeStorage.push_back( + std::make_shared(std::forward(Args)...)); + return *MNodeStorage.back(); + } + /// Check the graph for cycles by performing a depth-first search of the /// graph. If a node is visited more than once in a given path through the /// graph, a cycle is present and the search ends immediately. @@ -525,13 +531,13 @@ class graph_impl : public std::enable_shared_from_this { /// added as a root node. /// @param Node The node to add deps for /// @param Deps List of dependent nodes - void addDepsToNode(const std::shared_ptr &Node, nodes_range Deps) { + void addDepsToNode(node_impl &Node, nodes_range Deps) { for (node_impl &N : Deps) { N.registerSuccessor(Node); - this->removeRoot(*Node); + this->removeRoot(Node); } - if (Node->MPredecessors.empty()) { - this->addRoot(*Node); + if (Node.MPredecessors.empty()) { + this->addRoot(Node); } } diff --git a/sycl/source/detail/graph/node_impl.hpp b/sycl/source/detail/graph/node_impl.hpp index ea5482db3c60b..aed90d3f04906 100644 --- a/sycl/source/detail/graph/node_impl.hpp +++ b/sycl/source/detail/graph/node_impl.hpp @@ -15,6 +15,8 @@ #include // for CGType #include // for kernel_param_kind_t +#include // for node + #include #include #include @@ -27,8 +29,6 @@ inline namespace _V1 { namespace ext { namespace oneapi { namespace experimental { -// Forward declarations -class node; namespace detail { // Forward declarations @@ -122,27 +122,27 @@ class node_impl : public std::enable_shared_from_this { /// Add successor to the node. /// @param Node Node to add as a successor. - void registerSuccessor(const std::shared_ptr &Node) { + void registerSuccessor(node_impl &Node) { if (std::find_if(MSuccessors.begin(), MSuccessors.end(), - [Node](const std::weak_ptr &Ptr) { - return Ptr.lock() == Node; + [&Node](const std::weak_ptr &Ptr) { + return Ptr.lock().get() == &Node; }) != MSuccessors.end()) { return; } - MSuccessors.push_back(Node); - Node->registerPredecessor(shared_from_this()); + MSuccessors.push_back(Node.weak_from_this()); + Node.registerPredecessor(*this); } /// Add predecessor to the node. /// @param Node Node to add as a predecessor. - void registerPredecessor(const std::shared_ptr &Node) { + void registerPredecessor(node_impl &Node) { if (std::find_if(MPredecessors.begin(), MPredecessors.end(), [&Node](const std::weak_ptr &Ptr) { - return Ptr.lock() == Node; + return Ptr.lock().get() == &Node; }) != MPredecessors.end()) { return; } - MPredecessors.push_back(Node); + MPredecessors.push_back(Node.weak_from_this()); } /// Construct an empty node. @@ -764,12 +764,14 @@ class node_impl : public std::enable_shared_from_this { struct nodes_deref_impl { template static node_impl &dereference(T &Elem) { - if constexpr (std::is_same_v, - std::weak_ptr>) { + using Ty = std::decay_t; + if constexpr (std::is_same_v>) { // This assumes that weak_ptr doesn't actually manage lifetime and // the object is guaranteed to be alive (which seems to be the // assumption across all graph code). return *Elem.lock(); + } else if constexpr (std::is_same_v) { + return *getSyclObjImpl(Elem); } else { return *Elem; } @@ -791,7 +793,7 @@ using nodes_iterator = nodes_iterator_impl< // std::set>, std::set, // - std::list>; + std::list, std::vector>; class nodes_range : public iterator_range { private: diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 1e729b2100eb6..af7390c6ca7ec 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -886,13 +886,13 @@ event handler::finalize() { // In-order queues create implicit linear dependencies between nodes. // Find the last node added to the graph from this queue, so our new // node can set it as a predecessor. - std::vector> - Deps; + std::vector Deps; if (ext::oneapi::experimental::detail::node_impl *DependentNode = GraphImpl->getLastInorderNode(Queue)) { - Deps.push_back(DependentNode->shared_from_this()); + Deps.push_back(DependentNode); } - NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps); + NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps) + .shared_from_this(); // If we are recording an in-order queue remember the new node, so it // can be used as a dependency for any more nodes recorded from this @@ -902,13 +902,13 @@ event handler::finalize() { ext::oneapi::experimental::detail::node_impl *LastBarrierRecordedFromQueue = GraphImpl->getBarrierDep(Queue->weak_from_this()); - std::vector> - Deps; + std::vector Deps; if (LastBarrierRecordedFromQueue) { - Deps.push_back(LastBarrierRecordedFromQueue->shared_from_this()); + Deps.push_back(LastBarrierRecordedFromQueue); } - NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps); + NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps) + .shared_from_this(); if (NodeImpl->MCGType == sycl::detail::CGType::Barrier) { GraphImpl->setBarrierDep(Queue->weak_from_this(), *NodeImpl);