@@ -420,8 +420,6 @@ ur_result_t urEnqueueKernelLaunch(
420420 // / [out][optional] return an event object that identifies this
421421 // / particular kernel execution instance.
422422 ur_event_handle_t *phEvent) {
423- auto pfnKernelLaunch = getContext ()->urDdiTable .Enqueue .pfnKernelLaunch ;
424-
425423 getContext ()->logger .debug (" ==== urEnqueueKernelLaunch" );
426424
427425 USMLaunchInfo LaunchInfo (GetContext (hQueue), GetDevice (hQueue),
@@ -431,22 +429,14 @@ ur_result_t urEnqueueKernelLaunch(
431429
432430 UR_CALL (getMsanInterceptor ()->preLaunchKernel (hKernel, hQueue, LaunchInfo));
433431
434- ur_event_handle_t hEvent{};
435- ur_result_t result =
436- pfnKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
437- pGlobalWorkSize, LaunchInfo.LocalWorkSize .data (),
438- numEventsInWaitList, phEventWaitList, &hEvent);
439-
440- if (result == UR_RESULT_SUCCESS) {
441- UR_CALL (
442- getMsanInterceptor ()->postLaunchKernel (hKernel, hQueue, LaunchInfo));
443- }
432+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnKernelLaunch (
433+ hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
434+ LaunchInfo.LocalWorkSize .data (), numEventsInWaitList, phEventWaitList,
435+ phEvent));
444436
445- if (phEvent) {
446- *phEvent = hEvent;
447- }
437+ UR_CALL (getMsanInterceptor ()->postLaunchKernel (hKernel, hQueue, LaunchInfo));
448438
449- return result ;
439+ return UR_RESULT_SUCCESS ;
450440}
451441
452442// /////////////////////////////////////////////////////////////////////////////
@@ -1323,6 +1313,58 @@ ur_result_t urEnqueueMemUnmap(
13231313 return UR_RESULT_SUCCESS;
13241314}
13251315
1316+ ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp (
1317+ // / [in] handle of the queue object
1318+ ur_queue_handle_t hQueue,
1319+ // / [in] handle of the kernel object
1320+ ur_kernel_handle_t hKernel,
1321+ // / [in] number of dimensions, from 1 to 3, to specify the global and
1322+ // / work-group work-items
1323+ uint32_t workDim,
1324+ // / [in] pointer to an array of workDim unsigned values that specify the
1325+ // / offset used to calculate the global ID of a work-item
1326+ const size_t *pGlobalWorkOffset,
1327+ // / [in] pointer to an array of workDim unsigned values that specify the
1328+ // / number of global work-items in workDim that will execute the kernel
1329+ // / function
1330+ const size_t *pGlobalWorkSize,
1331+ // / [in][optional] pointer to an array of workDim unsigned values that
1332+ // / specify the number of local work-items forming a work-group that will
1333+ // / execute the kernel function.
1334+ // / If nullptr, the runtime implementation will choose the work-group size.
1335+ const size_t *pLocalWorkSize,
1336+ // / [in] size of the event wait list
1337+ uint32_t numEventsInWaitList,
1338+ // / [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1339+ // / events that must be complete before the kernel execution.
1340+ // / If nullptr, the numEventsInWaitList must be 0, indicating that no wait
1341+ // / event.
1342+ const ur_event_handle_t *phEventWaitList,
1343+ // / [out][optional][alloc] return an event object that identifies this
1344+ // / particular kernel execution instance. If phEventWaitList and phEvent
1345+ // / are not NULL, phEvent must not refer to an element of the
1346+ // / phEventWaitList array.
1347+ ur_event_handle_t *phEvent) {
1348+
1349+ getContext ()->logger .debug (" ==== urEnqueueCooperativeKernelLaunchExp" );
1350+
1351+ USMLaunchInfo LaunchInfo (GetContext (hQueue), GetDevice (hQueue),
1352+ pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
1353+ workDim);
1354+ UR_CALL (LaunchInfo.initialize ());
1355+
1356+ UR_CALL (getMsanInterceptor ()->preLaunchKernel (hKernel, hQueue, LaunchInfo));
1357+
1358+ UR_CALL (getContext ()->urDdiTable .EnqueueExp .pfnCooperativeKernelLaunchExp (
1359+ hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
1360+ LaunchInfo.LocalWorkSize .data (), numEventsInWaitList, phEventWaitList,
1361+ phEvent));
1362+
1363+ UR_CALL (getMsanInterceptor ()->postLaunchKernel (hKernel, hQueue, LaunchInfo));
1364+
1365+ return UR_RESULT_SUCCESS;
1366+ }
1367+
13261368// /////////////////////////////////////////////////////////////////////////////
13271369// / @brief Intercept function for urKernelRetain
13281370ur_result_t urKernelRetain (
@@ -1912,6 +1954,25 @@ ur_result_t urCheckVersion(ur_api_version_t version) {
19121954 return UR_RESULT_SUCCESS;
19131955}
19141956
1957+ // /////////////////////////////////////////////////////////////////////////////
1958+ // / @brief Exported function for filling application's EnqueueExp table
1959+ // / with current process' addresses
1960+ // /
1961+ // / @returns
1962+ // / - ::UR_RESULT_SUCCESS
1963+ // / - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
1964+ __urdlllocal ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable (
1965+ // / [in,out] pointer to table of DDI function pointers
1966+ ur_enqueue_exp_dditable_t *pDdiTable) {
1967+ if (nullptr == pDdiTable) {
1968+ return UR_RESULT_ERROR_INVALID_NULL_POINTER;
1969+ }
1970+
1971+ pDdiTable->pfnCooperativeKernelLaunchExp =
1972+ ur_sanitizer_layer::msan::urEnqueueCooperativeKernelLaunchExp;
1973+ return UR_RESULT_SUCCESS;
1974+ }
1975+
19151976} // namespace msan
19161977
19171978ur_result_t initMsanDDITable (ur_dditable_t *dditable) {
@@ -1966,6 +2027,11 @@ ur_result_t initMsanDDITable(ur_dditable_t *dditable) {
19662027 result = ur_sanitizer_layer::msan::urGetUSMProcAddrTable (&dditable->USM );
19672028 }
19682029
2030+ if (UR_RESULT_SUCCESS == result) {
2031+ result = ur_sanitizer_layer::msan::urGetEnqueueExpProcAddrTable (
2032+ &dditable->EnqueueExp );
2033+ }
2034+
19692035 if (result != UR_RESULT_SUCCESS) {
19702036 getContext ()->logger .error (" Initialize MSAN DDI table failed: {}" , result);
19712037 }
0 commit comments