Skip to content

[Offload] Make olLaunchKernel test thread safe #149497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions offload/include/Shared/APITypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <cstddef>
#include <cstdint>
#include <mutex>

extern "C" {

Expand Down Expand Up @@ -75,6 +76,11 @@ struct __tgt_async_info {
/// should be freed after finalization.
llvm::SmallVector<void *, 2> AssociatedAllocations;

/// Mutex to guard access to AssociatedAllocations and the Queue.
/// This is only used for liboffload and should be ignored in libomptarget
/// code.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually ignored for libomptarget? I think you're unconditionally acquiring the mutex.

std::mutex Mutex;

/// The kernel launch environment used to issue a kernel. Stored here to
/// ensure it is a valid location while the transfer to the device is
/// happening.
Expand Down
16 changes: 6 additions & 10 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Error initPlugins(OffloadContext &Context) {
}

Error olInit_impl() {
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);

if (isOffloadInitialized()) {
OffloadContext::get().RefCount++;
Expand All @@ -224,7 +224,7 @@ Error olInit_impl() {
}

Error olShutDown_impl() {
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);

if (--OffloadContext::get().RefCount != 0)
return Error::success();
Expand Down Expand Up @@ -487,16 +487,12 @@ Error olSyncQueue_impl(ol_queue_handle_t Queue) {
// Host plugin doesn't have a queue set so it's not safe to call synchronize
// on it, but we have nothing to synchronize in that situation anyway.
if (Queue->AsyncInfo->Queue) {
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo))
// We don't need to release the queue and we would like the ability for
// other offload threads to submit work concurrently, so pass "false" here.
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo, false))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please indicate with a comment what's the false doing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code assumes other threads will not release the queue from that async info, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, although as far as I know, liboffload doesn't do that, and that feels reasonable as a thing to mark as undefined.

return Err;
}

// Recreate the stream resource so the queue can be reused
// TODO: Would be easier for the synchronization to (optionally) not release
// it to begin with.
if (auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo))
return Res;

return Error::success();
}

Expand Down Expand Up @@ -727,7 +723,7 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
auto &Device = Program->Image->getDevice();

std::lock_guard<std::mutex> Lock{Program->SymbolListMutex};
std::lock_guard<std::mutex> Lock(Program->SymbolListMutex);

