Skip to content

[DeviceMSAN] Fix false negative report due to __spirv_GroupAsyncCopy #18216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 21, 2025
47 changes: 47 additions & 0 deletions libdevice/sanitizer/msan_rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,16 @@ inline void __msan_exit() {
__devicelib_exit();
}

// This function is only used for shadow propagation
template <typename T>
void GroupAsyncCopy(uptr Dest, uptr Src, size_t NumElements, size_t Stride) {
auto DestPtr = (__SYCL_GLOBAL__ T *)Dest;
auto SrcPtr = (const __SYCL_GLOBAL__ T *)Src;
for (size_t i = 0; i < NumElements; i++) {
DestPtr[i] = SrcPtr[i * Stride];
}
}

} // namespace

#define MSAN_MAYBE_WARNING(type, size) \
Expand Down Expand Up @@ -589,4 +599,41 @@ __msan_set_private_base(__SYCL_PRIVATE__ void *ptr) {
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_private_base, sid, ptr));
}

static __SYCL_CONSTANT__ const char __msan_print_strided_copy_unsupport_type[] =
"[kernel] __msan_unpoison_strided_copy: unsupported type(%d)\n";

DEVICE_EXTERN_C_NOINLINE void
__msan_unpoison_strided_copy(uptr dest, uint32_t dest_as, uptr src,
uint32_t src_as, uint32_t element_size,
uptr counts, uptr stride) {
if (!GetMsanLaunchInfo)
return;

MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg,
"__msan_unpoison_strided_copy"));

uptr shadow_dest = (uptr)__msan_get_shadow(dest, dest_as);
uptr shadow_src = (uptr)__msan_get_shadow(src, src_as);

switch (element_size) {
case 1:
GroupAsyncCopy<int8_t>(shadow_dest, shadow_src, counts, stride);
break;
case 2:
GroupAsyncCopy<int16_t>(shadow_dest, shadow_src, counts, stride);
break;
case 4:
GroupAsyncCopy<int32_t>(shadow_dest, shadow_src, counts, stride);
break;
case 8:
GroupAsyncCopy<int64_t>(shadow_dest, shadow_src, counts, stride);
break;
default:
__spirv_ocl_printf(__msan_print_strided_copy_unsupport_type, element_size);
}

MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end,
"__msan_unpoison_strided_copy"));
}

#endif // __SPIR__ || __SPIRV__
106 changes: 103 additions & 3 deletions llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,8 @@ class MemorySanitizerOnSpirv {
void initializeKernelCallerMap(Function *F);

private:
friend struct MemorySanitizerVisitor;

Module &M;
LLVMContext &C;
const DataLayout &DL;
Expand Down Expand Up @@ -833,6 +835,7 @@ class MemorySanitizerOnSpirv {
FunctionCallee MsanBarrierFunc;
FunctionCallee MsanUnpoisonStackFunc;
FunctionCallee MsanSetPrivateBaseFunc;
FunctionCallee MsanUnpoisonStridedCopyFunc;
};

} // end anonymous namespace
Expand Down Expand Up @@ -899,14 +902,14 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
M.getOrInsertFunction("__msan_unpoison_shadow_static_local",
IRB.getVoidTy(), IntptrTy, IntptrTy);

// __asan_poison_shadow_dynamic_local(
// __msan_poison_shadow_dynamic_local(
// uptr ptr,
// uint32_t num_args
// )
MsanPoisonShadowDynamicLocalFunc = M.getOrInsertFunction(
"__msan_poison_shadow_dynamic_local", IRB.getVoidTy(), IntptrTy, Int32Ty);

// __asan_unpoison_shadow_dynamic_local(
// __msan_unpoison_shadow_dynamic_local(
// uptr ptr,
// uint32_t num_args
// )
Expand All @@ -930,6 +933,18 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
MsanSetPrivateBaseFunc =
M.getOrInsertFunction("__msan_set_private_base", IRB.getVoidTy(),
PointerType::get(C, kSpirOffloadPrivateAS));

// __msan_unpoison_strided_copy(
// uptr dest, uint32_t dest_as,
// uptr src, uint32_t src_as,
// uint32_t element_size,
// uptr counts,
// uptr stride
// )
MsanUnpoisonStridedCopyFunc = M.getOrInsertFunction(
"__msan_unpoison_strided_copy", IRB.getVoidTy(), IntptrTy,
IRB.getInt32Ty(), IntptrTy, IRB.getInt32Ty(), IRB.getInt32Ty(),
IRB.getInt64Ty(), IRB.getInt64Ty());
}

