Skip to content

Commit f10396e

Browse files
committed
[UR][SYCL] Add support for zeCommandListAppendLaunchKernelWithArguments()
Signed-off-by: Lukasz Dorau <[email protected]>
1 parent ef305e2 commit f10396e

File tree

6 files changed

+187
-7
lines changed

6 files changed

+187
-7
lines changed

unified-runtime/source/adapters/level_zero/platform.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,16 @@ ur_result_t ur_platform_handle_t_::initialize() {
526526
ZeMemGetPitchFor2dImageExt.Supported =
527527
ZeMemGetPitchFor2dImageExt.zeMemGetPitchFor2dImage != nullptr;
528528

529+
ZE_CALL_NOCHECK(zeDriverGetExtensionFunctionAddress,
530+
(ZeDriver, "zeCommandListAppendLaunchKernelWithArguments",
531+
reinterpret_cast<void **>(
532+
&ZeCommandListAppendLaunchKernelWithArgumentsExt
533+
.zeCommandListAppendLaunchKernelWithArguments)));
534+
535+
ZeCommandListAppendLaunchKernelWithArgumentsExt.Supported =
536+
ZeCommandListAppendLaunchKernelWithArgumentsExt
537+
.zeCommandListAppendLaunchKernelWithArguments != nullptr;
538+
529539
return UR_RESULT_SUCCESS;
530540
}
531541

unified-runtime/source/adapters/level_zero/platform.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,12 @@ struct ur_platform_handle_t_ : ur::handle_base<ur::level_zero::ddi_getter>,
162162
ze_device_handle_t, size_t, size_t,
163163
unsigned int, size_t *);
164164
} ZeMemGetPitchFor2dImageExt;
165+
166+
struct ZeCommandListAppendLaunchKernelWithArgumentsExtension {
167+
bool Supported = false;
168+
ze_result_t (*zeCommandListAppendLaunchKernelWithArguments)(
169+
ze_command_list_handle_t, ze_kernel_handle_t, const ze_group_count_t,
170+
const ze_group_size_t, void **, void *, ze_event_handle_t, uint32_t,
171+
ze_event_handle_t *);
172+
} ZeCommandListAppendLaunchKernelWithArgumentsExt;
165173
};

unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp

Lines changed: 124 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ ur_result_t ur_command_list_manager::releaseSubmittedKernels() {
10391039
return UR_RESULT_SUCCESS;
10401040
}
10411041

