From e2b67a868b7191237374226218756c1a62fabd4e Mon Sep 17 00:00:00 2001 From: ReinUsesLisp Date: Fri, 15 May 2020 01:43:44 -0300 Subject: [PATCH] shader/other: Implement thread comparisons (NV_shader_thread_group) Hardware S2R special registers match gl_Thread*MaskNV. We can trivially implement these using Nvidia's extension on OpenGL or naively stubbing them with the ARB instructions to match. This might cause issues if the host device warp size doesn't match Nvidia's. That said, this is unlikely on proper shaders. Refer to the attached url for more documentation about these flags. https://www.khronos.org/registry/OpenGL/extensions/NV/NV_shader_thread_group.txt --- .../renderer_opengl/gl_shader_decompiler.cpp | 23 +++++++++++++++++++ .../renderer_vulkan/vk_shader_decompiler.cpp | 23 +++++++++++++++++++ src/video_core/shader/decode/other.cpp | 21 +++++++++++++++++ src/video_core/shader/node.h | 5 ++++ 4 files changed, 72 insertions(+) diff --git a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp index 960ebf1a12..c83a08d422 100644 --- a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp +++ b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp @@ -2309,6 +2309,18 @@ private: return {"gl_SubGroupInvocationARB", Type::Uint}; } + template + Expression ThreadMask(Operation) { + if (device.HasWarpIntrinsics()) { + return {fmt::format("gl_Thread{}MaskNV", comparison), Type::Uint}; + } + if (device.HasShaderBallot()) { + return {fmt::format("uint(gl_SubGroup{}MaskARB)", comparison), Type::Uint}; + } + LOG_ERROR(Render_OpenGL, "Thread mask intrinsics are required by the shader"); + return {"0U", Type::Uint}; + } + Expression ShuffleIndexed(Operation operation) { std::string value = VisitOperand(operation, 0).AsFloat(); @@ -2337,6 +2349,12 @@ private: static constexpr std::string_view NotEqual = "!="; static constexpr std::string_view GreaterEqual = ">="; + static constexpr std::string_view Eq = "Eq"; + static constexpr std::string_view Ge = "Ge"; + static constexpr std::string_view Gt = "Gt"; + static constexpr std::string_view Le = "Le"; + static constexpr std::string_view Lt = "Lt"; + static constexpr std::string_view Add = "Add"; static constexpr std::string_view Min = "Min"; static constexpr std::string_view Max = "Max"; @@ -2554,6 +2572,11 @@ private: &GLSLDecompiler::VoteEqual, &GLSLDecompiler::ThreadId, + &GLSLDecompiler::ThreadMask, + &GLSLDecompiler::ThreadMask, + &GLSLDecompiler::ThreadMask, + &GLSLDecompiler::ThreadMask, + &GLSLDecompiler::ThreadMask, &GLSLDecompiler::ShuffleIndexed, &GLSLDecompiler::MemoryBarrierGL, diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp index 167e20e919..f4ccc98489 100644 --- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp +++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp @@ -515,6 +515,16 @@ private: void DeclareCommon() { thread_id = DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id"); + thread_masks[0] = + DeclareInputBuiltIn(spv::BuiltIn::SubgroupEqMask, t_in_uint4, "thread_eq_mask"); + thread_masks[1] = + DeclareInputBuiltIn(spv::BuiltIn::SubgroupGeMask, t_in_uint4, "thread_ge_mask"); + thread_masks[2] = + DeclareInputBuiltIn(spv::BuiltIn::SubgroupGtMask, t_in_uint4, "thread_gt_mask"); + thread_masks[3] = + DeclareInputBuiltIn(spv::BuiltIn::SubgroupLeMask, t_in_uint4, "thread_le_mask"); + thread_masks[4] = + DeclareInputBuiltIn(spv::BuiltIn::SubgroupLtMask, t_in_uint4, "thread_lt_mask"); } void DeclareVertex() { @@ -2175,6 +2185,13 @@ private: return {OpLoad(t_uint, thread_id), Type::Uint}; } + template + Expression ThreadMask(Operation) { + // TODO(Rodrigo): Handle devices with different warp sizes + const Id mask = thread_masks[index]; + return {OpLoad(t_uint, AccessElement(t_in_uint, mask, 0)), Type::Uint}; + } + Expression ShuffleIndexed(Operation operation) { const Id value = AsFloat(Visit(operation[0])); const Id index = AsUint(Visit(operation[1])); @@ -2639,6 +2656,11 @@ private: &SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>, &SPIRVDecompiler::ThreadId, + &SPIRVDecompiler::ThreadMask<0>, // Eq + &SPIRVDecompiler::ThreadMask<1>, // Ge + &SPIRVDecompiler::ThreadMask<2>, // Gt + &SPIRVDecompiler::ThreadMask<3>, // Le + &SPIRVDecompiler::ThreadMask<4>, // Lt &SPIRVDecompiler::ShuffleIndexed, &SPIRVDecompiler::MemoryBarrierGL, @@ -2763,6 +2785,7 @@ private: Id workgroup_id{}; Id local_invocation_id{}; Id thread_id{}; + std::array thread_masks{}; // eq, ge, gt, le, lt VertexIndices in_indices; VertexIndices out_indices; diff --git a/src/video_core/shader/decode/other.cpp b/src/video_core/shader/decode/other.cpp index d4f95b18c6..399a455c41 100644 --- a/src/video_core/shader/decode/other.cpp +++ b/src/video_core/shader/decode/other.cpp @@ -109,6 +109,27 @@ u32 ShaderIR::DecodeOther(NodeBlock& bb, u32 pc) { return Operation(OperationCode::WorkGroupIdY); case SystemVariable::CtaIdZ: return Operation(OperationCode::WorkGroupIdZ); + case SystemVariable::EqMask: + case SystemVariable::LtMask: + case SystemVariable::LeMask: + case SystemVariable::GtMask: + case SystemVariable::GeMask: + uses_warps = true; + switch (instr.sys20) { + case SystemVariable::EqMask: + return Operation(OperationCode::ThreadEqMask); + case SystemVariable::LtMask: + return Operation(OperationCode::ThreadLtMask); + case SystemVariable::LeMask: + return Operation(OperationCode::ThreadLeMask); + case SystemVariable::GtMask: + return Operation(OperationCode::ThreadGtMask); + case SystemVariable::GeMask: + return Operation(OperationCode::ThreadGeMask); + default: + UNREACHABLE(); + return Immediate(0u); + } default: UNIMPLEMENTED_MSG("Unhandled system move: {}", static_cast(instr.sys20.Value())); diff --git a/src/video_core/shader/node.h b/src/video_core/shader/node.h index f75b622407..cce8aeebe5 100644 --- a/src/video_core/shader/node.h +++ b/src/video_core/shader/node.h @@ -226,6 +226,11 @@ enum class OperationCode { VoteEqual, /// (bool) -> bool ThreadId, /// () -> uint + ThreadEqMask, /// () -> uint + ThreadGeMask, /// () -> uint + ThreadGtMask, /// () -> uint + ThreadLeMask, /// () -> uint + ThreadLtMask, /// () -> uint ShuffleIndexed, /// (uint value, uint index) -> uint MemoryBarrierGL, /// () -> void