Skip to content

[NFCI][SYCL][Graph] Refactor graph_impl::add #19351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 34 additions & 52 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,22 +409,19 @@ void graph_impl::markCGMemObjs(
}
}

std::shared_ptr<node_impl> graph_impl::add(nodes_range Deps) {
const std::shared_ptr<node_impl> &NodeImpl = std::make_shared<node_impl>();

MNodeStorage.push_back(NodeImpl);
node_impl &graph_impl::add(nodes_range Deps) {
node_impl &NodeImpl = createNode();

addDepsToNode(NodeImpl, Deps);
// Add an event associated with this explicit node for mixed usage
addEventForNode(sycl::detail::event_impl::create_completed_host_event(),
*NodeImpl);
NodeImpl);
return NodeImpl;
}

std::shared_ptr<node_impl>
graph_impl::add(std::function<void(handler &)> CGF,
const std::vector<sycl::detail::ArgDesc> &Args,
std::vector<std::shared_ptr<node_impl>> &Deps) {
node_impl &graph_impl::add(std::function<void(handler &)> CGF,
const std::vector<sycl::detail::ArgDesc> &Args,
nodes_range Deps) {
(void)Args;
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
detail::handler_impl HandlerImpl{*this};
Expand All @@ -435,7 +432,9 @@ graph_impl::add(std::function<void(handler &)> CGF,

// Pass the node deps to the handler so they are available when processing the
// CGF, need for async_malloc nodes.
Handler.impl->MNodeDeps = Deps;
Handler.impl->MNodeDeps.clear();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to preserve the semantics. Not sure if truly necessary, maybe there is a guarantee that it must be empty. And if not, why are we dropping old deps?

for (node_impl &N : Deps)
Handler.impl->MNodeDeps.push_back(N.shared_from_this());

#if XPTI_ENABLE_INSTRUMENTATION
// Save code location if one was set in TLS.
Expand Down Expand Up @@ -471,12 +470,12 @@ graph_impl::add(std::function<void(handler &)> CGF,
: ext::oneapi::experimental::detail::getNodeTypeFromCG(
Handler.getType());

auto NodeImpl =
node_impl &NodeImpl =
this->add(NodeType, std::move(Handler.impl->MGraphNodeCG), Deps);

// Add an event associated with this explicit node for mixed usage
addEventForNode(sycl::detail::event_impl::create_completed_host_event(),
*NodeImpl);
NodeImpl);

// Retrieve any dynamic parameters which have been registered in the CGF and
// register the actual nodes with them.
Expand All @@ -489,44 +488,40 @@ graph_impl::add(std::function<void(handler &)> CGF,
}

for (auto &[DynamicParam, ArgIndex] : DynamicParams) {
DynamicParam->registerNode(NodeImpl, ArgIndex);
DynamicParam->registerNode(NodeImpl.shared_from_this(), ArgIndex);
}

return NodeImpl;
}

std::shared_ptr<node_impl>
graph_impl::add(node_type NodeType,
std::shared_ptr<sycl::detail::CG> CommandGroup,
nodes_range Deps) {
node_impl &graph_impl::add(node_type NodeType,
std::shared_ptr<sycl::detail::CG> CommandGroup,
nodes_range Deps) {

// A unique set of dependencies obtained by checking requirements and events
std::set<std::shared_ptr<node_impl>> UniqueDeps = getCGEdges(CommandGroup);

// Track and mark the memory objects being used by the graph.
markCGMemObjs(CommandGroup);

const std::shared_ptr<node_impl> &NodeImpl =
std::make_shared<node_impl>(NodeType, std::move(CommandGroup));
MNodeStorage.push_back(NodeImpl);
node_impl &NodeImpl = createNode(NodeType, std::move(CommandGroup));

// Add any deps determined from requirements and events into the dependency
// list
addDepsToNode(NodeImpl, Deps);
addDepsToNode(NodeImpl, UniqueDeps);

if (NodeType == node_type::async_free) {
auto AsyncFreeCG =
static_cast<CGAsyncFree *>(NodeImpl->MCommandGroup.get());
auto AsyncFreeCG = static_cast<CGAsyncFree *>(NodeImpl.MCommandGroup.get());
// If this is an async free node mark that it is now available for reuse,
// and pass the async free node for tracking.
MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), *NodeImpl);
MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), NodeImpl);
}

return NodeImpl;
}

std::shared_ptr<node_impl>
node_impl &
graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
nodes_range Deps) {
// Set of Dependent nodes based on CG event and accessor dependencies.
Expand All @@ -551,15 +546,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
const auto &ActiveKernel = DynCGImpl->getActiveCG();
node_type NodeType =
ext::oneapi::experimental::detail::getNodeTypeFromCG(DynCGImpl->MCGType);
std::shared_ptr<detail::node_impl> NodeImpl =
add(NodeType, ActiveKernel, Deps);
detail::node_impl &NodeImpl = add(NodeType, ActiveKernel, Deps);

// Add an event associated with this explicit node for mixed usage
addEventForNode(sycl::detail::event_impl::create_completed_host_event(),
*NodeImpl);
NodeImpl);

