-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[Offload] Add olWaitEvents #150036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Offload] Add olWaitEvents #150036
Conversation
@llvm/pr-subscribers-offload Author: Ross Brunton (RossBrunton) ChangesNot to be confused with olSyncQueue, which used to be called olWaitQueue Full diff: https://github.com/llvm/llvm-project/pull/150036.diff 6 Files Affected:
diff --git a/offload/liboffload/API/Queue.td b/offload/liboffload/API/Queue.td
index 19327cdab4254..43c723de54510 100644
--- a/offload/liboffload/API/Queue.td
+++ b/offload/liboffload/API/Queue.td
@@ -41,6 +41,23 @@ def : Function {
let returns = [];
}
+def : Function {
+ let name = "olWaitQueue";
+ let desc = "Make any future work submitted to this queue wait until the provided events are complete.";
+ let details = [
+ "All events in `Events` must complete beforet he queue is unblocked.",
+ "The input events can be from any queue on any device provided by the same platform as `Queue`.",
+ ];
+ let params = [
+ Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN>,
+ Param<"ol_event_handle_t *", "Events", "list of `NumEvents` events to wait for", PARAM_IN>,
+ Param<"size_t", "NumEvents", "size of `Events`", PARAM_IN>,
+ ];
+ let returns = [
+ Return<"OL_ERRC_INVALID_NULL_HANDLE", ["Any event handle in the list is NULL"]>,
+ ];
+}
+
def : Enum {
let name = "ol_queue_info_t";
let desc = "Supported queue info.";
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index d93e4f1db58a7..c155a6b85387c 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -500,6 +500,28 @@ Error olSyncQueue_impl(ol_queue_handle_t Queue) {
return Error::success();
}
+Error olWaitQueue_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events,
+ size_t NumEvents) {
+ auto *Device = Queue->Device->Device;
+
+ for (size_t I = 0; I < NumEvents; I++) {
+ auto *Event = Events[I];
+
+ if (!Event)
+ return Plugin::error(ErrorCode::INVALID_NULL_HANDLE,
+ "olWaitQueue asked to wait on a NULL event");
+
+ // Do nothing if the event is for this queue
+ if (Event->Queue == Queue)
+ continue;
+
+ if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo))
+ return Err;
+ }
+
+ return Error::success();
+}
+
Error olGetQueueInfoImplDetail(ol_queue_handle_t Queue,
ol_queue_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) {
diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt b/offload/unittests/OffloadAPI/CMakeLists.txt
index f09cfc6bb0876..2621eaeb64e82 100644
--- a/offload/unittests/OffloadAPI/CMakeLists.txt
+++ b/offload/unittests/OffloadAPI/CMakeLists.txt
@@ -39,7 +39,8 @@ add_offload_unittest("queue"
queue/olSyncQueue.cpp
queue/olDestroyQueue.cpp
queue/olGetQueueInfo.cpp
- queue/olGetQueueInfoSize.cpp)
+ queue/olGetQueueInfoSize.cpp
+ queue/olWaitQueue.cpp)
add_offload_unittest("symbol"
symbol/olGetSymbol.cpp
diff --git a/offload/unittests/OffloadAPI/device_code/CMakeLists.txt b/offload/unittests/OffloadAPI/device_code/CMakeLists.txt
index 11c8ccbd6c7c5..0e4695ee9969f 100644
--- a/offload/unittests/OffloadAPI/device_code/CMakeLists.txt
+++ b/offload/unittests/OffloadAPI/device_code/CMakeLists.txt
@@ -8,6 +8,7 @@ add_offload_test_device_code(localmem_static.c localmem_static)
add_offload_test_device_code(global.c global)
add_offload_test_device_code(global_ctor.c global_ctor)
add_offload_test_device_code(global_dtor.c global_dtor)
+add_offload_test_device_code(sequence.c sequence)
add_custom_target(offload_device_binaries DEPENDS
foo.bin
@@ -19,5 +20,6 @@ add_custom_target(offload_device_binaries DEPENDS
global.bin
global_ctor.bin
global_dtor.bin
+ sequence.bin
)
set(OFFLOAD_TEST_DEVICE_CODE_PATH ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE)
diff --git a/offload/unittests/OffloadAPI/device_code/sequence.c b/offload/unittests/OffloadAPI/device_code/sequence.c
new file mode 100644
index 0000000000000..22504086ffa38
--- /dev/null
+++ b/offload/unittests/OffloadAPI/device_code/sequence.c
@@ -0,0 +1,11 @@
+#include <gpuintrin.h>
+#include <stdint.h>
+
+__gpu_kernel void sequence(uint32_t idx, uint32_t *inout) {
+ if (idx == 0)
+ inout[idx] = 0;
+ else if (idx == 1)
+ inout[idx] = 1;
+ else
+ inout[idx] = inout[idx-1] + inout[idx-2];
+}
diff --git a/offload/unittests/OffloadAPI/queue/olWaitQueue.cpp b/offload/unittests/OffloadAPI/queue/olWaitQueue.cpp
new file mode 100644
index 0000000000000..fdf272dafa911
--- /dev/null
+++ b/offload/unittests/OffloadAPI/queue/olWaitQueue.cpp
@@ -0,0 +1,148 @@
+//===------- Offload API tests - olWaitQueue ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+struct olWaitQueueTest : OffloadProgramTest {
+ void SetUp() override {
+ RETURN_ON_FATAL_FAILURE(OffloadProgramTest::SetUpWith("sequence"));
+ ASSERT_SUCCESS(
+ olGetSymbol(Program, "sequence", OL_SYMBOL_KIND_KERNEL, &Kernel));
+ LaunchArgs.Dimensions = 1;
+ LaunchArgs.GroupSize = {1, 1, 1};
+ LaunchArgs.NumGroups = {1, 1, 1};
+ LaunchArgs.DynSharedMemory = 0;
+ }
+
+ void TearDown() override {
+ RETURN_ON_FATAL_FAILURE(OffloadProgramTest::TearDown());
+ }
+
+ ol_symbol_handle_t Kernel = nullptr;
+ ol_kernel_launch_size_args_t LaunchArgs{};
+};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olWaitQueueTest);
+
+TEST_P(olWaitQueueTest, Success) {
+ constexpr size_t NUM_KERNELS = 16;
+ ol_queue_handle_t Queues[NUM_KERNELS];
+ ol_event_handle_t Events[NUM_KERNELS];
+
+ void *Mem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ NUM_KERNELS * sizeof(uint32_t), &Mem));
+ struct {
+ uint32_t Idx;
+ void *Mem;
+ } Args{0, Mem};
+
+ for (size_t I = 0; I < NUM_KERNELS; I++) {
+ Args.Idx = I;
+
+ ASSERT_SUCCESS(olCreateQueue(Device, &Queues[I]));
+
+ if (I > 0)
+ ASSERT_SUCCESS(olWaitQueue(Queues[I], &Events[I - 1], 1));
+
+ ASSERT_SUCCESS(olLaunchKernel(Queues[I], Device, Kernel, &Args,
+ sizeof(Args), &LaunchArgs, &Events[I]));
+ }
+
+ ASSERT_SUCCESS(olSyncEvent(Events[NUM_KERNELS - 1]));
+
+ uint32_t *Data = (uint32_t *)Mem;
+ for (uint32_t i = 2; i < NUM_KERNELS; i++) {
+ ASSERT_EQ(Data[i], Data[i - 1] + Data[i - 2]);
+ }
+}
+
+TEST_P(olWaitQueueTest, SuccessSingleQueue) {
+ constexpr size_t NUM_KERNELS = 16;
+ ol_queue_handle_t Queue;
+ ol_event_handle_t Events[NUM_KERNELS];
+
+ ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
+
+ void *Mem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ NUM_KERNELS * sizeof(uint32_t), &Mem));
+ struct {
+ uint32_t Idx;
+ void *Mem;
+ } Args{0, Mem};
+
+ for (size_t I = 0; I < NUM_KERNELS; I++) {
+ Args.Idx = I;
+
+ if (I > 0)
+ ASSERT_SUCCESS(olWaitQueue(Queue, &Events[I - 1], 1));
+
+ ASSERT_SUCCESS(olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args),
+ &LaunchArgs, &Events[I]));
+ }
+
+ ASSERT_SUCCESS(olSyncEvent(Events[NUM_KERNELS - 1]));
+
+ uint32_t *Data = (uint32_t *)Mem;
+ for (uint32_t i = 2; i < NUM_KERNELS; i++) {
+ ASSERT_EQ(Data[i], Data[i - 1] + Data[i - 2]);
+ }
+}
+
+TEST_P(olWaitQueueTest, SuccessMultipleEvents) {
+ constexpr size_t NUM_KERNELS = 16;
+ ol_queue_handle_t Queues[NUM_KERNELS];
+ ol_event_handle_t Events[NUM_KERNELS];
+
+ void *Mem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ NUM_KERNELS * sizeof(uint32_t), &Mem));
+ struct {
+ uint32_t Idx;
+ void *Mem;
+ } Args{0, Mem};
+
+ for (size_t I = 0; I < NUM_KERNELS; I++) {
+ Args.Idx = I;
+
+ ASSERT_SUCCESS(olCreateQueue(Device, &Queues[I]));
+
+ if (I > 0)
+ ASSERT_SUCCESS(olWaitQueue(Queues[I], Events, I));
+
+ ASSERT_SUCCESS(olLaunchKernel(Queues[I], Device, Kernel, &Args,
+ sizeof(Args), &LaunchArgs, &Events[I]));
+ }
+
+ ASSERT_SUCCESS(olSyncEvent(Events[NUM_KERNELS - 1]));
+
+ uint32_t *Data = (uint32_t *)Mem;
+ for (uint32_t i = 2; i < NUM_KERNELS; i++) {
+ ASSERT_EQ(Data[i], Data[i - 1] + Data[i - 2]);
+ }
+}
+
+TEST_P(olWaitQueueTest, InvalidNullQueue) {
+ ol_event_handle_t Event;
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olWaitQueue(nullptr, &Event, 1));
+}
+
+TEST_P(olWaitQueueTest, InvalidNullEvent) {
+ ol_queue_handle_t Queue;
+ ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olWaitQueue(Queue, nullptr, 1));
+}
+
+TEST_P(olWaitQueueTest, InvalidNullInnerEvent) {
+ ol_queue_handle_t Queue;
+ ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
+ ol_event_handle_t Event = nullptr;
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olWaitQueue(Queue, &Event, 1));
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simple enough
ASSERT_SUCCESS(olCreateQueue(Device, &Queues[I])); | ||
|
||
if (I > 0) | ||
ASSERT_SUCCESS(olWaitQueue(Queues[I], Events, I)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While looking at HSA documentation, I noticed that the primitive they use to implement this can accept up to 4(?) signals to wait on. Likewise, OpenCL functions accept lists of dependant events rather than just one. I know this isn't supported by PluginInterface yet, but I figured I'd have it available in the offload interface in case it gets implemented in the future.
✅ With the latest revision this PR passed the C/C++ code formatter. |
I think the name sort of implies waiting on another queue rather than the queue waiting on events. IMO a name like olEnqueueWaits or olWaitEvents would be clearer. Just a nitpick though, otherwise LGTM. |
If it destroys the queue it could be like 'Finalize' or something. |
Not to be confused with olSyncQueue, which used to be called olWaitQueue until #150023. This function causes a queue to wait until all the provided events have completed before running any future scheduled work.
ef58e6a
to
9539bfe
Compare
Renamed it to @jhuber6 are you okay with this name? |
This function causes a queue to wait until all the provided events have completed before running any future scheduled work.