Skip to content

[UR][OpenCL][CUDA][HIP][L0][L0v2][Offload] Refactor reference counting in UR across all adapters into a new common UR_ReferenceCounter class #18823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: sycl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/cuda/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
std::call_once(InitFlag,
[=]() { ur::cuda::adapter = new ur_adapter_handle_t_; });

ur::cuda::adapter->RefCount++;
ur::cuda::adapter->getRefCounter().increment();
*phAdapters = ur::cuda::adapter;
}

Expand All @@ -78,13 +78,13 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
ur::cuda::adapter->RefCount++;
ur::cuda::adapter->getRefCounter().increment();

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
if (--ur::cuda::adapter->RefCount == 0) {
if (ur::cuda::adapter->getRefCounter().decrement() == 0) {
delete ur::cuda::adapter;
}
return UR_RESULT_SUCCESS;
Expand All @@ -108,7 +108,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
case UR_ADAPTER_INFO_BACKEND:
return ReturnValue(UR_BACKEND_CUDA);
case UR_ADAPTER_INFO_REFERENCE_COUNT:
return ReturnValue(ur::cuda::adapter->RefCount.load());
return ReturnValue(ur::cuda::adapter->getRefCounter().getCount());
case UR_ADAPTER_INFO_VERSION:
return ReturnValue(uint32_t{1});
default:
Expand Down
10 changes: 7 additions & 3 deletions unified-runtime/source/adapters/cuda/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,29 @@
#ifndef UR_CUDA_ADAPTER_HPP_INCLUDED
#define UR_CUDA_ADAPTER_HPP_INCLUDED

#include "common/ur_ref_counter.hpp"
#include "logger/ur_logger.hpp"
#include "platform.hpp"
#include "tracing.hpp"
#include <ur_api.h>

#include <atomic>
#include <memory>

struct ur_adapter_handle_t_ : ur::cuda::handle_base {
std::atomic<uint32_t> RefCount = 0;
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
logger::Logger &logger;
std::unique_ptr<ur_platform_handle_t_> Platform;
ur_adapter_handle_t_();
~ur_adapter_handle_t_();
ur_adapter_handle_t_(const ur_adapter_handle_t_ &) = delete;

UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; }

private:
UR_ReferenceCounter RefCounter;
};