// Track the dynamic command-group used inside the node object
DynCGImpl->MNodes.push_back(NodeImpl);
DynCGImpl->MNodes.push_back(NodeImpl.shared_from_this());

return NodeImpl;
}
Expand Down Expand Up @@ -652,7 +646,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
bool DestWasGraphRoot = Dest->MPredecessors.size() == 0;

// We need to add the edges first before checking for cycles
Src->registerSuccessor(Dest);
Src->registerSuccessor(*Dest);

bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1;
if (DestLostRootStatus) {
Expand Down Expand Up @@ -1265,7 +1259,7 @@ void exec_graph_impl::duplicateNodes() {
// Look through all the original node successors, find their copies and
// register those as successors with the current copied node
for (node_impl &NextNode : OriginalNode->successors()) {
auto Successor = NodesMap.at(NextNode.shared_from_this());
node_impl &Successor = *NodesMap.at(NextNode.shared_from_this());
NodeCopy->registerSuccessor(Successor);
}
}
Expand Down Expand Up @@ -1307,7 +1301,8 @@ void exec_graph_impl::duplicateNodes() {
auto NodeCopy = NewSubgraphNodes[i];

for (node_impl &NextNode : SubgraphNode->successors()) {
auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this());
node_impl &Successor =
*SubgraphNodesMap.at(NextNode.shared_from_this());
NodeCopy->registerSuccessor(Successor);
}
}
Expand Down Expand Up @@ -1341,7 +1336,7 @@ void exec_graph_impl::duplicateNodes() {
// Add all input nodes from the subgraph as successors for this node
// instead
for (auto &Input : Inputs) {
PredNode.registerSuccessor(Input);
PredNode.registerSuccessor(*Input);
}
}

Expand All @@ -1360,7 +1355,7 @@ void exec_graph_impl::duplicateNodes() {
// Add all Output nodes from the subgraph as predecessors for this node
// instead
for (auto &Output : Outputs) {
Output->registerSuccessor(SuccNode.shared_from_this());
Output->registerSuccessor(SuccNode);
}
}

Expand Down Expand Up @@ -1843,38 +1838,25 @@ node modifiable_command_graph::addImpl(dynamic_command_group &DynCGF,
"dynamic command-group.");
}

