@@ -1039,7 +1039,7 @@ ur_result_t ur_command_list_manager::releaseSubmittedKernels() {
10391039 return UR_RESULT_SUCCESS;
10401040}
10411041
1042- ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp (
1042+ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpOld (
10431043 ur_kernel_handle_t hKernel, uint32_t workDim,
10441044 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
10451045 const size_t *pLocalWorkSize, uint32_t numArgs,
@@ -1048,8 +1048,6 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp(
10481048 const ur_kernel_launch_property_t *launchPropList,
10491049 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
10501050 ur_event_handle_t phEvent) {
1051- TRACK_SCOPE_LATENCY (
1052- " ur_queue_immediate_in_order_t::enqueueKernelLaunchWithArgsExp" );
10531051 {
10541052 std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
10551053 for (uint32_t argIndex = 0 ; argIndex < numArgs; argIndex++) {
@@ -1095,3 +1093,126 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp(
10951093
10961094 return UR_RESULT_SUCCESS;
10971095}
1096+
1097+ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpNew (
1098+ ur_kernel_handle_t hKernel, uint32_t workDim,
1099+ const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
1100+ const size_t *pLocalWorkSize, uint32_t numArgs,
1101+ const ur_exp_kernel_arg_properties_t *pArgs,
1102+ uint32_t numPropsInLaunchPropList,
1103+ const ur_kernel_launch_property_t *launchPropList,
1104+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1105+ ur_event_handle_t phEvent) {
1106+
1107+ // TODO: remove memory allocation
1108+
1109+ // kernelMemObj contains kernel memory objects that
1110+ // UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ kernelArgs pointers point to
1111+ std::vector<void *> kernelMemObj (numArgs, nullptr );
1112+ std::vector<void *> kernelArgs (numArgs, nullptr );
1113+ std::scoped_lock<ur_shared_mutex> Lock (hKernel->Mutex );
1114+
1115+ for (uint32_t argIndex = 0 ; argIndex < numArgs; argIndex++) {
1116+ switch (pArgs[argIndex].type ) {
1117+ case UR_EXP_KERNEL_ARG_TYPE_LOCAL:
1118+ kernelArgs[argIndex] = (void *)&pArgs[argIndex].size ;
1119+ break ;
1120+ case UR_EXP_KERNEL_ARG_TYPE_VALUE:
1121+ kernelArgs[argIndex] = (void *)pArgs[argIndex].value .value ;
1122+ break ;
1123+ case UR_EXP_KERNEL_ARG_TYPE_POINTER:
1124+ kernelArgs[argIndex] = (void *)&pArgs[argIndex].value .pointer ;
1125+ break ;
1126+ case UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ:
1127+ // prepareForSubmission() will save zePtr in kernelMemObj[argIndex]
1128+ kernelArgs[argIndex] = &kernelMemObj[argIndex];
1129+ UR_CALL (hKernel->addPendingMemoryAllocation (
1130+ {pArgs[argIndex].value .memObjTuple .hMem ,
1131+ ur_mem_buffer_t ::device_access_mode_t ::read_write,
1132+ pArgs[argIndex].index }));
1133+ break ;
1134+ case UR_EXP_KERNEL_ARG_TYPE_SAMPLER: {
1135+ kernelArgs[argIndex] = &pArgs[argIndex].value .sampler ->ZeSampler ;
1136+ break ;
1137+ }
1138+ default :
1139+ return UR_RESULT_ERROR_INVALID_ENUMERATION;
1140+ }
1141+ }
1142+
1143+ for (uint32_t propIndex = 0 ; propIndex < numPropsInLaunchPropList;
1144+ propIndex++) {
1145+ if (launchPropList[propIndex].id != UR_KERNEL_LAUNCH_PROPERTY_ID_IGNORE) {
1146+ // We don't support any other properties.
1147+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1148+ }
1149+ }
1150+
1151+ UR_ASSERT (hKernel->getProgramHandle (), UR_RESULT_ERROR_INVALID_NULL_POINTER);
1152+ UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
1153+ UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
1154+
1155+ ze_kernel_handle_t hZeKernel = hKernel->getZeHandle (hDevice.get ());
1156+
1157+ ze_group_count_t zeThreadGroupDimensions{1 , 1 , 1 };
1158+ uint32_t WG[3 ]{};
1159+ UR_CALL (calculateKernelWorkDimensions (hZeKernel, hDevice.get (),
1160+ zeThreadGroupDimensions, WG, workDim,
1161+ pGlobalWorkSize, pLocalWorkSize));
1162+
1163+ ze_group_size_t groupSize = {WG[0 ], WG[1 ], WG[2 ]};
1164+
1165+ auto zeSignalEvent = getSignalEvent (phEvent, UR_COMMAND_KERNEL_LAUNCH);
1166+ auto waitListView = getWaitListView (phEventWaitList, numEventsInWaitList);
1167+
1168+ UR_CALL (hKernel->prepareForSubmission (
1169+ hContext.get (), hDevice.get (), pGlobalWorkOffset, workDim, WG[0 ], WG[1 ],
1170+ WG[2 ], getZeCommandList (), waitListView, &kernelArgs));
1171+
1172+ {
1173+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::"
1174+ " zeCommandListAppendLaunchKernelWithArguments" );
1175+ ZE2UR_CALL (hContext->getPlatform ()
1176+ ->ZeCommandListAppendLaunchKernelWithArgumentsExt
1177+ .zeCommandListAppendLaunchKernelWithArguments ,
1178+ (getZeCommandList (), hZeKernel, zeThreadGroupDimensions,
1179+ groupSize, kernelArgs.data (), nullptr , zeSignalEvent,
1180+ waitListView.num , waitListView.handles ));
1181+ }
1182+
1183+ recordSubmittedKernel (hKernel);
1184+
1185+ postSubmit (hZeKernel, pGlobalWorkOffset);
1186+
1187+ return UR_RESULT_SUCCESS;
1188+ }
1189+
1190+ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp (
1191+ ur_kernel_handle_t hKernel, uint32_t workDim,
1192+ const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
1193+ const size_t *pLocalWorkSize, uint32_t numArgs,
1194+ const ur_exp_kernel_arg_properties_t *pArgs,
1195+ uint32_t numPropsInLaunchPropList,
1196+ const ur_kernel_launch_property_t *launchPropList,
1197+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1198+ ur_event_handle_t phEvent) {
1199+ TRACK_SCOPE_LATENCY (
1200+ " ur_queue_immediate_in_order_t::enqueueKernelLaunchWithArgsExp" );
1201+
1202+ UR_ASSERT (hKernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
1203+
1204+ if (hContext->getPlatform ()
1205+ ->ZeCommandListAppendLaunchKernelWithArgumentsExt .Supported ) {
1206+ return appendKernelLaunchWithArgsExpNew (
1207+ hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
1208+ numArgs, pArgs, numPropsInLaunchPropList, launchPropList,
1209+ numEventsInWaitList, phEventWaitList, phEvent);
1210+ } else {
1211+ return appendKernelLaunchWithArgsExpOld (
1212+ hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
1213+ numArgs, pArgs, numPropsInLaunchPropList, launchPropList,
1214+ numEventsInWaitList, phEventWaitList, phEvent);
1215+ }
1216+
1217+ return UR_RESULT_SUCCESS;
1218+ }
0 commit comments