Skip to content
47 changes: 47 additions & 0 deletions libdevice/sanitizer/msan_rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,16 @@ inline void __msan_exit() {
__devicelib_exit();
}

// This function is only used for shadow propagation
template <typename T>
void GroupAsyncCopy(uptr Dest, uptr Src, size_t NumElements, size_t Stride) {
auto DestPtr = (__SYCL_GLOBAL__ T *)Dest;
auto SrcPtr = (const __SYCL_GLOBAL__ T *)Src;
for (size_t i = 0; i < NumElements; i++) {
DestPtr[i] = SrcPtr[i * Stride];
}
}

} // namespace

#define MSAN_MAYBE_WARNING(type, size) \
Expand Down Expand Up @@ -589,4 +599,41 @@ __msan_set_private_base(__SYCL_PRIVATE__ void *ptr) {
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_private_base, sid, ptr));
}

static __SYCL_CONSTANT__ const char __msan_print_strided_copy_unsupport_type[] =
"[kernel] __msan_unpoison_strided_copy: unsupported type(%d)\n";

DEVICE_EXTERN_C_NOINLINE void
__msan_unpoison_strided_copy(uptr dest, uint32_t dest_as, uptr src,
uint32_t src_as, uint32_t element_size,
uptr counts, uptr stride) {
if (!GetMsanLaunchInfo)
return;

MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg,
"__msan_unpoison_strided_copy"));

uptr shadow_dest = (uptr)__msan_get_shadow(dest, dest_as);
uptr shadow_src = (uptr)__msan_get_shadow(src, src_as);

switch (element_size) {
case 1:
GroupAsyncCopy<int8_t>(shadow_dest, shadow_src, counts, stride);
break;
case 2:
GroupAsyncCopy<int16_t>(shadow_dest, shadow_src, counts, stride);
break;
case 4:
GroupAsyncCopy<int32_t>(shadow_dest, shadow_src, counts, stride);
break;
case 8:
GroupAsyncCopy<int64_t>(shadow_dest, shadow_src, counts, stride);
break;
default:
__spirv_ocl_printf(__msan_print_strided_copy_unsupport_type, element_size);
}

MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end,
"__msan_unpoison_strided_copy"));
}

#endif // __SPIR__ || __SPIRV__
106 changes: 103 additions & 3 deletions llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,8 @@ class MemorySanitizerOnSpirv {
void initializeKernelCallerMap(Function *F);

private:
friend struct MemorySanitizerVisitor;

Module &M;
LLVMContext &C;
const DataLayout &DL;
Expand Down Expand Up @@ -833,6 +835,7 @@ class MemorySanitizerOnSpirv {
FunctionCallee MsanBarrierFunc;
FunctionCallee MsanUnpoisonStackFunc;
FunctionCallee MsanSetPrivateBaseFunc;
FunctionCallee MsanUnpoisonStridedCopyFunc;
};

} // end anonymous namespace
Expand Down Expand Up @@ -899,14 +902,14 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
M.getOrInsertFunction("__msan_unpoison_shadow_static_local",
IRB.getVoidTy(), IntptrTy, IntptrTy);

// __asan_poison_shadow_dynamic_local(
// __msan_poison_shadow_dynamic_local(
// uptr ptr,
// uint32_t num_args
// )
MsanPoisonShadowDynamicLocalFunc = M.getOrInsertFunction(
"__msan_poison_shadow_dynamic_local", IRB.getVoidTy(), IntptrTy, Int32Ty);

// __asan_unpoison_shadow_dynamic_local(
// __msan_unpoison_shadow_dynamic_local(
// uptr ptr,
// uint32_t num_args
// )
Expand All @@ -930,6 +933,18 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
MsanSetPrivateBaseFunc =
M.getOrInsertFunction("__msan_set_private_base", IRB.getVoidTy(),
PointerType::get(C, kSpirOffloadPrivateAS));

// __msan_unpoison_strided_copy(
// uptr dest, uint32_t dest_as,
// uptr src, uint32_t src_as,
// uint32_t element_size,
// uptr counts,
// uptr stride
// )
MsanUnpoisonStridedCopyFunc = M.getOrInsertFunction(
"__msan_unpoison_strided_copy", IRB.getVoidTy(), IntptrTy,
IRB.getInt32Ty(), IntptrTy, IRB.getInt32Ty(), IRB.getInt32Ty(),
IRB.getInt64Ty(), IRB.getInt64Ty());
}