std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
for (auto &D : Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DynCGFImpl, DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
detail::node_impl &NodeImpl = impl->add(DynCGFImpl, Deps);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
for (auto &D : Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
detail::node_impl &NodeImpl = impl->add(Deps);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
const std::vector<node> &Deps) {
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
for (auto &D : Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

std::shared_ptr<detail::node_impl> NodeImpl = impl->add(CGF, {}, DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
detail::node_impl &NodeImpl = impl->add(CGF, {}, Deps);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

void modifiable_command_graph::addGraphLeafDependencies(node Node) {
Expand Down
32 changes: 19 additions & 13 deletions sycl/source/detail/graph/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,30 +147,30 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @param CommandGroup The CG which stores all information for this node.
/// @param Deps Dependencies of the created node.
/// @return Created node in the graph.
std::shared_ptr<node_impl> add(node_type NodeType,
std::shared_ptr<sycl::detail::CG> CommandGroup,
nodes_range Deps);
node_impl &add(node_type NodeType,
std::shared_ptr<sycl::detail::CG> CommandGroup,
nodes_range Deps);

/// Create a CGF node in the graph.
/// @param CGF Command-group function to create node with.
/// @param Args Node arguments.
/// @param Deps Dependencies of the created node.
/// @return Created node in the graph.
std::shared_ptr<node_impl> add(std::function<void(handler &)> CGF,
const std::vector<sycl::detail::ArgDesc> &Args,
std::vector<std::shared_ptr<node_impl>> &Deps);
node_impl &add(std::function<void(handler &)> CGF,
const std::vector<sycl::detail::ArgDesc> &Args,
nodes_range Deps);

/// Create an empty node in the graph.
/// @param Deps List of predecessor nodes.
/// @return Created node in the graph.
std::shared_ptr<node_impl> add(nodes_range Deps);
node_impl &add(nodes_range Deps);

/// Create a dynamic command-group node in the graph.
/// @param DynCGImpl Dynamic command-group used to create node.
/// @param Deps List of predecessor nodes.
/// @return Created node in the graph.
std::shared_ptr<node_impl>
add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl, nodes_range Deps);
node_impl &add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
nodes_range Deps);

/// Add a queue to the set of queues which are currently recording to this
/// graph.
Expand Down Expand Up @@ -511,6 +511,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
}

private:
template <typename... Ts> node_impl &createNode(Ts &&...Args) {
MNodeStorage.push_back(
std::make_shared<node_impl>(std::forward<Ts>(Args)...));
return *MNodeStorage.back();
Comment on lines +515 to +517
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if node_impl creation is under a mutex or not. If races are possible, then might need to change to

   auto Ptr = make_shared();
   node_impl &Res = *Ptr;
   MNodeStorage.push_back(std::move(Ptr));
   return Res;

If that's the case, then nodes_range over std::vector<node> optimization needs to be examined for the race conditions as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes node_impl creation should always be under a mutex. Found a potential race and also gap in our e2e tests while investigating this: #19379

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes node_impl creation should always be under a mutex. Found a potential race and also gap in our e2e tests while investigating this: #19379

@reble , are you trying to investigate this more (or something else maybe) before continuing with the review of this PR?

}

/// Check the graph for cycles by performing a depth-first search of the
/// graph. If a node is visited more than once in a given path through the
/// graph, a cycle is present and the search ends immediately.
Expand All @@ -525,13 +531,13 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// added as a root node.
/// @param Node The node to add deps for
/// @param Deps List of dependent nodes
void addDepsToNode(const std::shared_ptr<node_impl> &Node, nodes_range Deps) {
void addDepsToNode(node_impl &Node, nodes_range Deps) {
for (node_impl &N : Deps) {
N.registerSuccessor(Node);
this->removeRoot(*Node);
this->removeRoot(Node);
}
if (Node->MPredecessors.empty()) {
this->addRoot(*Node);
if (Node.MPredecessors.empty()) {
this->addRoot(Node);
}
}

Expand Down
34 changes: 17 additions & 17 deletions sycl/source/detail/graph/node_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <sycl/detail/cg_types.hpp> // for CGType
#include <sycl/detail/kernel_desc.hpp> // for kernel_param_kind_t

#include <sycl/ext/oneapi/experimental/graph/node.hpp> // for node

#include <cstring>
#include <fstream>
#include <iomanip>
Expand All @@ -26,8 +28,6 @@ inline namespace _V1 {
namespace ext {
namespace oneapi {
namespace experimental {
// Forward declarations
class node;

namespace detail {
// Forward declarations
Expand Down Expand Up @@ -121,27 +121,27 @@ class node_impl : public std::enable_shared_from_this<node_impl> {

/// Add successor to the node.
/// @param Node Node to add as a successor.
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
void registerSuccessor(node_impl &Node) {
if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
[Node](const std::weak_ptr<node_impl> &Ptr) {
return Ptr.lock() == Node;
[&Node](const std::weak_ptr<node_impl> &Ptr) {
return Ptr.lock().get() == &Node;
}) != MSuccessors.end()) {
return;
}
MSuccessors.push_back(Node);
Node->registerPredecessor(shared_from_this());
MSuccessors.push_back(Node.weak_from_this());
Node.registerPredecessor(*this);
}

/// Add predecessor to the node.
/// @param Node Node to add as a predecessor.
void registerPredecessor(const std::shared_ptr<node_impl> &Node) {
void registerPredecessor(node_impl &Node) {
if (std::find_if(MPredecessors.begin(), MPredecessors.end(),
[&Node](const std::weak_ptr<node_impl> &Ptr) {
return Ptr.lock() == Node;
return Ptr.lock().get() == &Node;
}) != MPredecessors.end()) {
return;
}
MPredecessors.push_back(Node);
MPredecessors.push_back(Node.weak_from_this());
}

/// Construct an empty node.
Expand Down Expand Up @@ -777,7 +777,7 @@ class nodes_range {
//
std::set<std::shared_ptr<node_impl>>, std::set<node_impl *>,
//
std::list<node_impl *>>;
std::list<node_impl *>, std::vector<node>>;

storage_iter Begin;
storage_iter End;
Expand All @@ -786,10 +786,8 @@ class nodes_range {
public:
nodes_range(const nodes_range &Other) = default;

template <
typename ContainerTy,
typename = std::enable_if_t<!std::is_same_v<nodes_range, ContainerTy>>>
nodes_range(ContainerTy &Container)
template <typename ContainerTy>
nodes_range(const ContainerTy &Container)
: Begin{Container.begin()}, End{Container.end()}, Size{Container.size()} {
}

Expand All @@ -815,12 +813,14 @@ class nodes_range {
return std::visit(
[](auto &&It) -> node_impl & {
auto &Elem = *It;
if constexpr (std::is_same_v<std::decay_t<decltype(Elem)>,
std::weak_ptr<node_impl>>) {
using Ty = std::decay_t<decltype(Elem)>;
if constexpr (std::is_same_v<Ty, std::weak_ptr<node_impl>>) {
// This assumes that weak_ptr doesn't actually manage lifetime and
// the object is guaranteed to be alive (which seems to be the
// assumption across all graph code).
return *Elem.lock();
} else if constexpr (std::is_same_v<Ty, node>) {
return *getSyclObjImpl(Elem);
} else {
return *Elem;
}
Expand Down
Loading
Loading