// Keep the global namespace'd
// Keep the global namespace
namespace ur::cuda {
extern ur_adapter_handle_t adapter;
} // namespace ur::cuda
Expand Down
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_(
bool IsInOrder)
: handle_base(), Context(Context), Device(Device), IsUpdatable(IsUpdatable),
IsInOrder(IsInOrder), CudaGraph{nullptr}, CudaGraphExec{nullptr},
RefCount{1}, NextSyncPoint{0} {
NextSyncPoint{0} {
urContextRetain(Context);
}

Expand Down Expand Up @@ -380,13 +380,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(

UR_APIEXPORT ur_result_t UR_APICALL
urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
hCommandBuffer->incrementReferenceCount();
hCommandBuffer->getRefCounter().increment();
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
if (hCommandBuffer->decrementReferenceCount() == 0) {
if (hCommandBuffer->getRefCounter().decrement() == 0) {
// Ref count has reached zero, release of created commands
for (auto &Command : hCommandBuffer->CommandHandles) {
commandHandleDestroy(Command);
Expand Down Expand Up @@ -1476,7 +1476,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp(

switch (propName) {
case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT:
return ReturnValue(hCommandBuffer->getReferenceCount());
return ReturnValue(hCommandBuffer->getRefCounter().getCount());
case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: {
ur_exp_command_buffer_desc_t Descriptor{};
Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC;
Expand Down
10 changes: 6 additions & 4 deletions unified-runtime/source/adapters/cuda/command_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
#include <ur_api.h>
#include <ur_print.hpp>

#include "common/ur_ref_counter.hpp"
#include "context.hpp"
#include "logger/ur_logger.hpp"

#include <cuda.h>
#include <memory>
#include <unordered_set>
Expand Down Expand Up @@ -173,9 +175,7 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base {
return SyncPoint;
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
uint32_t getReferenceCount() const noexcept { return RefCount; }
UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; }

// UR context associated with this command-buffer
ur_context_handle_t Context;
Expand All @@ -191,7 +191,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base {
CUgraphExec CudaGraphExec = nullptr;
// Atomic variable counting the number of reference to this command_buffer
// using std::atomic prevents data race when incrementing/decrementing.
std::atomic_uint32_t RefCount;

// Ordered map of sync_points to ur_events, so that we can find the last
// node added to an in-order command-buffer.
Expand All @@ -203,4 +202,7 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base {
// Handles to individual commands in the command-buffer
std::vector<std::unique_ptr<ur_exp_command_buffer_command_handle_t_>>
CommandHandles;

private:
UR_ReferenceCounter RefCounter;
};
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/cuda/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
return ReturnValue(hContext->getDevices().data(),
hContext->getDevices().size());
case UR_CONTEXT_INFO_REFERENCE_COUNT:
return ReturnValue(hContext->getReferenceCount());
return ReturnValue(hContext->getRefCounter().getCount());
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
// 2D USM memcpy is supported.
return ReturnValue(true);
Expand All @@ -83,7 +83,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(

UR_APIEXPORT ur_result_t UR_APICALL
urContextRelease(ur_context_handle_t hContext) {
if (hContext->decrementReferenceCount() > 0) {
if (hContext->getRefCounter().getCount() > 0) {
return UR_RESULT_SUCCESS;
}
hContext->invokeExtendedDeleters();
Expand All @@ -94,9 +94,9 @@ urContextRelease(ur_context_handle_t hContext) {

UR_APIEXPORT ur_result_t UR_APICALL
urContextRetain(ur_context_handle_t hContext) {
assert(hContext->getReferenceCount() > 0);
assert(hContext->getRefCounter().getCount() > 0);

hContext->incrementReferenceCount();
hContext->getRefCounter().increment();
return UR_RESULT_SUCCESS;
}

Expand Down
12 changes: 4 additions & 8 deletions unified-runtime/source/adapters/cuda/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
#include <memory>
#include <ur_api.h>

#include <atomic>
#include <mutex>
#include <set>
#include <vector>

#include "adapter.hpp"
#include "common.hpp"
#include "common/ur_ref_counter.hpp"
#include "device.hpp"
#include "umf_helpers.hpp"

Expand Down Expand Up @@ -88,15 +88,14 @@ struct ur_context_handle_t_ : ur::cuda::handle_base {
};

std::vector<ur_device_handle_t> Devices;
std::atomic_uint32_t RefCount;

// UMF CUDA memory provider and pool for the host memory
// (UMF_MEMORY_TYPE_HOST)
umf_memory_provider_handle_t MemoryProviderHost = nullptr;
umf_memory_pool_handle_t MemoryPoolHost = nullptr;

ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
: handle_base(), Devices{Devs, Devs + NumDevices}, RefCount{1} {
: handle_base(), Devices{Devs, Devs + NumDevices} {
// Create UMF CUDA memory provider for the host memory
// (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because
// it is guaranteed to exist).
Expand Down Expand Up @@ -140,11 +139,7 @@ struct ur_context_handle_t_ : ur::cuda::handle_base {
return std::distance(Devices.begin(), It);
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }

uint32_t getReferenceCount() const noexcept { return RefCount; }
UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; }

void addPool(ur_usm_pool_handle_t Pool);

Expand All @@ -156,6 +151,7 @@ struct ur_context_handle_t_ : ur::cuda::handle_base {
std::mutex Mutex;
std::vector<deleter_data> ExtendedDeleters;
std::set<ur_usm_pool_handle_t> PoolHandles;
UR_ReferenceCounter RefCounter;
};

namespace {
Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
return ReturnValue("CUDA");
}
case UR_DEVICE_INFO_REFERENCE_COUNT: {
return ReturnValue(hDevice->getReferenceCount());
return ReturnValue(hDevice->getRefCounter().getCount());
}
case UR_DEVICE_INFO_VERSION: {
std::stringstream SS;
Expand Down
9 changes: 6 additions & 3 deletions unified-runtime/source/adapters/cuda/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <umf/memory_provider.h>

#include "common.hpp"
#include "common/ur_ref_counter.hpp"

struct ur_device_handle_t_ : ur::cuda::handle_base {
private:
Expand All @@ -23,7 +24,6 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
native_type CuDevice;
CUcontext CuContext;
CUevent EvBase; // CUDA event used as base counter
std::atomic_uint32_t RefCount;
ur_platform_handle_t Platform;
uint32_t DeviceIndex;

Expand All @@ -42,7 +42,7 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
ur_device_handle_t_(native_type cuDevice, CUcontext cuContext, CUevent evBase,
ur_platform_handle_t platform, uint32_t DevIndex)
: handle_base(), CuDevice(cuDevice), CuContext(cuContext), EvBase(evBase),
RefCount{1}, Platform(platform), DeviceIndex{DevIndex} {
Platform(platform), DeviceIndex{DevIndex} {
UR_CHECK_ERROR(cuDeviceGetAttribute(
&MaxRegsPerBlock, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK,
cuDevice));
Expand Down Expand Up @@ -136,7 +136,7 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {

CUcontext getNativeContext() const noexcept { return CuContext; };

uint32_t getReferenceCount() const noexcept { return RefCount; }
UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; }

ur_platform_handle_t getPlatform() const noexcept { return Platform; };

Expand Down Expand Up @@ -178,6 +178,9 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
// (UMF_MEMORY_TYPE_SHARED)
umf_memory_provider_handle_t MemoryProviderShared;
umf_memory_pool_handle_t MemoryPoolShared;

private:
UR_ReferenceCounter RefCounter;
};

int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute);
10 changes: 4 additions & 6 deletions unified-runtime/source/adapters/cuda/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
case UR_EVENT_INFO_COMMAND_TYPE:
return ReturnValue(hEvent->getCommandType());
case UR_EVENT_INFO_REFERENCE_COUNT:
return ReturnValue(hEvent->getReferenceCount());
return ReturnValue(hEvent->getRefCounter().getCount());
case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS:
return ReturnValue(hEvent->getExecutionStatus());
case UR_EVENT_INFO_CONTEXT:
Expand Down Expand Up @@ -248,9 +248,7 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) {
}

UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
const auto RefCount = hEvent->incrementReferenceCount();

if (RefCount == 0) {
if (hEvent->getRefCounter().increment() == 0) {
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
}

Expand All @@ -260,12 +258,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
// double delete or someone is messing with the ref count.
// either way, cannot safely proceed.
if (hEvent->getReferenceCount() == 0) {
if (hEvent->getRefCounter().getCount() == 0) {
return UR_RESULT_ERROR_INVALID_EVENT;
}

// decrement ref count. If it is 0, delete the event.
if (hEvent->decrementReferenceCount() == 0) {
if (hEvent->getRefCounter().decrement() == 0) {
std::unique_ptr<ur_event_handle_t_> event_ptr{hEvent};
ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT;
try {
Expand Down
9 changes: 3 additions & 6 deletions unified-runtime/source/adapters/cuda/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ur/ur.hpp>

#include "common.hpp"
#include "common/ur_ref_counter.hpp"
#include "queue.hpp"

/// UR Event mapping to CUevent
Expand Down Expand Up @@ -59,16 +60,12 @@ struct ur_event_handle_t_ : ur::cuda::handle_base {
ur_command_t getCommandType() const noexcept { return CommandType; }
ur_context_handle_t getContext() const noexcept { return Context; };
uint32_t getEventID() const noexcept { return EventID; }

// Reference counting.
uint32_t getReferenceCount() const noexcept { return RefCount; }
uint32_t incrementReferenceCount() { return ++RefCount; }
uint32_t decrementReferenceCount() { return --RefCount; }
UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; }

private:
ur_command_t CommandType; // The type of command associated with event.

std::atomic_uint32_t RefCount{1}; // Event reference count.
UR_ReferenceCounter RefCounter;

bool HasOwnership{true}; // Signifies if event owns the native type.
bool HasProfiling{false}; // Signifies if event has profiling information.
Expand Down
12 changes: 7 additions & 5 deletions unified-runtime/source/adapters/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,22 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) {
UR_ASSERT(hKernel->getReferenceCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL);
UR_ASSERT(hKernel->getRefCounter().getCount() > 0u,
UR_RESULT_ERROR_INVALID_KERNEL);

hKernel->incrementReferenceCount();
hKernel->getRefCounter().increment();
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urKernelRelease(ur_kernel_handle_t hKernel) {
// double delete or someone is messing with the ref count.
// either way, cannot safely proceed.
UR_ASSERT(hKernel->getReferenceCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL);
UR_ASSERT(hKernel->getRefCounter().getCount() != 0,
UR_RESULT_ERROR_INVALID_KERNEL);

// decrement ref count. If it is 0, delete the program.
if (hKernel->decrementReferenceCount() == 0) {
if (hKernel->getRefCounter().decrement() == 0) {
// no internal cuda resources to clean up. Just delete it.
delete hKernel;
return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -248,7 +250,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel,
case UR_KERNEL_INFO_NUM_ARGS:
return ReturnValue(hKernel->getNumArgs());
case UR_KERNEL_INFO_REFERENCE_COUNT:
return ReturnValue(hKernel->getReferenceCount());
return ReturnValue(hKernel->getRefCounter().getCount());
case UR_KERNEL_INFO_CONTEXT:
return ReturnValue(hKernel->getContext());
case UR_KERNEL_INFO_PROGRAM:
Expand Down
Loading
Loading