// Handle global variables:
Expand Down Expand Up @@ -1833,7 +1848,8 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) {
}
} else {
auto FuncName = Func->getName();
if (FuncName.contains("__spirv_"))
if (FuncName.contains("__spirv_") &&
!FuncName.contains("__spirv_GroupAsyncCopy"))
I.setNoSanitizeMetadata();
}
}
Expand All @@ -1843,6 +1859,55 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) {
I.setNoSanitizeMetadata();
}

// This is not a general-purpose function, but a helper for demangling
// "__spirv_GroupAsyncCopy" function name
static int getTypeSizeFromManglingName(StringRef Name) {
auto GetTypeSize = [](const char C) {
switch (C) {
case 'a': // signed char
case 'c': // char
return 1;
case 's': // short
return 2;
case 'f': // float
case 'i': // int
return 4;
case 'd': // double
case 'l': // long
return 8;
default:
return 0;
}
};

// Name should always be long enough since it has other unmeaningful chars,
// it should have at least 6 chars, such as "Dv16_d"
if (Name.size() < 6)
return 0;

// 1. Basic type
if (Name[0] != 'D')
return GetTypeSize(Name[0]);

// 2. Vector type

// Drop "Dv"
assert(Name[0] == 'D' && Name[1] == 'v' &&
"Invalid mangling name for vector type");
Name = Name.drop_front(2);

// Vector length
assert(isDigit(Name[0]) && "Invalid mangling name for vector type");
int Len = std::stoi(Name.str());
Name = Name.drop_front(Len >= 10 ? 2 : 1);

assert(Name[0] == '_' && "Invalid mangling name for vector type");
Name = Name.drop_front(1);

int Size = GetTypeSize(Name[0]);
return Len * Size;
}

namespace {

/// Helper class to attach debug information of the given instruction onto new
Expand Down Expand Up @@ -6395,6 +6460,41 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
VAHelper->visitCallBase(CB, IRB);
}

if (SpirOrSpirv) {
auto *Func = CB.getCalledFunction();
if (Func) {
auto FuncName = Func->getName();
if (FuncName.contains("__spirv_GroupAsyncCopy")) {
// clang-format off
// Handle functions like "_Z22__spirv_GroupAsyncCopyiPU3AS3dPU3AS1dllP13__spirv_Event",
// its demangled name is "__spirv_GroupAsyncCopy(int, double AS3* dst, double AS1* src, long, long, __spirv_Event*)"
// The type of "src" and "dst" should always be same.
// clang-format on

auto *Dest = CB.getArgOperand(1);
auto *Src = CB.getArgOperand(2);
auto *NumElements = CB.getArgOperand(3);
auto *Stride = CB.getArgOperand(4);

// Skip "_Z22__spirv_GroupAsyncCopyiPU3AS3" (33 char), get the size of
// parameter type directly
const size_t kManglingPrefixLength = 33;
int ElementSize = getTypeSizeFromManglingName(
FuncName.substr(kManglingPrefixLength));
assert(ElementSize != 0 &&
"Unsupported __spirv_GroupAsyncCopy element type");

IRB.CreateCall(
MS.Spirv.MsanUnpoisonStridedCopyFunc,
{IRB.CreatePointerCast(Dest, MS.Spirv.IntptrTy),
IRB.getInt32(Dest->getType()->getPointerAddressSpace()),
IRB.CreatePointerCast(Src, MS.Spirv.IntptrTy),
IRB.getInt32(Src->getType()->getPointerAddressSpace()),
IRB.getInt32(ElementSize), NumElements, Stride});
}
}
}

// Now, get the shadow for the RetVal.
if (!CB.getType()->isSized())
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; RUN: opt < %s -passes=msan -msan-instrumentation-with-call-threshold=0 -msan-eager-checks=1 -msan-poison-stack-with-call=1 -S | FileCheck %s

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

declare spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyiPU3AS3iPU3AS1immP13__spirv_Event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event")) nounwind
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event"))

define spir_kernel void @kernel(ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc) sanitize_memory {
entry:
; CHECK: @__msan_barrier()
; CHECK: [[REG1:%[0-9]+]] = ptrtoint ptr addrspace(3) %_arg_localAcc to i64
; CHECK-NEXT: [[REG2:%[0-9]+]] = ptrtoint ptr addrspace(1) %_arg_globalAcc to i64
; CHECK-NEXT: call void @__msan_unpoison_strided_copy(i64 [[REG1]], i32 3, i64 [[REG2]], i32 1, i32 4, i64 512, i64 1)
%copy = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyiPU3AS3iPU3AS1immP13__spirv_Event(i32 2, ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer)

; CHECK: __msan_unpoison_strided_copy
%copy2 = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %_arg_globalAcc, ptr addrspace(3) %_arg_localAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer)
; CHECK: __msan_unpoison_strided_copy
%copy3 = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer)
ret void
}
Loading