switch (Kind) {
case OL_SYMBOL_KIND_KERNEL: {
Expand Down
4 changes: 2 additions & 2 deletions offload/libomptarget/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase,
TargetDataFuncPtrTy TargetDataFunction, const char *RegionTypeMsg,
const char *RegionName) {
assert(PM && "Runtime not initialized");
static_assert(std::is_convertible_v<TargetAsyncInfoTy, AsyncInfoTy>,
static_assert(std::is_convertible_v<TargetAsyncInfoTy &, AsyncInfoTy &>,
"TargetAsyncInfoTy must be convertible to AsyncInfoTy.");

TIMESCOPE_WITH_DETAILS_AND_IDENT("Runtime: Data Copy",
Expand Down Expand Up @@ -311,7 +311,7 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
int32_t ThreadLimit, void *HostPtr,
KernelArgsTy *KernelArgs) {
assert(PM && "Runtime not initialized");
static_assert(std::is_convertible_v<TargetAsyncInfoTy, AsyncInfoTy>,
static_assert(std::is_convertible_v<TargetAsyncInfoTy &, AsyncInfoTy &>,
"Target AsyncInfoTy must be convertible to AsyncInfoTy.");
DP("Entering target region for device %" PRId64 " with entry point " DPxMOD
"\n",
Expand Down
25 changes: 12 additions & 13 deletions offload/plugins-nextgen/amdgpu/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2232,16 +2232,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
/// Get the stream of the asynchronous info structure or get a new one.
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper,
AMDGPUStreamTy *&Stream) {
// Get the stream (if any) from the async info.
Stream = AsyncInfoWrapper.getQueueAs<AMDGPUStreamTy *>();
if (!Stream) {
// There was no stream; get an idle one.
if (auto Err = AMDGPUStreamManager.getResource(Stream))
return Err;

// Modify the async info's stream.
AsyncInfoWrapper.setQueueAs<AMDGPUStreamTy *>(Stream);
}
auto WrapperStream =
AsyncInfoWrapper.getOrInitQueue<AMDGPUStreamTy *>(AMDGPUStreamManager);
if (!WrapperStream)
return WrapperStream.takeError();
Stream = *WrapperStream;
return Plugin::success();
}

Expand Down Expand Up @@ -2296,7 +2291,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
}

/// Synchronize current thread with the pending operations on the async info.
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
bool ReleaseQueue) override {
AMDGPUStreamTy *Stream =
reinterpret_cast<AMDGPUStreamTy *>(AsyncInfo.Queue);
assert(Stream && "Invalid stream");
Expand All @@ -2307,8 +2303,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
// Once the stream is synchronized, return it to stream pool and reset
// AsyncInfo. This is to make sure the synchronization only works for its
// own tasks.
AsyncInfo.Queue = nullptr;
return AMDGPUStreamManager.returnResource(Stream);
if (ReleaseQueue) {
AsyncInfo.Queue = nullptr;
return AMDGPUStreamManager.returnResource(Stream);
}
return Plugin::success();
}

/// Query for the completion of the pending operations on the async info.
Expand Down
25 changes: 22 additions & 3 deletions offload/plugins-nextgen/common/include/PluginInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct GenericPluginTy;
struct GenericKernelTy;
struct GenericDeviceTy;
struct RecordReplayTy;
template <typename ResourceRef> class GenericDeviceResourceManagerTy;

/// Class that wraps the __tgt_async_info to simply its usage. In case the
/// object is constructed without a valid __tgt_async_info, the object will use
Expand Down Expand Up @@ -93,6 +94,20 @@ struct AsyncInfoWrapperTy {
AsyncInfoPtr->Queue = Queue;
}

/// Get the queue, using the provided resource manager to initialise it if it
/// doesn't exist.
template <typename Ty, typename RMTy>
Expected<Ty>
getOrInitQueue(GenericDeviceResourceManagerTy<RMTy> &ResourceManager) {
std::lock_guard<std::mutex> Lock(AsyncInfoPtr->Mutex);
if (!AsyncInfoPtr->Queue) {
if (auto Err = ResourceManager.getResource(
*reinterpret_cast<Ty *>(&AsyncInfoPtr->Queue)))
return Err;
}
return getQueueAs<Ty>();
}

/// Synchronize with the __tgt_async_info's pending operations if it's the
/// internal async info. The error associated to the asynchronous operations
/// issued in this queue must be provided in \p Err. This function will update
Expand All @@ -104,6 +119,7 @@ struct AsyncInfoWrapperTy {
/// Register \p Ptr as an associated allocation that is freed after
/// finalization.
void freeAllocationAfterSynchronization(void *Ptr) {
std::lock_guard<std::mutex> AllocationGuard{AsyncInfoPtr->Mutex};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still some instances with the {...} syntax. I think you have another one in the code.

AsyncInfoPtr->AssociatedAllocations.push_back(Ptr);
}

Expand Down Expand Up @@ -793,9 +809,12 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
Error setupRPCServer(GenericPluginTy &Plugin, DeviceImageTy &Image);

/// Synchronize the current thread with the pending operations on the
/// __tgt_async_info structure.
Error synchronize(__tgt_async_info *AsyncInfo);
virtual Error synchronizeImpl(__tgt_async_info &AsyncInfo) = 0;
/// __tgt_async_info structure. If ReleaseQueue is false, then the
// underlying queue will not be released. In this case, additional
// work may be submitted to the queue whilst a synchronize is running.
Error synchronize(__tgt_async_info *AsyncInfo, bool ReleaseQueue = true);
virtual Error synchronizeImpl(__tgt_async_info &AsyncInfo,
bool ReleaseQueue) = 0;

/// Invokes any global constructors on the device if present and is required
/// by the target.
Expand Down
23 changes: 15 additions & 8 deletions offload/plugins-nextgen/common/src/PluginInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1329,18 +1329,25 @@ Error PinnedAllocationMapTy::unlockUnmappedHostBuffer(void *HstPtr) {
return eraseEntry(*Entry);
}

Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo) {
if (!AsyncInfo || !AsyncInfo->Queue)
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
"invalid async info queue");
Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo,
bool ReleaseQueue) {
SmallVector<void *, 2> AllocsToDelete{};
{
std::lock_guard<std::mutex> AllocationGuard{AsyncInfo->Mutex};

if (auto Err = synchronizeImpl(*AsyncInfo))
return Err;
if (!AsyncInfo || !AsyncInfo->Queue)
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
"invalid async info queue");

if (auto Err = synchronizeImpl(*AsyncInfo, ReleaseQueue))
return Err;

std::swap(AllocsToDelete, AsyncInfo->AssociatedAllocations);
}

for (auto *Ptr : AsyncInfo->AssociatedAllocations)
for (auto *Ptr : AllocsToDelete)
if (auto Err = dataDelete(Ptr, TargetAllocTy::TARGET_ALLOC_DEVICE))
return Err;
AsyncInfo->AssociatedAllocations.clear();

return Plugin::success();
}
Expand Down
32 changes: 15 additions & 17 deletions offload/plugins-nextgen/cuda/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,16 +522,11 @@ struct CUDADeviceTy : public GenericDeviceTy {

/// Get the stream of the asynchronous info structure or get a new one.
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper, CUstream &Stream) {
// Get the stream (if any) from the async info.
Stream = AsyncInfoWrapper.getQueueAs<CUstream>();
if (!Stream) {
// There was no stream; get an idle one.
if (auto Err = CUDAStreamManager.getResource(Stream))
return Err;

// Modify the async info's stream.
AsyncInfoWrapper.setQueueAs<CUstream>(Stream);
}
auto WrapperStream =
AsyncInfoWrapper.getOrInitQueue<CUstream>(CUDAStreamManager);
if (!WrapperStream)
return WrapperStream.takeError();
Stream = *WrapperStream;
return Plugin::success();
}

Expand Down Expand Up @@ -642,17 +637,20 @@ struct CUDADeviceTy : public GenericDeviceTy {
}

/// Synchronize current thread with the pending operations on the async info.
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
bool ReleaseQueue) override {
CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
CUresult Res;
Res = cuStreamSynchronize(Stream);

// Once the stream is synchronized, return it to stream pool and reset
// AsyncInfo. This is to make sure the synchronization only works for its
// own tasks.
AsyncInfo.Queue = nullptr;
if (auto Err = CUDAStreamManager.returnResource(Stream))
return Err;
// Once the stream is synchronized and we want to release the queue, return
// it to stream pool and reset AsyncInfo. This is to make sure the
// synchronization only works for its own tasks.
if (ReleaseQueue) {
AsyncInfo.Queue = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When does the queue gets unset/released for liboffload queues?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the device is de-inited, all streams in the stream manager are deinited and dropped.

For liboffload specifically, since devices are not cleared, this happens during the final liboffload olShutDown call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But shouldn't the queue be released when olDestroyQueue is called?

if (auto Err = CUDAStreamManager.returnResource(Stream))
return Err;
}

return Plugin::check(Res, "error in cuStreamSynchronize: %s");
}
Expand Down
3 changes: 2 additions & 1 deletion offload/plugins-nextgen/host/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ struct GenELF64DeviceTy : public GenericDeviceTy {

/// All functions are already synchronous. No need to do anything on this
/// synchronization function.
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
bool ReleaseQueue) override {
return Plugin::success();
}

