Skip to content

Commit f2fa176

Browse files
authored
[UR][SYCL] Add urUSMContextMemcpyExp API to enable device global support. (#17268)
Also adds a path using the new api to avoid a workaround introduced in #16565
1 parent 7b9490b commit f2fa176

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+821
-14
lines changed

sycl/source/detail/device_global_map_entry.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,48 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(queue_impl &QueueImpl) {
8787
return NewAlloc;
8888
}
8989

90+
DeviceGlobalUSMMem &
91+
DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(const context &Context) {
92+
assert(!MIsDeviceImageScopeDecorated &&
93+
"USM allocations should not be acquired for device_global with "
94+
"device_image_scope property.");
95+
const std::shared_ptr<context_impl> &CtxImpl = getSyclObjImpl(Context);
96+
const std::shared_ptr<device_impl> &DevImpl =
97+
getSyclObjImpl(CtxImpl->getDevices().front());
98+
std::lock_guard<std::mutex> Lock(MDeviceToUSMPtrMapMutex);
99+
100+
auto DGUSMPtr = MDeviceToUSMPtrMap.find({DevImpl.get(), CtxImpl.get()});
101+
if (DGUSMPtr != MDeviceToUSMPtrMap.end())
102+
return DGUSMPtr->second;
103+
104+
void *NewDGUSMPtr = detail::usm::alignedAllocInternal(
105+
0, MDeviceGlobalTSize, CtxImpl.get(), DevImpl.get(),
106+
sycl::usm::alloc::device);
107+
108+
auto NewAllocIt = MDeviceToUSMPtrMap.emplace(
109+
std::piecewise_construct,
110+
std::forward_as_tuple(DevImpl.get(), CtxImpl.get()),
111+
std::forward_as_tuple(NewDGUSMPtr));
112+
assert(NewAllocIt.second &&
113+
"USM allocation for device and context already happened.");
114+
DeviceGlobalUSMMem &NewAlloc = NewAllocIt.first->second;
115+
116+
// C++ guarantees members appear in memory in the order they are declared,
117+
// so since the member variable that contains the initial contents of the
118+
// device_global is right after the usm_ptr member variable we can do
119+
// some pointer arithmetic to memcopy over this value to the usm_ptr. This
120+
// value inside of the device_global will be zero-initialized if it was not
121+
// given a value on construction.
122+
MemoryManager::context_copy_usm(
123+
reinterpret_cast<const void *>(
124+
reinterpret_cast<uintptr_t>(MDeviceGlobalPtr) +
125+
sizeof(MDeviceGlobalPtr)),
126+
CtxImpl, MDeviceGlobalTSize, NewAlloc.MPtr);
127+
128+
CtxImpl->addAssociatedDeviceGlobal(MDeviceGlobalPtr);
129+
return NewAlloc;
130+
}
131+
90132
void DeviceGlobalMapEntry::removeAssociatedResources(
91133
const context_impl *CtxImpl) {
92134
std::lock_guard<std::mutex> Lock{MDeviceToUSMPtrMapMutex};

sycl/source/detail/device_global_map_entry.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ struct DeviceGlobalMapEntry {
111111
// Gets or allocates USM memory for a device_global.
112112
DeviceGlobalUSMMem &getOrAllocateDeviceGlobalUSM(queue_impl &QueueImpl);
113113

114+
// This overload allows the allocation to be initialized without a queue. The
115+
// UR adapter in use must report true for
116+
// UR_DEVICE_INFO_USM_CONTEXT_MEMCPY_SUPPORT_EXP to take advantage of this.
117+
DeviceGlobalUSMMem &getOrAllocateDeviceGlobalUSM(const context &Context);
118+
114119
// Removes resources for device_globals associated with the context.
115120
void removeAssociatedResources(const context_impl *CtxImpl);
116121

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -665,13 +665,21 @@ class kernel_bundle_impl {
665665
"'device_image_scope' property");
666666
}
667667

668-
// TODO: Add context-only initialization via `urUSMContextMemcpyExp` instead
669-
// of using a throw-away queue.
670-
queue InitQueue{MContext, Dev};
671-
auto &USMMem =
672-
Entry->getOrAllocateDeviceGlobalUSM(*getSyclObjImpl(InitQueue));
673-
InitQueue.wait_and_throw();
674-
return USMMem.getPtr();
668+
const auto &DeviceImpl = getSyclObjImpl(Dev);
669+
bool SupportContextMemcpy = false;
670+
DeviceImpl->getAdapter()->call<UrApiKind::urDeviceGetInfo>(
671+
DeviceImpl->getHandleRef(),
672+
UR_DEVICE_INFO_USM_CONTEXT_MEMCPY_SUPPORT_EXP,
673+
sizeof(SupportContextMemcpy), &SupportContextMemcpy, nullptr);
674+
if (SupportContextMemcpy) {
675+
return Entry->getOrAllocateDeviceGlobalUSM(MContext).getPtr();
676+
} else {
677+
queue InitQueue{MContext, Dev};
678+
auto &USMMem =
679+
Entry->getOrAllocateDeviceGlobalUSM(*getSyclObjImpl(InitQueue));
680+
InitQueue.wait_and_throw();
681+
return USMMem.getPtr();
682+
}
675683
}
676684

677685
size_t ext_oneapi_get_device_global_size(const std::string &Name) const {

sycl/source/detail/memory_manager.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,16 @@ void MemoryManager::copy_usm(const void *SrcMem, queue_impl &SrcQueue,
891891
DepEvents.data(), OutEvent);
892892
}
893893

894+
void MemoryManager::context_copy_usm(const void *SrcMem, ContextImplPtr Context,
895+
size_t Len, void *DstMem) {
896+
if (!SrcMem || !DstMem)
897+
throw exception(make_error_code(errc::invalid),
898+
"NULL pointer argument in memory copy operation.");
899+
const AdapterPtr &Adapter = Context->getAdapter();
900+
Adapter->call<UrApiKind::urUSMContextMemcpyExp>(Context->getHandleRef(),
901+
DstMem, SrcMem, Len);
902+
}
903+
894904
void MemoryManager::fill_usm(void *Mem, queue_impl &Queue, size_t Length,
895905
const std::vector<unsigned char> &Pattern,
896906
std::vector<ur_event_handle_t> DepEvents,

sycl/source/detail/memory_manager.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ class MemoryManager {
140140
void *DstMem, std::vector<ur_event_handle_t> DepEvents,
141141
ur_event_handle_t *OutEvent);
142142

143+
static void context_copy_usm(const void *SrcMem, ContextImplPtr Context,
144+
size_t Len, void *DstMem);
145+
143146
static void fill_usm(void *DstMem, queue_impl &Queue, size_t Len,
144147
const std::vector<unsigned char> &Pattern,
145148
std::vector<ur_event_handle_t> DepEvents,

unified-runtime/include/ur_api.h

Lines changed: 54 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

unified-runtime/include/ur_api_funcs.def

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

unified-runtime/include/ur_ddi.h

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

unified-runtime/include/ur_print.h

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

unified-runtime/include/ur_print.hpp

Lines changed: 52 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)