diff --git a/main.cpp b/main.cpp index 81db89ce0..8d015a6dc 100644 --- a/main.cpp +++ b/main.cpp @@ -1082,6 +1082,10 @@ static ExecutionModel stage_to_execution_model(const std::string &stage) return ExecutionModelMissKHR; else if (stage == "rcall") return ExecutionModelCallableKHR; + else if (stage == "mesh") + return spv::ExecutionModelMeshEXT; + else if (stage == "task") + return spv::ExecutionModelTaskEXT; else SPIRV_CROSS_THROW("Invalid stage."); } diff --git a/reference/opt/shaders/task/task-shader-basic-2.vk.spv14.nocompat.task.vk b/reference/opt/shaders/task/task-shader-basic-2.vk.spv14.nocompat.task.vk new file mode 100644 index 000000000..98704e22d --- /dev/null +++ b/reference/opt/shaders/task/task-shader-basic-2.vk.spv14.nocompat.task.vk @@ -0,0 +1,42 @@ +#version 450 +#extension GL_EXT_mesh_shader : require +layout(local_size_x = 4, local_size_y = 3, local_size_z = 2) in; + +struct Payload +{ + float v[3]; +}; + +shared float vs[24]; +taskPayloadSharedEXT Payload p; + +void main() +{ + vs[gl_LocalInvocationIndex] = 10.0; + barrier(); + if (gl_LocalInvocationIndex < 12u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 12u]; + } + barrier(); + if (gl_LocalInvocationIndex < 6u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 6u]; + } + barrier(); + if (gl_LocalInvocationIndex < 3u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 3u]; + } + barrier(); + p.v[gl_LocalInvocationIndex] = vs[gl_LocalInvocationIndex]; + if (vs[5] > 20.0) + { + EmitMeshTasksEXT(uint(int(vs[4])), uint(int(vs[6])), uint(int(vs[8]))); + } + else + { + EmitMeshTasksEXT(uint(int(vs[6])), 10u, 50u); + } +} + diff --git a/reference/opt/shaders/task/task-shader-basic.vk.spv14.nocompat.task.vk b/reference/opt/shaders/task/task-shader-basic.vk.spv14.nocompat.task.vk new file mode 100644 index 000000000..1d491e701 --- /dev/null +++ b/reference/opt/shaders/task/task-shader-basic.vk.spv14.nocompat.task.vk @@ -0,0 +1,35 @@ +#version 450 +#extension GL_EXT_mesh_shader : require +layout(local_size_x = 4, local_size_y = 3, local_size_z = 2) in; + +struct Payload +{ + float v[3]; +}; + +shared float vs[24]; +taskPayloadSharedEXT Payload p; + +void main() +{ + vs[gl_LocalInvocationIndex] = 10.0; + barrier(); + if (gl_LocalInvocationIndex < 12u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 12u]; + } + barrier(); + if (gl_LocalInvocationIndex < 6u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 6u]; + } + barrier(); + if (gl_LocalInvocationIndex < 3u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 3u]; + } + barrier(); + p.v[gl_LocalInvocationIndex] = vs[gl_LocalInvocationIndex]; + EmitMeshTasksEXT(uint(int(vs[4])), uint(int(vs[6])), uint(int(vs[8]))); +} + diff --git a/reference/shaders-no-opt/asm/task/task-shader.vk.nocompat.invalid.spv14.asm.task.vk b/reference/shaders-no-opt/asm/task/task-shader.vk.nocompat.invalid.spv14.asm.task.vk new file mode 100644 index 000000000..1d491e701 --- /dev/null +++ b/reference/shaders-no-opt/asm/task/task-shader.vk.nocompat.invalid.spv14.asm.task.vk @@ -0,0 +1,35 @@ +#version 450 +#extension GL_EXT_mesh_shader : require +layout(local_size_x = 4, local_size_y = 3, local_size_z = 2) in; + +struct Payload +{ + float v[3]; +}; + +shared float vs[24]; +taskPayloadSharedEXT Payload p; + +void main() +{ + vs[gl_LocalInvocationIndex] = 10.0; + barrier(); + if (gl_LocalInvocationIndex < 12u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 12u]; + } + barrier(); + if (gl_LocalInvocationIndex < 6u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 6u]; + } + barrier(); + if (gl_LocalInvocationIndex < 3u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 3u]; + } + barrier(); + p.v[gl_LocalInvocationIndex] = vs[gl_LocalInvocationIndex]; + EmitMeshTasksEXT(uint(int(vs[4])), uint(int(vs[6])), uint(int(vs[8]))); +} + diff --git a/reference/shaders/task/task-shader-basic-2.vk.spv14.nocompat.task.vk b/reference/shaders/task/task-shader-basic-2.vk.spv14.nocompat.task.vk new file mode 100644 index 000000000..98704e22d --- /dev/null +++ b/reference/shaders/task/task-shader-basic-2.vk.spv14.nocompat.task.vk @@ -0,0 +1,42 @@ +#version 450 +#extension GL_EXT_mesh_shader : require +layout(local_size_x = 4, local_size_y = 3, local_size_z = 2) in; + +struct Payload +{ + float v[3]; +}; + +shared float vs[24]; +taskPayloadSharedEXT Payload p; + +void main() +{ + vs[gl_LocalInvocationIndex] = 10.0; + barrier(); + if (gl_LocalInvocationIndex < 12u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 12u]; + } + barrier(); + if (gl_LocalInvocationIndex < 6u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 6u]; + } + barrier(); + if (gl_LocalInvocationIndex < 3u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 3u]; + } + barrier(); + p.v[gl_LocalInvocationIndex] = vs[gl_LocalInvocationIndex]; + if (vs[5] > 20.0) + { + EmitMeshTasksEXT(uint(int(vs[4])), uint(int(vs[6])), uint(int(vs[8]))); + } + else + { + EmitMeshTasksEXT(uint(int(vs[6])), 10u, 50u); + } +} + diff --git a/reference/shaders/task/task-shader-basic.vk.spv14.nocompat.task.vk b/reference/shaders/task/task-shader-basic.vk.spv14.nocompat.task.vk new file mode 100644 index 000000000..1d491e701 --- /dev/null +++ b/reference/shaders/task/task-shader-basic.vk.spv14.nocompat.task.vk @@ -0,0 +1,35 @@ +#version 450 +#extension GL_EXT_mesh_shader : require +layout(local_size_x = 4, local_size_y = 3, local_size_z = 2) in; + +struct Payload +{ + float v[3]; +}; + +shared float vs[24]; +taskPayloadSharedEXT Payload p; + +void main() +{ + vs[gl_LocalInvocationIndex] = 10.0; + barrier(); + if (gl_LocalInvocationIndex < 12u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 12u]; + } + barrier(); + if (gl_LocalInvocationIndex < 6u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 6u]; + } + barrier(); + if (gl_LocalInvocationIndex < 3u) + { + vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 3u]; + } + barrier(); + p.v[gl_LocalInvocationIndex] = vs[gl_LocalInvocationIndex]; + EmitMeshTasksEXT(uint(int(vs[4])), uint(int(vs[6])), uint(int(vs[8]))); +} + diff --git a/shaders-no-opt/asm/task/task-shader.vk.nocompat.invalid.spv14.asm.task b/shaders-no-opt/asm/task/task-shader.vk.nocompat.invalid.spv14.asm.task new file mode 100644 index 000000000..cbef97ed1 --- /dev/null +++ b/shaders-no-opt/asm/task/task-shader.vk.nocompat.invalid.spv14.asm.task @@ -0,0 +1,132 @@ +; SPIR-V +; Version: 1.4 +; Generator: Khronos Glslang Reference Front End; 10 +; Bound: 93 +; Schema: 0 + OpCapability MeshShadingEXT + OpExtension "SPV_EXT_mesh_shader" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint TaskEXT %main "main" %vs %gl_LocalInvocationIndex %p + OpExecutionMode %main LocalSize 4 3 2 + OpSource GLSL 450 + OpSourceExtension "GL_EXT_mesh_shader" + OpName %main "main" + OpName %vs "vs" + OpName %gl_LocalInvocationIndex "gl_LocalInvocationIndex" + OpName %Payload "Payload" + OpMemberName %Payload 0 "v" + OpName %p "p" + OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex + OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %uint_24 = OpConstant %uint 24 +%_arr_float_uint_24 = OpTypeArray %float %uint_24 +%_ptr_Workgroup__arr_float_uint_24 = OpTypePointer Workgroup %_arr_float_uint_24 + %vs = OpVariable %_ptr_Workgroup__arr_float_uint_24 Workgroup +%_ptr_Input_uint = OpTypePointer Input %uint +%gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input + %float_10 = OpConstant %float 10 +%_ptr_Workgroup_float = OpTypePointer Workgroup %float + %uint_2 = OpConstant %uint 2 + %uint_264 = OpConstant %uint 264 + %uint_12 = OpConstant %uint 12 + %bool = OpTypeBool + %uint_6 = OpConstant %uint 6 + %uint_3 = OpConstant %uint 3 +%_arr_float_uint_3 = OpTypeArray %float %uint_3 + %Payload = OpTypeStruct %_arr_float_uint_3 +%_ptr_TaskPayloadWorkgroupEXT_Payload = OpTypePointer TaskPayloadWorkgroupEXT %Payload + %p = OpVariable %_ptr_TaskPayloadWorkgroupEXT_Payload TaskPayloadWorkgroupEXT + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 +%_ptr_TaskPayloadWorkgroupEXT_float = OpTypePointer TaskPayloadWorkgroupEXT %float + %int_4 = OpConstant %int 4 + %int_6 = OpConstant %int 6 + %int_8 = OpConstant %int 8 + %v3uint = OpTypeVector %uint 3 + %uint_4 = OpConstant %uint 4 +%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_4 %uint_3 %uint_2 + %main = OpFunction %void None %3 + %5 = OpLabel + %14 = OpLoad %uint %gl_LocalInvocationIndex + %17 = OpAccessChain %_ptr_Workgroup_float %vs %14 + OpStore %17 %float_10 + OpControlBarrier %uint_2 %uint_2 %uint_264 + %20 = OpLoad %uint %gl_LocalInvocationIndex + %23 = OpULessThan %bool %20 %uint_12 + OpSelectionMerge %25 None + OpBranchConditional %23 %24 %25 + %24 = OpLabel + %26 = OpLoad %uint %gl_LocalInvocationIndex + %27 = OpLoad %uint %gl_LocalInvocationIndex + %28 = OpIAdd %uint %27 %uint_12 + %29 = OpAccessChain %_ptr_Workgroup_float %vs %28 + %30 = OpLoad %float %29 + %31 = OpAccessChain %_ptr_Workgroup_float %vs %26 + %32 = OpLoad %float %31 + %33 = OpFAdd %float %32 %30 + %34 = OpAccessChain %_ptr_Workgroup_float %vs %26 + OpStore %34 %33 + OpBranch %25 + %25 = OpLabel + OpControlBarrier %uint_2 %uint_2 %uint_264 + %35 = OpLoad %uint %gl_LocalInvocationIndex + %37 = OpULessThan %bool %35 %uint_6 + OpSelectionMerge %39 None + OpBranchConditional %37 %38 %39 + %38 = OpLabel + %40 = OpLoad %uint %gl_LocalInvocationIndex + %41 = OpLoad %uint %gl_LocalInvocationIndex + %42 = OpIAdd %uint %41 %uint_6 + %43 = OpAccessChain %_ptr_Workgroup_float %vs %42 + %44 = OpLoad %float %43 + %45 = OpAccessChain %_ptr_Workgroup_float %vs %40 + %46 = OpLoad %float %45 + %47 = OpFAdd %float %46 %44 + %48 = OpAccessChain %_ptr_Workgroup_float %vs %40 + OpStore %48 %47 + OpBranch %39 + %39 = OpLabel + OpControlBarrier %uint_2 %uint_2 %uint_264 + %49 = OpLoad %uint %gl_LocalInvocationIndex + %51 = OpULessThan %bool %49 %uint_3 + OpSelectionMerge %53 None + OpBranchConditional %51 %52 %53 + %52 = OpLabel + %54 = OpLoad %uint %gl_LocalInvocationIndex + %55 = OpLoad %uint %gl_LocalInvocationIndex + %56 = OpIAdd %uint %55 %uint_3 + %57 = OpAccessChain %_ptr_Workgroup_float %vs %56 + %58 = OpLoad %float %57 + %59 = OpAccessChain %_ptr_Workgroup_float %vs %54 + %60 = OpLoad %float %59 + %61 = OpFAdd %float %60 %58 + %62 = OpAccessChain %_ptr_Workgroup_float %vs %54 + OpStore %62 %61 + OpBranch %53 + %53 = OpLabel + OpControlBarrier %uint_2 %uint_2 %uint_264 + %69 = OpLoad %uint %gl_LocalInvocationIndex + %70 = OpLoad %uint %gl_LocalInvocationIndex + %71 = OpAccessChain %_ptr_Workgroup_float %vs %70 + %72 = OpLoad %float %71 + %74 = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_float %p %int_0 %69 + OpStore %74 %72 + %76 = OpAccessChain %_ptr_Workgroup_float %vs %int_4 + %77 = OpLoad %float %76 + %78 = OpConvertFToS %int %77 + %79 = OpBitcast %uint %78 + %81 = OpAccessChain %_ptr_Workgroup_float %vs %int_6 + %82 = OpLoad %float %81 + %83 = OpConvertFToS %int %82 + %84 = OpBitcast %uint %83 + %86 = OpAccessChain %_ptr_Workgroup_float %vs %int_8 + %87 = OpLoad %float %86 + %88 = OpConvertFToS %int %87 + %89 = OpBitcast %uint %88 + OpEmitMeshTasksEXT %79 %84 %89 %p + OpFunctionEnd diff --git a/spirv_common.hpp b/spirv_common.hpp index c8e748e6f..5c2ad7476 100644 --- a/spirv_common.hpp +++ b/spirv_common.hpp @@ -777,7 +777,8 @@ struct SPIRBlock : IVariant Unreachable, // Noop Kill, // Discard IgnoreIntersection, // Ray Tracing - TerminateRay // Ray Tracing + TerminateRay, // Ray Tracing + EmitMeshTasks // Mesh shaders }; enum Merge @@ -839,6 +840,13 @@ struct SPIRBlock : IVariant BlockID false_block = 0; BlockID default_block = 0; + // If terminator is EmitMeshTasksEXT. + struct + { + ID groups[3]; + ID payload; + } mesh = {}; + SmallVector ops; struct Phi diff --git a/spirv_cross.cpp b/spirv_cross.cpp index 17072c19a..04ea35fa5 100644 --- a/spirv_cross.cpp +++ b/spirv_cross.cpp @@ -98,7 +98,8 @@ bool Compiler::block_is_pure(const SPIRBlock &block) // This is a global side effect of the function. if (block.terminator == SPIRBlock::Kill || block.terminator == SPIRBlock::TerminateRay || - block.terminator == SPIRBlock::IgnoreIntersection) + block.terminator == SPIRBlock::IgnoreIntersection || + block.terminator == SPIRBlock::EmitMeshTasks) return false; for (auto &i : block.ops) @@ -155,6 +156,7 @@ bool Compiler::block_is_pure(const SPIRBlock &block) return false; // Mesh shader functions modify global state. + // (EmitMeshTasks is a terminator). case OpSetMeshOutputsEXT: return false; diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index e80a6ceae..bcd4f911c 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -498,6 +498,7 @@ void CompilerGLSL::find_static_extensions() break; case ExecutionModelMeshEXT: + case ExecutionModelTaskEXT: if (options.es || options.version < 450) SPIRV_CROSS_THROW("Mesh shaders require GLSL 450 or above."); if (!options.vulkan_semantics) @@ -16105,6 +16106,13 @@ void CompilerGLSL::emit_block_chain(SPIRBlock &block) statement("terminateRayEXT;"); break; + case SPIRBlock::EmitMeshTasks: + statement("EmitMeshTasksEXT(", + to_unpacked_expression(block.mesh.groups[0]), ", ", + to_unpacked_expression(block.mesh.groups[1]), ", ", + to_unpacked_expression(block.mesh.groups[2]), ");"); + break; + default: SPIRV_CROSS_THROW("Unimplemented block terminator."); } diff --git a/spirv_parser.cpp b/spirv_parser.cpp index 526c92cb5..49eb1933c 100644 --- a/spirv_parser.cpp +++ b/spirv_parser.cpp @@ -183,6 +183,15 @@ void Parser::parse(const Instruction &instruction) auto op = static_cast(instruction.op); uint32_t length = instruction.length; + // HACK for glslang that might emit OpEmitMeshTasksEXT followed by return / branch. + // Instead of failing hard, just ignore it. + if (ignore_trailing_block_opcodes) + { + ignore_trailing_block_opcodes = false; + if (op == OpReturn || op == OpBranch || op == OpUnreachable) + return; + } + switch (op) { case OpSourceContinued: @@ -1107,6 +1116,18 @@ void Parser::parse(const Instruction &instruction) current_block = nullptr; break; + case OpEmitMeshTasksEXT: + if (!current_block) + SPIRV_CROSS_THROW("Trying to end a non-existing block."); + current_block->terminator = SPIRBlock::EmitMeshTasks; + for (uint32_t i = 0; i < 3; i++) + current_block->mesh.groups[i] = ops[i]; + current_block->mesh.payload = length >= 4 ? ops[3] : 0; + current_block = nullptr; + // Currently glslang is bugged and does not treat EmitMeshTasksEXT as a terminator. + ignore_trailing_block_opcodes = true; + break; + case OpReturn: { if (!current_block) diff --git a/spirv_parser.hpp b/spirv_parser.hpp index d72fc71d8..dabc0e224 100644 --- a/spirv_parser.hpp +++ b/spirv_parser.hpp @@ -46,6 +46,8 @@ class Parser ParsedIR ir; SPIRFunction *current_function = nullptr; SPIRBlock *current_block = nullptr; + // For workarounds. + bool ignore_trailing_block_opcodes = false; void parse(const Instruction &instr); const uint32_t *stream(const Instruction &instr) const;