Expand Down
18 changes: 18 additions & 0 deletions offload/unittests/OffloadAPI/common/Fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <OffloadAPI.h>
#include <OffloadPrint.hpp>
#include <gtest/gtest.h>
#include <thread>

#include "Environment.hpp"

Expand Down Expand Up @@ -57,6 +58,23 @@ inline std::string SanitizeString(const std::string &Str) {
return NewStr;
}

template <typename Fn> inline void threadify(Fn body) {
std::vector<std::thread> Threads;
for (size_t I = 0; I < 20; I++) {
Threads.emplace_back(
[&body](size_t I) {
std::string ScopeMsg{"Thread #"};
ScopeMsg.append(std::to_string(I));
SCOPED_TRACE(ScopeMsg);
body(I);
},
I);
}
for (auto &T : Threads) {
T.join();
}
}

struct OffloadTest : ::testing::Test {
ol_device_handle_t Host = TestEnvironment::getHostDevice();
};
Expand Down
23 changes: 23 additions & 0 deletions offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,29 @@ TEST_P(olLaunchKernelFooTest, Success) {
ASSERT_SUCCESS(olMemFree(Mem));
}

TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd love to be able to add an OFFLOAD_TEST_THREADED_P macro so that you'd get threaded and non-threaded tests "for free" without copy-pasting the test body. But I can't think of a good way of actually implementing that with gtest, anyone have any ideas?

threadify([&](size_t) {
void *Mem;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
LaunchArgs.GroupSize.x * sizeof(uint32_t), &Mem));
struct {
void *Mem;
} Args{Mem};

ASSERT_SUCCESS(olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args),
&LaunchArgs, nullptr));

ASSERT_SUCCESS(olWaitQueue(Queue));

uint32_t *Data = (uint32_t *)Mem;
for (uint32_t i = 0; i < 64; i++) {
ASSERT_EQ(Data[i], i);
}

ASSERT_SUCCESS(olMemFree(Mem));
});
}

TEST_P(olLaunchKernelNoArgsTest, Success) {
ASSERT_SUCCESS(
olLaunchKernel(Queue, Device, Kernel, nullptr, 0, &LaunchArgs));
Expand Down
Loading