1042-
ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp(
1042+
ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpOld(
10431043
ur_kernel_handle_t hKernel, uint32_t workDim,
10441044
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
10451045
const size_t *pLocalWorkSize, uint32_t numArgs,
@@ -1048,8 +1048,6 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp(
10481048
const ur_kernel_launch_property_t *launchPropList,
10491049
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
10501050
ur_event_handle_t phEvent) {
1051-
TRACK_SCOPE_LATENCY(
1052-
"ur_queue_immediate_in_order_t::enqueueKernelLaunchWithArgsExp");
10531051
{
10541052
std::scoped_lock<ur_shared_mutex> guard(hKernel->Mutex);
10551053
for (uint32_t argIndex = 0; argIndex < numArgs; argIndex++) {
@@ -1095,3 +1093,126 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp(
10951093

10961094
return UR_RESULT_SUCCESS;
10971095
}
1096+
1097+
ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpNew(
1098+
ur_kernel_handle_t hKernel, uint32_t workDim,
1099+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
1100+
const size_t *pLocalWorkSize, uint32_t numArgs,
1101+
const ur_exp_kernel_arg_properties_t *pArgs,
1102+
uint32_t numPropsInLaunchPropList,
1103+
const ur_kernel_launch_property_t *launchPropList,
1104+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1105+
ur_event_handle_t phEvent) {
1106+
1107+
// TODO: remove memory allocation
1108+
1109+
// kernelMemObj contains kernel memory objects that
1110+
// UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ kernelArgs pointers point to
1111+
std::vector<void *> kernelMemObj(numArgs, nullptr);
1112+
std::vector<void *> kernelArgs(numArgs, nullptr);
1113+
std::scoped_lock<ur_shared_mutex> Lock(hKernel->Mutex);
1114+
1115+
for (uint32_t argIndex = 0; argIndex < numArgs; argIndex++) {
1116+
switch (pArgs[argIndex].type) {
1117+
case UR_EXP_KERNEL_ARG_TYPE_LOCAL:
1118+
kernelArgs[argIndex] = (void *)&pArgs[argIndex].size;
1119+
break;
1120+
case UR_EXP_KERNEL_ARG_TYPE_VALUE:
1121+
kernelArgs[argIndex] = (void *)pArgs[argIndex].value.value;
1122+
break;
1123+
case UR_EXP_KERNEL_ARG_TYPE_POINTER:
1124+
kernelArgs[argIndex] = (void *)&pArgs[argIndex].value.pointer;
1125+
break;
1126+
case UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ:
1127+
// prepareForSubmission() will save zePtr in kernelMemObj[argIndex]
1128+
kernelArgs[argIndex] = &kernelMemObj[argIndex];
1129+
UR_CALL(hKernel->addPendingMemoryAllocation(
1130+
{pArgs[argIndex].value.memObjTuple.hMem,
1131+
ur_mem_buffer_t::device_access_mode_t::read_write,
1132+
pArgs[argIndex].index}));
1133+
break;
1134+
case UR_EXP_KERNEL_ARG_TYPE_SAMPLER: {
1135+
kernelArgs[argIndex] = &pArgs[argIndex].value.sampler->ZeSampler;
1136+
break;
1137+
}
1138+
default:
1139+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
1140+
}
1141+
}
1142+
1143+
for (uint32_t propIndex = 0; propIndex < numPropsInLaunchPropList;
1144+
propIndex++) {
1145+
if (launchPropList[propIndex].id != UR_KERNEL_LAUNCH_PROPERTY_ID_IGNORE) {
1146+
// We don't support any other properties.
1147+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1148+
}
1149+
}
1150+
1151+
UR_ASSERT(hKernel->getProgramHandle(), UR_RESULT_ERROR_INVALID_NULL_POINTER);
1152+
UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
1153+
UR_ASSERT(workDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
1154+
1155+
ze_kernel_handle_t hZeKernel = hKernel->getZeHandle(hDevice.get());
1156+
1157+
ze_group_count_t zeThreadGroupDimensions{1, 1, 1};
1158+
uint32_t WG[3]{};
1159+
UR_CALL(calculateKernelWorkDimensions(hZeKernel, hDevice.get(),
1160+
zeThreadGroupDimensions, WG, workDim,
1161+
pGlobalWorkSize, pLocalWorkSize));
1162+
1163+
ze_group_size_t groupSize = {WG[0], WG[1], WG[2]};
1164+
1165+
auto zeSignalEvent = getSignalEvent(phEvent, UR_COMMAND_KERNEL_LAUNCH);
1166+
auto waitListView = getWaitListView(phEventWaitList, numEventsInWaitList);
1167+
1168+
UR_CALL(hKernel->prepareForSubmission(
1169+
hContext.get(), hDevice.get(), pGlobalWorkOffset, workDim, WG[0], WG[1],
1170+
WG[2], getZeCommandList(), waitListView, &kernelArgs));
1171+
1172+
{
1173+
TRACK_SCOPE_LATENCY("ur_command_list_manager::"
1174+
"zeCommandListAppendLaunchKernelWithArguments");
1175+
ZE2UR_CALL(hContext->getPlatform()
1176+
->ZeCommandListAppendLaunchKernelWithArgumentsExt
1177+
.zeCommandListAppendLaunchKernelWithArguments,
1178+
(getZeCommandList(), hZeKernel, zeThreadGroupDimensions,
1179+
groupSize, kernelArgs.data(), nullptr, zeSignalEvent,
1180+
waitListView.num, waitListView.handles));
1181+
}
1182+
1183+
recordSubmittedKernel(hKernel);
1184+
1185+
postSubmit(hZeKernel, pGlobalWorkOffset);
1186+
1187+
return UR_RESULT_SUCCESS;
1188+
}
1189+
1190+
ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp(
1191+
ur_kernel_handle_t hKernel, uint32_t workDim,
1192+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
1193+
const size_t *pLocalWorkSize, uint32_t numArgs,
1194+
const ur_exp_kernel_arg_properties_t *pArgs,
1195+
uint32_t numPropsInLaunchPropList,
1196+
const ur_kernel_launch_property_t *launchPropList,
1197+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1198+
ur_event_handle_t phEvent) {
1199+
TRACK_SCOPE_LATENCY(
1200+
"ur_queue_immediate_in_order_t::enqueueKernelLaunchWithArgsExp");
1201+
1202+
UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
1203+
1204+
if (hContext->getPlatform()
1205+
->ZeCommandListAppendLaunchKernelWithArgumentsExt.Supported) {
1206+
return appendKernelLaunchWithArgsExpNew(
1207+
hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
1208+
numArgs, pArgs, numPropsInLaunchPropList, launchPropList,
1209+
numEventsInWaitList, phEventWaitList, phEvent);
1210+
} else {
1211+
return appendKernelLaunchWithArgsExpOld(
1212+
hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
1213+
numArgs, pArgs, numPropsInLaunchPropList, launchPropList,
1214+
numEventsInWaitList, phEventWaitList, phEvent);
1215+
}
1216+
1217+
return UR_RESULT_SUCCESS;
1218+
}

unified-runtime/source/adapters/level_zero/v2/command_list_manager.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,26 @@ struct ur_command_list_manager {
244244
ur_event_handle_t phEvent);
245245

246246
private:
247+
ur_result_t appendKernelLaunchWithArgsExpOld(
248+
ur_kernel_handle_t hKernel, uint32_t workDim,
249+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
250+
const size_t *pLocalWorkSize, uint32_t numArgs,
251+
const ur_exp_kernel_arg_properties_t *pArgs,
252+
uint32_t numPropsInLaunchPropList,
253+
const ur_kernel_launch_property_t *launchPropList,
254+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
255+
ur_event_handle_t phEvent);
256+
257+
ur_result_t appendKernelLaunchWithArgsExpNew(
258+
ur_kernel_handle_t hKernel, uint32_t workDim,
259+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
260+
const size_t *pLocalWorkSize, uint32_t numArgs,
261+
const ur_exp_kernel_arg_properties_t *pArgs,
262+
uint32_t numPropsInLaunchPropList,
263+
const ur_kernel_launch_property_t *launchPropList,
264+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
265+
ur_event_handle_t phEvent);
266+
247267
ur_result_t appendGenericCommandListsExp(
248268
uint32_t numCommandLists, ze_command_list_handle_t *phCommandLists,
249269
ur_event_handle_t phEvent, uint32_t numEventsInWaitList,

unified-runtime/source/adapters/level_zero/v2/kernel.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
271271
ur_context_handle_t hContext, ur_device_handle_t hDevice,
272272
const size_t *pGlobalWorkOffset, uint32_t workDim, uint32_t groupSizeX,
273273
uint32_t groupSizeY, uint32_t groupSizeZ,
274-
ze_command_list_handle_t commandList, wait_list_view &waitListView) {
274+
ze_command_list_handle_t commandList, wait_list_view &waitListView,
275+
std::vector<void *> *kernelArgs) {
275276
auto &deviceKernelOpt = deviceKernels[deviceIndex(hDevice)];
276277
if (!deviceKernelOpt.has_value())
277278
return UR_RESULT_ERROR_INVALID_KERNEL;
@@ -299,8 +300,25 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
299300
zePtr = reinterpret_cast<void *>(hImage->getZeImage());
300301
}
301302
}
302-
// Set the argument only on this device's kernel.
303-
UR_CALL(deviceKernel.setArgPointer(pending.argIndex, zePtr));
303+
304+
// kernelArgs must be non-nullptr in the path of
305+
// zeCommandListAppendLaunchKernelWithArguments()
306+
if (kernelArgs) {
307+
// zeCommandListAppendLaunchKernelWithArguments()
308+
// (==CommandListCoreFamily<gfxCoreFamily>::appendLaunchKernelWithArguments())
309+
// calls setArgumentValue(i, argSize, argValue) for all arguments on its
310+
// own so do not call it here, but save the zePtr pointer in kernelArgs
311+
// for this future call.
312+
if (pending.argIndex > kernelArgs->size() - 1) {
313+
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
314+
}
315+
UR_ASSERT((*kernelArgs)[pending.argIndex] != nullptr,
316+
UR_RESULT_ERROR_INVALID_NULL_POINTER);
317+
*((void **)(*kernelArgs)[pending.argIndex]) = zePtr;
318+
} else {
319+
// Set the argument only on this device's kernel.
320+
UR_CALL(deviceKernel.setArgPointer(pending.argIndex, zePtr));
321+
}
304322
}
305323
pending_allocations.clear();
306324

unified-runtime/source/adapters/level_zero/v2/kernel.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,16 @@ struct ur_kernel_handle_t_ : ur_object {
9696

9797
// Set all required values for the kernel before submission (including pending
9898
// memory allocations).
99+
// The kernelArgs argument must be non-nullptr
100+
// in the path of zeCommandListAppendLaunchKernelWithArguments()
99101
ur_result_t prepareForSubmission(ur_context_handle_t hContext,
100102
ur_device_handle_t hDevice,
101103
const size_t *pGlobalWorkOffset,
102104
uint32_t workDim, uint32_t groupSizeX,
103105
uint32_t groupSizeY, uint32_t groupSizeZ,
104106
ze_command_list_handle_t cmdList,
105-
wait_list_view &waitListView);
107+
wait_list_view &waitListView,
108+
std::vector<void *> *kernelArgs = nullptr);
106109

107110
// Get context of the kernel.
108111
ur_context_handle_t getContext() const { return hProgram->Context; }

0 commit comments

Comments
 (0)