diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index 0175f2fb3698b..970b83de5ee33 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -612,13 +612,10 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of // type int32 with 0 value to represent the FP Fast Math Mode. std::vector SPIRVFloatTypes; - const MachineInstr *ConstZero = nullptr; + const MachineInstr *ConstZeroInt32 = nullptr; for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { - // Skip if the instruction is not OpTypeFloat or OpConstant. unsigned OpCode = MI->getOpcode(); - if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull) - continue; // Collect the SPIRV type if it's a float. if (OpCode == SPIRV::OpTypeFloat) { @@ -629,14 +626,18 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { continue; } SPIRVFloatTypes.push_back(MI); - } else { + continue; + } + + if (OpCode == SPIRV::OpConstantNull) { // Check if the constant is int32, if not skip it. const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo(); MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg()); - if (!TypeMI || TypeMI->getOperand(1).getImm() != 32) - continue; - - ConstZero = MI; + bool IsInt32Ty = TypeMI && + TypeMI->getOpcode() == SPIRV::OpTypeInt && + TypeMI->getOperand(1).getImm() == 32; + if (IsInt32Ty) + ConstZeroInt32 = MI; } } @@ -657,9 +658,9 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { MCRegister TypeReg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); Inst.addOperand(MCOperand::createReg(TypeReg)); - assert(ConstZero && "There should be a constant zero."); + assert(ConstZeroInt32 && "There should be a constant zero."); MCRegister ConstReg = MAI->getRegisterAlias( - ConstZero->getMF(), ConstZero->getOperand(0).getReg()); + ConstZeroInt32->getMF(), ConstZeroInt32->getOperand(0).getReg()); Inst.addOperand(MCOperand::createReg(ConstReg)); outputMCInst(Inst); } diff --git a/llvm/test/CodeGen/SPIRV/non_int_constant_null.ll b/llvm/test/CodeGen/SPIRV/non_int_constant_null.ll new file mode 100644 index 0000000000000..0ba016aaa30aa --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/non_int_constant_null.ll @@ -0,0 +1,25 @@ +; RUN: llc -mtriple spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_float_controls2 -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -mtriple spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_float_controls2 -o - -filetype=obj | spirv-val %} + +@A = addrspace(1) constant [1 x i8] zeroinitializer + +; CHECK: OpName %[[#FOO:]] "foo" +; CHECK: OpName %[[#A:]] "A" +; CHECK: OpDecorate %[[#A]] Constant +; CHECK: OpDecorate %[[#A]] LinkageAttributes "A" Export +; CHECK: %[[#INT8:]] = OpTypeInt 8 0 +; CHECK: %[[#INT32:]] = OpTypeInt 32 0 +; CHECK: %[[#ONE:]] = OpConstant %[[#INT32]] 1 +; CHECK: %[[#ARR_INT8:]] = OpTypeArray %[[#INT8]] %7 +; CHECK: %[[#ARR_INT8_PTR:]] = OpTypePointer CrossWorkgroup %[[#ARR_INT8]] +; CHECK: %[[#ARR_INT8_ZERO:]] = OpConstantNull %[[#ARR_INT8]] +; CHECK: %13 = OpVariable %[[#ARR_INT8_PTR]] CrossWorkgroup %[[#ARR_INT8_ZERO]] +; CHECK: %[[#FOO]] = OpFunction +; CHECK: = OpLabel +; CHECK: OpReturn +; CHECK: OpFunctionEnd + +define spir_kernel void @foo() { +entry: + ret void +}