// Handle global variables:
Expand Down Expand Up @@ -1833,7 +1848,8 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) {
}
} else {
auto FuncName = Func->getName();
if (FuncName.contains("__spirv_"))
if (FuncName.contains("__spirv_") &&
!FuncName.contains("__spirv_GroupAsyncCopy"))
I.setNoSanitizeMetadata();
}
}
Expand All @@ -1843,6 +1859,55 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) {
I.setNoSanitizeMetadata();
}

// This is not a general-purpose function, but a helper for demangling
// "__spirv_GroupAsyncCopy" function name
static int getTypeSizeFromManglingName(StringRef Name) {
auto GetTypeSize = [](const char C) {
switch (C) {
case 'a': // signed char
case 'c': // char
return 1;
case 's': // short
return 2;
case 'f': // float
case 'i': // int
return 4;
case 'd': // double
case 'l': // long
return 8;
default:
return 0;
}
};

// Name should always be long enough since it has other unmeaningful chars,
// it should have at least 6 chars, such as "Dv16_d"
if (Name.size() < 6)
return 0;

// 1. Basic type
if (Name[0] != 'D')
return GetTypeSize(Name[0]);

// 2. Vector type

// Drop "Dv"
assert(Name[0] == 'D' && Name[1] == 'v' &&
"Invalid mangling name for vector type");
Name = Name.drop_front(2);

// Vector length
assert(isDigit(Name[0]) && "Invalid mangling name for vector type");
int Len = std::stoi(Name.str());
Name = Name.drop_front(Len >= 10 ? 2 : 1);

assert(Name[0] == '_' && "Invalid mangling name for vector type");
Name = Name.drop_front(1);

int Size = GetTypeSize(Name[0]);
return Len * Size;
}

namespace {

/// Helper class to attach debug information of the given instruction onto new
Expand Down Expand Up @@ -6395,6 +6460,41 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
VAHelper->visitCallBase(CB, IRB);
}

if (SpirOrSpirv) {
auto *Func = CB.getCalledFunction();
if (Func) {
auto FuncName = Func->getName();
if (FuncName.contains("__spirv_GroupAsyncCopy")) {
// clang-format off
// Handle functions like "_Z22__spirv_GroupAsyncCopyiPU3AS3dPU3AS1dllP13__spirv_Event",
// its demangled name is "__spirv_GroupAsyncCopy(int, double AS3* dst, double AS1* src, long, long, __spirv_Event*)"
// The type of "src" and "dst" should always be same.
// clang-format on

auto *Dest = CB.getArgOperand(1);
auto *Src = CB.getArgOperand(2);
auto *NumElements = CB.getArgOperand(3);
auto *Stride = CB.getArgOperand(4);

// Skip "_Z22__spirv_GroupAsyncCopyiPU3AS3" (33 char), get the size of
// parameter type directly
const size_t kManglingPrefixLength = 33;
int ElementSize = getTypeSizeFromManglingName(
FuncName.substr(kManglingPrefixLength));
assert(ElementSize != 0 &&
"Unsupported __spirv_GroupAsyncCopy element type");

IRB.CreateCall(
MS.Spirv.MsanUnpoisonStridedCopyFunc,
{IRB.CreatePointerCast(Dest, MS.Spirv.IntptrTy),
IRB.getInt32(Dest->getType()->getPointerAddressSpace()),
IRB.CreatePointerCast(Src, MS.Spirv.IntptrTy),
IRB.getInt32(Src->getType()->getPointerAddressSpace()),
IRB.getInt32(ElementSize), NumElements, Stride});
}
}
}

// Now, get the shadow for the RetVal.
if (!CB.getType()->isSized())
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; RUN: opt < %s -passes=msan -msan-instrumentation-with-call-threshold=0 -msan-eager-checks=1 -msan-poison-stack-with-call=1 -S | FileCheck %s

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

declare spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyiPU3AS3iPU3AS1immP13__spirv_Event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event")) nounwind
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event"))

define spir_kernel void @kernel(ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc) sanitize_memory {
entry:
; CHECK: @__msan_barrier()
; CHECK: [[REG1:%[0-9]+]] = ptrtoint ptr addrspace(3) %_arg_localAcc to i64
; CHECK-NEXT: [[REG2:%[0-9]+]] = ptrtoint ptr addrspace(1) %_arg_globalAcc to i64
; CHECK-NEXT: call void @__msan_unpoison_strided_copy(i64 [[REG1]], i32 3, i64 [[REG2]], i32 1, i32 4, i64 512, i64 1)
%copy = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyiPU3AS3iPU3AS1immP13__spirv_Event(i32 2, ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer)

; CHECK: __msan_unpoison_strided_copy
%copy2 = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %_arg_globalAcc, ptr addrspace(3) %_arg_localAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer)
; CHECK: __msan_unpoison_strided_copy
%copy3 = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer)
ret void
}
Loading