diff --git a/src/shader_recompiler/frontend/ir/ir_emitter.cpp b/src/shader_recompiler/frontend/ir/ir_emitter.cpp index 1c5ae0109b..9d7dc034c9 100644 --- a/src/shader_recompiler/frontend/ir/ir_emitter.cpp +++ b/src/shader_recompiler/frontend/ir/ir_emitter.cpp @@ -32,16 +32,16 @@ U32 IREmitter::Imm32(s32 value) const { return U32{Value{static_cast(value)}}; } -U32 IREmitter::Imm32(f32 value) const { - return U32{Value{Common::BitCast(value)}}; +F32 IREmitter::Imm32(f32 value) const { + return F32{Value{value}}; } U64 IREmitter::Imm64(u64 value) const { return U64{Value{value}}; } -U64 IREmitter::Imm64(f64 value) const { - return U64{Value{Common::BitCast(value)}}; +F64 IREmitter::Imm64(f64 value) const { + return F64{Value{value}}; } void IREmitter::Branch(IR::Block* label) { @@ -121,11 +121,11 @@ void IREmitter::SetOFlag(const U1& value) { Inst(Opcode::SetOFlag, value); } -U32 IREmitter::GetAttribute(IR::Attribute attribute) { - return Inst(Opcode::GetAttribute, attribute); +F32 IREmitter::GetAttribute(IR::Attribute attribute) { + return Inst(Opcode::GetAttribute, attribute); } -void IREmitter::SetAttribute(IR::Attribute attribute, const U32& value) { +void IREmitter::SetAttribute(IR::Attribute attribute, const F32& value) { Inst(Opcode::SetAttribute, attribute, value); } @@ -225,50 +225,113 @@ U1 IREmitter::GetOverflowFromOp(const Value& op) { return Inst(Opcode::GetOverflowFromOp, op); } -U16U32U64 IREmitter::FPAdd(const U16U32U64& a, const U16U32U64& b, FpControl control) { +F16F32F64 IREmitter::FPAdd(const F16F32F64& a, const F16F32F64& b, FpControl control) { if (a.Type() != a.Type()) { throw InvalidArgument("Mismatching types {} and {}", a.Type(), b.Type()); } switch (a.Type()) { - case Type::U16: - return Inst(Opcode::FPAdd16, Flags{control}, a, b); - case Type::U32: - return Inst(Opcode::FPAdd32, Flags{control}, a, b); - case Type::U64: - return Inst(Opcode::FPAdd64, Flags{control}, a, b); + case Type::F16: + return Inst(Opcode::FPAdd16, Flags{control}, a, b); + case Type::F32: + return Inst(Opcode::FPAdd32, Flags{control}, a, b); + case Type::F64: + return Inst(Opcode::FPAdd64, Flags{control}, a, b); default: ThrowInvalidType(a.Type()); } } -Value IREmitter::CompositeConstruct(const UAny& e1, const UAny& e2) { +Value IREmitter::CompositeConstruct(const Value& e1, const Value& e2) { if (e1.Type() != e2.Type()) { throw InvalidArgument("Mismatching types {} and {}", e1.Type(), e2.Type()); } - return Inst(Opcode::CompositeConstruct2, e1, e2); + switch (e1.Type()) { + case Type::U32: + return Inst(Opcode::CompositeConstructU32x2, e1, e2); + case Type::F16: + return Inst(Opcode::CompositeConstructF16x2, e1, e2); + case Type::F32: + return Inst(Opcode::CompositeConstructF32x2, e1, e2); + case Type::F64: + return Inst(Opcode::CompositeConstructF64x2, e1, e2); + default: + ThrowInvalidType(e1.Type()); + } } -Value IREmitter::CompositeConstruct(const UAny& e1, const UAny& e2, const UAny& e3) { +Value IREmitter::CompositeConstruct(const Value& e1, const Value& e2, const Value& e3) { if (e1.Type() != e2.Type() || e1.Type() != e3.Type()) { throw InvalidArgument("Mismatching types {}, {}, and {}", e1.Type(), e2.Type(), e3.Type()); } - return Inst(Opcode::CompositeConstruct3, e1, e2, e3); + switch (e1.Type()) { + case Type::U32: + return Inst(Opcode::CompositeConstructU32x3, e1, e2, e3); + case Type::F16: + return Inst(Opcode::CompositeConstructF16x3, e1, e2, e3); + case Type::F32: + return Inst(Opcode::CompositeConstructF32x3, e1, e2, e3); + case Type::F64: + return Inst(Opcode::CompositeConstructF64x3, e1, e2, e3); + default: + ThrowInvalidType(e1.Type()); + } } -Value IREmitter::CompositeConstruct(const UAny& e1, const UAny& e2, const UAny& e3, - const UAny& e4) { +Value IREmitter::CompositeConstruct(const Value& e1, const Value& e2, const Value& e3, + const Value& e4) { if (e1.Type() != e2.Type() || e1.Type() != e3.Type() || e1.Type() != e4.Type()) { throw InvalidArgument("Mismatching types {}, {}, {}, and {}", e1.Type(), e2.Type(), e3.Type(), e4.Type()); } - return Inst(Opcode::CompositeConstruct4, e1, e2, e3, e4); + switch (e1.Type()) { + case Type::U32: + return Inst(Opcode::CompositeConstructU32x4, e1, e2, e3, e4); + case Type::F16: + return Inst(Opcode::CompositeConstructF16x4, e1, e2, e3, e4); + case Type::F32: + return Inst(Opcode::CompositeConstructF32x4, e1, e2, e3, e4); + case Type::F64: + return Inst(Opcode::CompositeConstructF64x4, e1, e2, e3, e4); + default: + ThrowInvalidType(e1.Type()); + } } -UAny IREmitter::CompositeExtract(const Value& vector, size_t element) { - if (element >= 4) { - throw InvalidArgument("Out of bounds element {}", element); +Value IREmitter::CompositeExtract(const Value& vector, size_t element) { + const auto read = [&](Opcode opcode, size_t limit) -> Value { + if (element >= limit) { + throw InvalidArgument("Out of bounds element {}", element); + } + return Inst(opcode, vector, Value{static_cast(element)}); + }; + switch (vector.Type()) { + case Type::U32x2: + return read(Opcode::CompositeExtractU32x2, 2); + case Type::U32x3: + return read(Opcode::CompositeExtractU32x3, 3); + case Type::U32x4: + return read(Opcode::CompositeExtractU32x4, 4); + case Type::F16x2: + return read(Opcode::CompositeExtractF16x2, 2); + case Type::F16x3: + return read(Opcode::CompositeExtractF16x3, 3); + case Type::F16x4: + return read(Opcode::CompositeExtractF16x4, 4); + case Type::F32x2: + return read(Opcode::CompositeExtractF32x2, 2); + case Type::F32x3: + return read(Opcode::CompositeExtractF32x3, 3); + case Type::F32x4: + return read(Opcode::CompositeExtractF32x4, 4); + case Type::F64x2: + return read(Opcode::CompositeExtractF64x2, 2); + case Type::F64x3: + return read(Opcode::CompositeExtractF64x3, 3); + case Type::F64x4: + return read(Opcode::CompositeExtractF64x4, 4); + default: + ThrowInvalidType(vector.Type()); } - return Inst(Opcode::CompositeExtract, vector, Imm32(static_cast(element))); } UAny IREmitter::Select(const U1& condition, const UAny& true_value, const UAny& false_value) { @@ -289,6 +352,36 @@ UAny IREmitter::Select(const U1& condition, const UAny& true_value, const UAny& } } +template <> +IR::U32 IREmitter::BitCast(const IR::F32& value) { + return Inst(Opcode::BitCastU32F32, value); +} + +template <> +IR::F32 IREmitter::BitCast(const IR::U32& value) { + return Inst(Opcode::BitCastF32U32, value); +} + +template <> +IR::U16 IREmitter::BitCast(const IR::F16& value) { + return Inst(Opcode::BitCastU16F16, value); +} + +template <> +IR::F16 IREmitter::BitCast(const IR::U16& value) { + return Inst(Opcode::BitCastF16U16, value); +} + +template <> +IR::U64 IREmitter::BitCast(const IR::F64& value) { + return Inst(Opcode::BitCastU64F64, value); +} + +template <> +IR::F64 IREmitter::BitCast(const IR::U64& value) { + return Inst(Opcode::BitCastF64U64, value); +} + U64 IREmitter::PackUint2x32(const Value& vector) { return Inst(Opcode::PackUint2x32, vector); } @@ -305,75 +398,75 @@ Value IREmitter::UnpackFloat2x16(const U32& value) { return Inst(Opcode::UnpackFloat2x16, value); } -U64 IREmitter::PackDouble2x32(const Value& vector) { - return Inst(Opcode::PackDouble2x32, vector); +F64 IREmitter::PackDouble2x32(const Value& vector) { + return Inst(Opcode::PackDouble2x32, vector); } -Value IREmitter::UnpackDouble2x32(const U64& value) { +Value IREmitter::UnpackDouble2x32(const F64& value) { return Inst(Opcode::UnpackDouble2x32, value); } -U16U32U64 IREmitter::FPMul(const U16U32U64& a, const U16U32U64& b, FpControl control) { +F16F32F64 IREmitter::FPMul(const F16F32F64& a, const F16F32F64& b, FpControl control) { if (a.Type() != b.Type()) { throw InvalidArgument("Mismatching types {} and {}", a.Type(), b.Type()); } switch (a.Type()) { - case Type::U16: - return Inst(Opcode::FPMul16, Flags{control}, a, b); - case Type::U32: - return Inst(Opcode::FPMul32, Flags{control}, a, b); - case Type::U64: - return Inst(Opcode::FPMul64, Flags{control}, a, b); + case Type::F16: + return Inst(Opcode::FPMul16, Flags{control}, a, b); + case Type::F32: + return Inst(Opcode::FPMul32, Flags{control}, a, b); + case Type::F64: + return Inst(Opcode::FPMul64, Flags{control}, a, b); default: ThrowInvalidType(a.Type()); } } -U16U32U64 IREmitter::FPFma(const U16U32U64& a, const U16U32U64& b, const U16U32U64& c, +F16F32F64 IREmitter::FPFma(const F16F32F64& a, const F16F32F64& b, const F16F32F64& c, FpControl control) { if (a.Type() != b.Type() || a.Type() != c.Type()) { throw InvalidArgument("Mismatching types {}, {}, and {}", a.Type(), b.Type(), c.Type()); } switch (a.Type()) { - case Type::U16: - return Inst(Opcode::FPFma16, Flags{control}, a, b, c); - case Type::U32: - return Inst(Opcode::FPFma32, Flags{control}, a, b, c); - case Type::U64: - return Inst(Opcode::FPFma64, Flags{control}, a, b, c); + case Type::F16: + return Inst(Opcode::FPFma16, Flags{control}, a, b, c); + case Type::F32: + return Inst(Opcode::FPFma32, Flags{control}, a, b, c); + case Type::F64: + return Inst(Opcode::FPFma64, Flags{control}, a, b, c); default: ThrowInvalidType(a.Type()); } } -U16U32U64 IREmitter::FPAbs(const U16U32U64& value) { +F16F32F64 IREmitter::FPAbs(const F16F32F64& value) { switch (value.Type()) { case Type::U16: - return Inst(Opcode::FPAbs16, value); + return Inst(Opcode::FPAbs16, value); case Type::U32: - return Inst(Opcode::FPAbs32, value); + return Inst(Opcode::FPAbs32, value); case Type::U64: - return Inst(Opcode::FPAbs64, value); + return Inst(Opcode::FPAbs64, value); default: ThrowInvalidType(value.Type()); } } -U16U32U64 IREmitter::FPNeg(const U16U32U64& value) { +F16F32F64 IREmitter::FPNeg(const F16F32F64& value) { switch (value.Type()) { case Type::U16: - return Inst(Opcode::FPNeg16, value); + return Inst(Opcode::FPNeg16, value); case Type::U32: - return Inst(Opcode::FPNeg32, value); + return Inst(Opcode::FPNeg32, value); case Type::U64: - return Inst(Opcode::FPNeg64, value); + return Inst(Opcode::FPNeg64, value); default: ThrowInvalidType(value.Type()); } } -U16U32U64 IREmitter::FPAbsNeg(const U16U32U64& value, bool abs, bool neg) { - U16U32U64 result{value}; +F16F32F64 IREmitter::FPAbsNeg(const F16F32F64& value, bool abs, bool neg) { + F16F32F64 result{value}; if (abs) { result = FPAbs(value); } @@ -383,108 +476,108 @@ U16U32U64 IREmitter::FPAbsNeg(const U16U32U64& value, bool abs, bool neg) { return result; } -U32 IREmitter::FPCosNotReduced(const U32& value) { - return Inst(Opcode::FPCosNotReduced, value); +F32 IREmitter::FPCosNotReduced(const F32& value) { + return Inst(Opcode::FPCosNotReduced, value); } -U32 IREmitter::FPExp2NotReduced(const U32& value) { - return Inst(Opcode::FPExp2NotReduced, value); +F32 IREmitter::FPExp2NotReduced(const F32& value) { + return Inst(Opcode::FPExp2NotReduced, value); } -U32 IREmitter::FPLog2(const U32& value) { - return Inst(Opcode::FPLog2, value); +F32 IREmitter::FPLog2(const F32& value) { + return Inst(Opcode::FPLog2, value); } -U32U64 IREmitter::FPRecip(const U32U64& value) { +F32F64 IREmitter::FPRecip(const F32F64& value) { switch (value.Type()) { case Type::U32: - return Inst(Opcode::FPRecip32, value); + return Inst(Opcode::FPRecip32, value); case Type::U64: - return Inst(Opcode::FPRecip64, value); + return Inst(Opcode::FPRecip64, value); default: ThrowInvalidType(value.Type()); } } -U32U64 IREmitter::FPRecipSqrt(const U32U64& value) { +F32F64 IREmitter::FPRecipSqrt(const F32F64& value) { switch (value.Type()) { case Type::U32: - return Inst(Opcode::FPRecipSqrt32, value); + return Inst(Opcode::FPRecipSqrt32, value); case Type::U64: - return Inst(Opcode::FPRecipSqrt64, value); + return Inst(Opcode::FPRecipSqrt64, value); default: ThrowInvalidType(value.Type()); } } -U32 IREmitter::FPSinNotReduced(const U32& value) { - return Inst(Opcode::FPSinNotReduced, value); +F32 IREmitter::FPSinNotReduced(const F32& value) { + return Inst(Opcode::FPSinNotReduced, value); } -U32 IREmitter::FPSqrt(const U32& value) { - return Inst(Opcode::FPSqrt, value); +F32 IREmitter::FPSqrt(const F32& value) { + return Inst(Opcode::FPSqrt, value); } -U16U32U64 IREmitter::FPSaturate(const U16U32U64& value) { +F16F32F64 IREmitter::FPSaturate(const F16F32F64& value) { switch (value.Type()) { case Type::U16: - return Inst(Opcode::FPSaturate16, value); + return Inst(Opcode::FPSaturate16, value); case Type::U32: - return Inst(Opcode::FPSaturate32, value); + return Inst(Opcode::FPSaturate32, value); case Type::U64: - return Inst(Opcode::FPSaturate64, value); + return Inst(Opcode::FPSaturate64, value); default: ThrowInvalidType(value.Type()); } } -U16U32U64 IREmitter::FPRoundEven(const U16U32U64& value) { +F16F32F64 IREmitter::FPRoundEven(const F16F32F64& value) { switch (value.Type()) { case Type::U16: - return Inst(Opcode::FPRoundEven16, value); + return Inst(Opcode::FPRoundEven16, value); case Type::U32: - return Inst(Opcode::FPRoundEven32, value); + return Inst(Opcode::FPRoundEven32, value); case Type::U64: - return Inst(Opcode::FPRoundEven64, value); + return Inst(Opcode::FPRoundEven64, value); default: ThrowInvalidType(value.Type()); } } -U16U32U64 IREmitter::FPFloor(const U16U32U64& value) { +F16F32F64 IREmitter::FPFloor(const F16F32F64& value) { switch (value.Type()) { case Type::U16: - return Inst(Opcode::FPFloor16, value); + return Inst(Opcode::FPFloor16, value); case Type::U32: - return Inst(Opcode::FPFloor32, value); + return Inst(Opcode::FPFloor32, value); case Type::U64: - return Inst(Opcode::FPFloor64, value); + return Inst(Opcode::FPFloor64, value); default: ThrowInvalidType(value.Type()); } } -U16U32U64 IREmitter::FPCeil(const U16U32U64& value) { +F16F32F64 IREmitter::FPCeil(const F16F32F64& value) { switch (value.Type()) { case Type::U16: - return Inst(Opcode::FPCeil16, value); + return Inst(Opcode::FPCeil16, value); case Type::U32: - return Inst(Opcode::FPCeil32, value); + return Inst(Opcode::FPCeil32, value); case Type::U64: - return Inst(Opcode::FPCeil64, value); + return Inst(Opcode::FPCeil64, value); default: ThrowInvalidType(value.Type()); } } -U16U32U64 IREmitter::FPTrunc(const U16U32U64& value) { +F16F32F64 IREmitter::FPTrunc(const F16F32F64& value) { switch (value.Type()) { case Type::U16: - return Inst(Opcode::FPTrunc16, value); + return Inst(Opcode::FPTrunc16, value); case Type::U32: - return Inst(Opcode::FPTrunc32, value); + return Inst(Opcode::FPTrunc32, value); case Type::U64: - return Inst(Opcode::FPTrunc64, value); + return Inst(Opcode::FPTrunc64, value); default: ThrowInvalidType(value.Type()); } @@ -605,7 +698,7 @@ U1 IREmitter::LogicalNot(const U1& value) { return Inst(Opcode::LogicalNot, value); } -U32U64 IREmitter::ConvertFToS(size_t bitsize, const U16U32U64& value) { +U32U64 IREmitter::ConvertFToS(size_t bitsize, const F16F32F64& value) { switch (bitsize) { case 16: switch (value.Type()) { @@ -645,7 +738,7 @@ U32U64 IREmitter::ConvertFToS(size_t bitsize, const U16U32U64& value) { } } -U32U64 IREmitter::ConvertFToU(size_t bitsize, const U16U32U64& value) { +U32U64 IREmitter::ConvertFToU(size_t bitsize, const F16F32F64& value) { switch (bitsize) { case 16: switch (value.Type()) { @@ -685,7 +778,7 @@ U32U64 IREmitter::ConvertFToU(size_t bitsize, const U16U32U64& value) { } } -U32U64 IREmitter::ConvertFToI(size_t bitsize, bool is_signed, const U16U32U64& value) { +U32U64 IREmitter::ConvertFToI(size_t bitsize, bool is_signed, const F16F32F64& value) { if (is_signed) { return ConvertFToS(bitsize, value); } else { diff --git a/src/shader_recompiler/frontend/ir/ir_emitter.h b/src/shader_recompiler/frontend/ir/ir_emitter.h index 84b844898f..bfd9916cca 100644 --- a/src/shader_recompiler/frontend/ir/ir_emitter.h +++ b/src/shader_recompiler/frontend/ir/ir_emitter.h @@ -27,9 +27,9 @@ public: [[nodiscard]] U16 Imm16(u16 value) const; [[nodiscard]] U32 Imm32(u32 value) const; [[nodiscard]] U32 Imm32(s32 value) const; - [[nodiscard]] U32 Imm32(f32 value) const; + [[nodiscard]] F32 Imm32(f32 value) const; [[nodiscard]] U64 Imm64(u64 value) const; - [[nodiscard]] U64 Imm64(f64 value) const; + [[nodiscard]] F64 Imm64(f64 value) const; void Branch(IR::Block* label); void BranchConditional(const U1& cond, IR::Block* true_label, IR::Block* false_label); @@ -55,8 +55,8 @@ public: void SetCFlag(const U1& value); void SetOFlag(const U1& value); - [[nodiscard]] U32 GetAttribute(IR::Attribute attribute); - void SetAttribute(IR::Attribute attribute, const U32& value); + [[nodiscard]] F32 GetAttribute(IR::Attribute attribute); + void SetAttribute(IR::Attribute attribute, const F32& value); [[nodiscard]] U32 WorkgroupIdX(); [[nodiscard]] U32 WorkgroupIdY(); @@ -87,44 +87,47 @@ public: [[nodiscard]] U1 GetCarryFromOp(const Value& op); [[nodiscard]] U1 GetOverflowFromOp(const Value& op); - [[nodiscard]] Value CompositeConstruct(const UAny& e1, const UAny& e2); - [[nodiscard]] Value CompositeConstruct(const UAny& e1, const UAny& e2, const UAny& e3); - [[nodiscard]] Value CompositeConstruct(const UAny& e1, const UAny& e2, const UAny& e3, - const UAny& e4); - [[nodiscard]] UAny CompositeExtract(const Value& vector, size_t element); + [[nodiscard]] Value CompositeConstruct(const Value& e1, const Value& e2); + [[nodiscard]] Value CompositeConstruct(const Value& e1, const Value& e2, const Value& e3); + [[nodiscard]] Value CompositeConstruct(const Value& e1, const Value& e2, const Value& e3, + const Value& e4); + [[nodiscard]] Value CompositeExtract(const Value& vector, size_t element); [[nodiscard]] UAny Select(const U1& condition, const UAny& true_value, const UAny& false_value); + template + [[nodiscard]] Dest BitCast(const Source& value); + [[nodiscard]] U64 PackUint2x32(const Value& vector); [[nodiscard]] Value UnpackUint2x32(const U64& value); [[nodiscard]] U32 PackFloat2x16(const Value& vector); [[nodiscard]] Value UnpackFloat2x16(const U32& value); - [[nodiscard]] U64 PackDouble2x32(const Value& vector); - [[nodiscard]] Value UnpackDouble2x32(const U64& value); + [[nodiscard]] F64 PackDouble2x32(const Value& vector); + [[nodiscard]] Value UnpackDouble2x32(const F64& value); - [[nodiscard]] U16U32U64 FPAdd(const U16U32U64& a, const U16U32U64& b, FpControl control = {}); - [[nodiscard]] U16U32U64 FPMul(const U16U32U64& a, const U16U32U64& b, FpControl control = {}); - [[nodiscard]] U16U32U64 FPFma(const U16U32U64& a, const U16U32U64& b, const U16U32U64& c, + [[nodiscard]] F16F32F64 FPAdd(const F16F32F64& a, const F16F32F64& b, FpControl control = {}); + [[nodiscard]] F16F32F64 FPMul(const F16F32F64& a, const F16F32F64& b, FpControl control = {}); + [[nodiscard]] F16F32F64 FPFma(const F16F32F64& a, const F16F32F64& b, const F16F32F64& c, FpControl control = {}); - [[nodiscard]] U16U32U64 FPAbs(const U16U32U64& value); - [[nodiscard]] U16U32U64 FPNeg(const U16U32U64& value); - [[nodiscard]] U16U32U64 FPAbsNeg(const U16U32U64& value, bool abs, bool neg); + [[nodiscard]] F16F32F64 FPAbs(const F16F32F64& value); + [[nodiscard]] F16F32F64 FPNeg(const F16F32F64& value); + [[nodiscard]] F16F32F64 FPAbsNeg(const F16F32F64& value, bool abs, bool neg); - [[nodiscard]] U32 FPCosNotReduced(const U32& value); - [[nodiscard]] U32 FPExp2NotReduced(const U32& value); - [[nodiscard]] U32 FPLog2(const U32& value); - [[nodiscard]] U32U64 FPRecip(const U32U64& value); - [[nodiscard]] U32U64 FPRecipSqrt(const U32U64& value); - [[nodiscard]] U32 FPSinNotReduced(const U32& value); - [[nodiscard]] U32 FPSqrt(const U32& value); - [[nodiscard]] U16U32U64 FPSaturate(const U16U32U64& value); - [[nodiscard]] U16U32U64 FPRoundEven(const U16U32U64& value); - [[nodiscard]] U16U32U64 FPFloor(const U16U32U64& value); - [[nodiscard]] U16U32U64 FPCeil(const U16U32U64& value); - [[nodiscard]] U16U32U64 FPTrunc(const U16U32U64& value); + [[nodiscard]] F32 FPCosNotReduced(const F32& value); + [[nodiscard]] F32 FPExp2NotReduced(const F32& value); + [[nodiscard]] F32 FPLog2(const F32& value); + [[nodiscard]] F32F64 FPRecip(const F32F64& value); + [[nodiscard]] F32F64 FPRecipSqrt(const F32F64& value); + [[nodiscard]] F32 FPSinNotReduced(const F32& value); + [[nodiscard]] F32 FPSqrt(const F32& value); + [[nodiscard]] F16F32F64 FPSaturate(const F16F32F64& value); + [[nodiscard]] F16F32F64 FPRoundEven(const F16F32F64& value); + [[nodiscard]] F16F32F64 FPFloor(const F16F32F64& value); + [[nodiscard]] F16F32F64 FPCeil(const F16F32F64& value); + [[nodiscard]] F16F32F64 FPTrunc(const F16F32F64& value); [[nodiscard]] U32U64 IAdd(const U32U64& a, const U32U64& b); [[nodiscard]] U32U64 ISub(const U32U64& a, const U32U64& b); @@ -154,9 +157,9 @@ public: [[nodiscard]] U1 LogicalXor(const U1& a, const U1& b); [[nodiscard]] U1 LogicalNot(const U1& value); - [[nodiscard]] U32U64 ConvertFToS(size_t bitsize, const U16U32U64& value); - [[nodiscard]] U32U64 ConvertFToU(size_t bitsize, const U16U32U64& value); - [[nodiscard]] U32U64 ConvertFToI(size_t bitsize, bool is_signed, const U16U32U64& value); + [[nodiscard]] U32U64 ConvertFToS(size_t bitsize, const F16F32F64& value); + [[nodiscard]] U32U64 ConvertFToU(size_t bitsize, const F16F32F64& value); + [[nodiscard]] U32U64 ConvertFToI(size_t bitsize, bool is_signed, const F16F32F64& value); [[nodiscard]] U32U64 ConvertU(size_t result_bitsize, const U32U64& value); diff --git a/src/shader_recompiler/frontend/ir/opcode.inc b/src/shader_recompiler/frontend/ir/opcode.inc index 4596bf39f7..6eb105d929 100644 --- a/src/shader_recompiler/frontend/ir/opcode.inc +++ b/src/shader_recompiler/frontend/ir/opcode.inc @@ -52,15 +52,15 @@ OPCODE(LoadGlobalS8, U32, U64, OPCODE(LoadGlobalU16, U32, U64, ) OPCODE(LoadGlobalS16, U32, U64, ) OPCODE(LoadGlobal32, U32, U64, ) -OPCODE(LoadGlobal64, Opaque, U64, ) -OPCODE(LoadGlobal128, Opaque, U64, ) +OPCODE(LoadGlobal64, U32x2, U64, ) +OPCODE(LoadGlobal128, U32x4, U64, ) OPCODE(WriteGlobalU8, Void, U64, U32, ) OPCODE(WriteGlobalS8, Void, U64, U32, ) OPCODE(WriteGlobalU16, Void, U64, U32, ) OPCODE(WriteGlobalS16, Void, U64, U32, ) OPCODE(WriteGlobal32, Void, U64, U32, ) -OPCODE(WriteGlobal64, Void, U64, Opaque, ) -OPCODE(WriteGlobal128, Void, U64, Opaque, ) +OPCODE(WriteGlobal64, Void, U64, U32x2, ) +OPCODE(WriteGlobal128, Void, U64, U32x4, ) // Storage buffer operations OPCODE(LoadStorageU8, U32, U32, U32, ) @@ -68,21 +68,41 @@ OPCODE(LoadStorageS8, U32, U32, OPCODE(LoadStorageU16, U32, U32, U32, ) OPCODE(LoadStorageS16, U32, U32, U32, ) OPCODE(LoadStorage32, U32, U32, U32, ) -OPCODE(LoadStorage64, Opaque, U32, U32, ) -OPCODE(LoadStorage128, Opaque, U32, U32, ) -OPCODE(WriteStorageU8, Void, U32, U32, U32, ) -OPCODE(WriteStorageS8, Void, U32, U32, U32, ) -OPCODE(WriteStorageU16, Void, U32, U32, U32, ) -OPCODE(WriteStorageS16, Void, U32, U32, U32, ) -OPCODE(WriteStorage32, Void, U32, U32, U32, ) -OPCODE(WriteStorage64, Void, U32, U32, Opaque, ) -OPCODE(WriteStorage128, Void, U32, U32, Opaque, ) +OPCODE(LoadStorage64, U32x2, U32, U32, ) +OPCODE(LoadStorage128, U32x4, U32, U32, ) +OPCODE(WriteStorageU8, Void, U32, U32, U32, ) +OPCODE(WriteStorageS8, Void, U32, U32, U32, ) +OPCODE(WriteStorageU16, Void, U32, U32, U32, ) +OPCODE(WriteStorageS16, Void, U32, U32, U32, ) +OPCODE(WriteStorage32, Void, U32, U32, U32, ) +OPCODE(WriteStorage64, Void, U32, U32, U32x2, ) +OPCODE(WriteStorage128, Void, U32, U32, U32x4, ) // Vector utility -OPCODE(CompositeConstruct2, Opaque, Opaque, Opaque, ) -OPCODE(CompositeConstruct3, Opaque, Opaque, Opaque, Opaque, ) -OPCODE(CompositeConstruct4, Opaque, Opaque, Opaque, Opaque, Opaque, ) -OPCODE(CompositeExtract, Opaque, Opaque, U32, ) +OPCODE(CompositeConstructU32x2, U32x2, U32, U32, ) +OPCODE(CompositeConstructU32x3, U32x3, U32, U32, U32, ) +OPCODE(CompositeConstructU32x4, U32x4, U32, U32, U32, U32, ) +OPCODE(CompositeExtractU32x2, U32, U32x2, U32, ) +OPCODE(CompositeExtractU32x3, U32, U32x3, U32, ) +OPCODE(CompositeExtractU32x4, U32, U32x4, U32, ) +OPCODE(CompositeConstructF16x2, F16x2, F16, F16, ) +OPCODE(CompositeConstructF16x3, F16x3, F16, F16, F16, ) +OPCODE(CompositeConstructF16x4, F16x4, F16, F16, F16, F16, ) +OPCODE(CompositeExtractF16x2, F16, F16x2, U32, ) +OPCODE(CompositeExtractF16x3, F16, F16x3, U32, ) +OPCODE(CompositeExtractF16x4, F16, F16x4, U32, ) +OPCODE(CompositeConstructF32x2, F32x2, F32, F32, ) +OPCODE(CompositeConstructF32x3, F32x3, F32, F32, F32, ) +OPCODE(CompositeConstructF32x4, F32x4, F32, F32, F32, F32, ) +OPCODE(CompositeExtractF32x2, F32, F32x2, U32, ) +OPCODE(CompositeExtractF32x3, F32, F32x3, U32, ) +OPCODE(CompositeExtractF32x4, F32, F32x4, U32, ) +OPCODE(CompositeConstructF64x2, F64x2, F64, F64, ) +OPCODE(CompositeConstructF64x3, F64x3, F64, F64, F64, ) +OPCODE(CompositeConstructF64x4, F64x4, F64, F64, F64, F64, ) +OPCODE(CompositeExtractF64x2, F64, F64x2, U32, ) +OPCODE(CompositeExtractF64x3, F64, F64x3, U32, ) +OPCODE(CompositeExtractF64x4, F64, F64x4, U32, ) // Select operations OPCODE(Select8, U8, U1, U8, U8, ) @@ -91,12 +111,18 @@ OPCODE(Select32, U32, U1, OPCODE(Select64, U64, U1, U64, U64, ) // Bitwise conversions -OPCODE(PackUint2x32, U64, Opaque, ) -OPCODE(UnpackUint2x32, Opaque, U64, ) -OPCODE(PackFloat2x16, U32, Opaque, ) -OPCODE(UnpackFloat2x16, Opaque, U32, ) -OPCODE(PackDouble2x32, U64, Opaque, ) -OPCODE(UnpackDouble2x32, Opaque, U64, ) +OPCODE(BitCastU16F16, U16, F16, ) +OPCODE(BitCastU32F32, U32, F32, ) +OPCODE(BitCastU64F64, U64, F64, ) +OPCODE(BitCastF16U16, F16, U16, ) +OPCODE(BitCastF32U32, F32, U32, ) +OPCODE(BitCastF64U64, F64, U64, ) +OPCODE(PackUint2x32, U64, U32x2, ) +OPCODE(UnpackUint2x32, U32x2, U64, ) +OPCODE(PackFloat2x16, U32, F16x2, ) +OPCODE(UnpackFloat2x16, F16x2, U32, ) +OPCODE(PackDouble2x32, U64, U32x2, ) +OPCODE(UnpackDouble2x32, U32x2, U64, ) // Pseudo-operation, handled specially at final emit OPCODE(GetZeroFromOp, U1, Opaque, ) @@ -105,52 +131,52 @@ OPCODE(GetCarryFromOp, U1, Opaq OPCODE(GetOverflowFromOp, U1, Opaque, ) // Floating-point operations -OPCODE(FPAbs16, U16, U16, ) -OPCODE(FPAbs32, U32, U32, ) -OPCODE(FPAbs64, U64, U64, ) -OPCODE(FPAdd16, U16, U16, U16, ) -OPCODE(FPAdd32, U32, U32, U32, ) -OPCODE(FPAdd64, U64, U64, U64, ) -OPCODE(FPFma16, U16, U16, U16, U16, ) -OPCODE(FPFma32, U32, U32, U32, U32, ) -OPCODE(FPFma64, U64, U64, U64, U64, ) -OPCODE(FPMax32, U32, U32, U32, ) -OPCODE(FPMax64, U64, U64, U64, ) -OPCODE(FPMin32, U32, U32, U32, ) -OPCODE(FPMin64, U64, U64, U64, ) -OPCODE(FPMul16, U16, U16, U16, ) -OPCODE(FPMul32, U32, U32, U32, ) -OPCODE(FPMul64, U64, U64, U64, ) -OPCODE(FPNeg16, U16, U16, ) -OPCODE(FPNeg32, U32, U32, ) -OPCODE(FPNeg64, U64, U64, ) -OPCODE(FPRecip32, U32, U32, ) -OPCODE(FPRecip64, U64, U64, ) -OPCODE(FPRecipSqrt32, U32, U32, ) -OPCODE(FPRecipSqrt64, U64, U64, ) -OPCODE(FPSqrt, U32, U32, ) -OPCODE(FPSin, U32, U32, ) -OPCODE(FPSinNotReduced, U32, U32, ) -OPCODE(FPExp2, U32, U32, ) -OPCODE(FPExp2NotReduced, U32, U32, ) -OPCODE(FPCos, U32, U32, ) -OPCODE(FPCosNotReduced, U32, U32, ) -OPCODE(FPLog2, U32, U32, ) -OPCODE(FPSaturate16, U16, U16, ) -OPCODE(FPSaturate32, U32, U32, ) -OPCODE(FPSaturate64, U64, U64, ) -OPCODE(FPRoundEven16, U16, U16, ) -OPCODE(FPRoundEven32, U32, U32, ) -OPCODE(FPRoundEven64, U64, U64, ) -OPCODE(FPFloor16, U16, U16, ) -OPCODE(FPFloor32, U32, U32, ) -OPCODE(FPFloor64, U64, U64, ) -OPCODE(FPCeil16, U16, U16, ) -OPCODE(FPCeil32, U32, U32, ) -OPCODE(FPCeil64, U64, U64, ) -OPCODE(FPTrunc16, U16, U16, ) -OPCODE(FPTrunc32, U32, U32, ) -OPCODE(FPTrunc64, U64, U64, ) +OPCODE(FPAbs16, F16, F16, ) +OPCODE(FPAbs32, F32, F32, ) +OPCODE(FPAbs64, F64, F64, ) +OPCODE(FPAdd16, F16, F16, F16, ) +OPCODE(FPAdd32, F32, F32, F32, ) +OPCODE(FPAdd64, F64, F64, F64, ) +OPCODE(FPFma16, F16, F16, F16, F16, ) +OPCODE(FPFma32, F32, F32, F32, F32, ) +OPCODE(FPFma64, F64, F64, F64, F64, ) +OPCODE(FPMax32, F32, F32, F32, ) +OPCODE(FPMax64, F64, F64, F64, ) +OPCODE(FPMin32, F32, F32, F32, ) +OPCODE(FPMin64, F64, F64, F64, ) +OPCODE(FPMul16, F16, F16, F16, ) +OPCODE(FPMul32, F32, F32, F32, ) +OPCODE(FPMul64, F64, F64, F64, ) +OPCODE(FPNeg16, F16, F16, ) +OPCODE(FPNeg32, F32, F32, ) +OPCODE(FPNeg64, F64, F64, ) +OPCODE(FPRecip32, F32, F32, ) +OPCODE(FPRecip64, F64, F64, ) +OPCODE(FPRecipSqrt32, F32, F32, ) +OPCODE(FPRecipSqrt64, F64, F64, ) +OPCODE(FPSqrt, F32, F32, ) +OPCODE(FPSin, F32, F32, ) +OPCODE(FPSinNotReduced, F32, F32, ) +OPCODE(FPExp2, F32, F32, ) +OPCODE(FPExp2NotReduced, F32, F32, ) +OPCODE(FPCos, F32, F32, ) +OPCODE(FPCosNotReduced, F32, F32, ) +OPCODE(FPLog2, F32, F32, ) +OPCODE(FPSaturate16, F16, F16, ) +OPCODE(FPSaturate32, F32, F32, ) +OPCODE(FPSaturate64, F64, F64, ) +OPCODE(FPRoundEven16, F16, F16, ) +OPCODE(FPRoundEven32, F32, F32, ) +OPCODE(FPRoundEven64, F64, F64, ) +OPCODE(FPFloor16, F16, F16, ) +OPCODE(FPFloor32, F32, F32, ) +OPCODE(FPFloor64, F64, F64, ) +OPCODE(FPCeil16, F16, F16, ) +OPCODE(FPCeil32, F32, F32, ) +OPCODE(FPCeil64, F64, F64, ) +OPCODE(FPTrunc16, F16, F16, ) +OPCODE(FPTrunc32, F32, F32, ) +OPCODE(FPTrunc64, F64, F64, ) // Integer operations OPCODE(IAdd32, U32, U32, U32, ) @@ -188,24 +214,24 @@ OPCODE(LogicalXor, U1, U1, OPCODE(LogicalNot, U1, U1, ) // Conversion operations -OPCODE(ConvertS16F16, U32, U16, ) -OPCODE(ConvertS16F32, U32, U32, ) -OPCODE(ConvertS16F64, U32, U64, ) -OPCODE(ConvertS32F16, U32, U16, ) -OPCODE(ConvertS32F32, U32, U32, ) -OPCODE(ConvertS32F64, U32, U64, ) -OPCODE(ConvertS64F16, U64, U16, ) -OPCODE(ConvertS64F32, U64, U32, ) -OPCODE(ConvertS64F64, U64, U64, ) -OPCODE(ConvertU16F16, U32, U16, ) -OPCODE(ConvertU16F32, U32, U32, ) -OPCODE(ConvertU16F64, U32, U64, ) -OPCODE(ConvertU32F16, U32, U16, ) -OPCODE(ConvertU32F32, U32, U32, ) -OPCODE(ConvertU32F64, U32, U64, ) -OPCODE(ConvertU64F16, U64, U16, ) -OPCODE(ConvertU64F32, U64, U32, ) -OPCODE(ConvertU64F64, U64, U64, ) +OPCODE(ConvertS16F16, U32, F16, ) +OPCODE(ConvertS16F32, U32, F32, ) +OPCODE(ConvertS16F64, U32, F64, ) +OPCODE(ConvertS32F16, U32, F16, ) +OPCODE(ConvertS32F32, U32, F32, ) +OPCODE(ConvertS32F64, U32, F64, ) +OPCODE(ConvertS64F16, U64, F16, ) +OPCODE(ConvertS64F32, U64, F32, ) +OPCODE(ConvertS64F64, U64, F64, ) +OPCODE(ConvertU16F16, U32, F16, ) +OPCODE(ConvertU16F32, U32, F32, ) +OPCODE(ConvertU16F64, U32, F64, ) +OPCODE(ConvertU32F16, U32, F16, ) +OPCODE(ConvertU32F32, U32, F32, ) +OPCODE(ConvertU32F64, U32, F64, ) +OPCODE(ConvertU64F16, U64, F16, ) +OPCODE(ConvertU64F32, U64, F32, ) +OPCODE(ConvertU64F64, U64, F64, ) OPCODE(ConvertU64U32, U64, U32, ) OPCODE(ConvertU32U64, U32, U64, ) diff --git a/src/shader_recompiler/frontend/ir/type.cpp b/src/shader_recompiler/frontend/ir/type.cpp index 13cc091956..f28341bfe7 100644 --- a/src/shader_recompiler/frontend/ir/type.cpp +++ b/src/shader_recompiler/frontend/ir/type.cpp @@ -11,7 +11,9 @@ namespace Shader::IR { std::string NameOf(Type type) { static constexpr std::array names{ - "Opaque", "Label", "Reg", "Pred", "Attribute", "U1", "U8", "U16", "U32", "U64", + "Opaque", "Label", "Reg", "Pred", "Attribute", "U1", "U8", "U16", "U32", + "U64", "F16", "F32", "F64", "U32x2", "U32x3", "U32x4", "F16x2", "F16x3", + "F16x4", "F32x2", "F32x3", "F32x4", "F64x2", "F64x3", "F64x4", }; const size_t bits{static_cast(type)}; if (bits == 0) { diff --git a/src/shader_recompiler/frontend/ir/type.h b/src/shader_recompiler/frontend/ir/type.h index 397875018b..9a32ca1e8a 100644 --- a/src/shader_recompiler/frontend/ir/type.h +++ b/src/shader_recompiler/frontend/ir/type.h @@ -25,6 +25,21 @@ enum class Type { U16 = 1 << 7, U32 = 1 << 8, U64 = 1 << 9, + F16 = 1 << 10, + F32 = 1 << 11, + F64 = 1 << 12, + U32x2 = 1 << 13, + U32x3 = 1 << 14, + U32x4 = 1 << 15, + F16x2 = 1 << 16, + F16x3 = 1 << 17, + F16x4 = 1 << 18, + F32x2 = 1 << 19, + F32x3 = 1 << 20, + F32x4 = 1 << 21, + F64x2 = 1 << 22, + F64x3 = 1 << 23, + F64x4 = 1 << 24, }; DECLARE_ENUM_FLAG_OPERATORS(Type) diff --git a/src/shader_recompiler/frontend/ir/value.cpp b/src/shader_recompiler/frontend/ir/value.cpp index 59a9b10dc9..93ff8ccf16 100644 --- a/src/shader_recompiler/frontend/ir/value.cpp +++ b/src/shader_recompiler/frontend/ir/value.cpp @@ -26,8 +26,12 @@ Value::Value(u16 value) noexcept : type{Type::U16}, imm_u16{value} {} Value::Value(u32 value) noexcept : type{Type::U32}, imm_u32{value} {} +Value::Value(f32 value) noexcept : type{Type::F32}, imm_f32{value} {} + Value::Value(u64 value) noexcept : type{Type::U64}, imm_u64{value} {} +Value::Value(f64 value) noexcept : type{Type::F64}, imm_f64{value} {} + bool Value::IsIdentity() const noexcept { return type == Type::Opaque && inst->Opcode() == Opcode::Identity; } @@ -122,6 +126,14 @@ u32 Value::U32() const { return imm_u32; } +f32 Value::F32() const { + if (IsIdentity()) { + return inst->Arg(0).F32(); + } + ValidateAccess(Type::F32); + return imm_f32; +} + u64 Value::U64() const { if (IsIdentity()) { return inst->Arg(0).U64(); @@ -152,11 +164,27 @@ bool Value::operator==(const Value& other) const { case Type::U8: return imm_u8 == other.imm_u8; case Type::U16: + case Type::F16: return imm_u16 == other.imm_u16; case Type::U32: + case Type::F32: return imm_u32 == other.imm_u32; case Type::U64: + case Type::F64: return imm_u64 == other.imm_u64; + case Type::U32x2: + case Type::U32x3: + case Type::U32x4: + case Type::F16x2: + case Type::F16x3: + case Type::F16x4: + case Type::F32x2: + case Type::F32x3: + case Type::F32x4: + case Type::F64x2: + case Type::F64x3: + case Type::F64x4: + break; } throw LogicError("Invalid type {}", type); } diff --git a/src/shader_recompiler/frontend/ir/value.h b/src/shader_recompiler/frontend/ir/value.h index 31f8317940..2f3688c736 100644 --- a/src/shader_recompiler/frontend/ir/value.h +++ b/src/shader_recompiler/frontend/ir/value.h @@ -28,7 +28,9 @@ public: explicit Value(u8 value) noexcept; explicit Value(u16 value) noexcept; explicit Value(u32 value) noexcept; + explicit Value(f32 value) noexcept; explicit Value(u64 value) noexcept; + explicit Value(f64 value) noexcept; [[nodiscard]] bool IsIdentity() const noexcept; [[nodiscard]] bool IsEmpty() const noexcept; @@ -46,6 +48,7 @@ public: [[nodiscard]] u8 U8() const; [[nodiscard]] u16 U16() const; [[nodiscard]] u32 U32() const; + [[nodiscard]] f32 F32() const; [[nodiscard]] u64 U64() const; [[nodiscard]] bool operator==(const Value& other) const; @@ -65,7 +68,9 @@ private: u8 imm_u8; u16 imm_u16; u32 imm_u32; + f32 imm_f32; u64 imm_u64; + f64 imm_f64; }; }; @@ -93,8 +98,13 @@ using U8 = TypedValue; using U16 = TypedValue; using U32 = TypedValue; using U64 = TypedValue; +using F16 = TypedValue; +using F32 = TypedValue; +using F64 = TypedValue; using U32U64 = TypedValue; +using F32F64 = TypedValue; using U16U32U64 = TypedValue; +using F16F32F64 = TypedValue; using UAny = TypedValue; } // namespace Shader::IR diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_add.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_add.cpp index d2c44b9ccd..cb3a326cfa 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_add.cpp +++ b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_add.cpp @@ -11,7 +11,7 @@ namespace Shader::Maxwell { namespace { void FADD(TranslatorVisitor& v, u64 insn, bool sat, bool cc, bool ftz, FpRounding fp_rounding, - const IR::U32& src_b, bool abs_a, bool neg_a, bool abs_b, bool neg_b) { + const IR::F32& src_b, bool abs_a, bool neg_a, bool abs_b, bool neg_b) { union { u64 raw; BitField<0, 8, IR::Reg> dest_reg; @@ -24,17 +24,17 @@ void FADD(TranslatorVisitor& v, u64 insn, bool sat, bool cc, bool ftz, FpRoundin if (cc) { throw NotImplementedException("FADD CC"); } - const IR::U32 op_a{v.ir.FPAbsNeg(v.X(fadd.src_a), abs_a, neg_a)}; - const IR::U32 op_b{v.ir.FPAbsNeg(src_b, abs_b, neg_b)}; + const IR::F32 op_a{v.ir.FPAbsNeg(v.F(fadd.src_a), abs_a, neg_a)}; + const IR::F32 op_b{v.ir.FPAbsNeg(src_b, abs_b, neg_b)}; IR::FpControl control{ .no_contraction{true}, .rounding{CastFpRounding(fp_rounding)}, .fmz_mode{ftz ? IR::FmzMode::FTZ : IR::FmzMode::None}, }; - v.X(fadd.dest_reg, v.ir.FPAdd(op_a, op_b, control)); + v.F(fadd.dest_reg, v.ir.FPAdd(op_a, op_b, control)); } -void FADD(TranslatorVisitor& v, u64 insn, const IR::U32& src_b) { +void FADD(TranslatorVisitor& v, u64 insn, const IR::F32& src_b) { union { u64 raw; BitField<39, 2, FpRounding> fp_rounding; @@ -53,7 +53,7 @@ void FADD(TranslatorVisitor& v, u64 insn, const IR::U32& src_b) { } // Anonymous namespace void TranslatorVisitor::FADD_reg(u64 insn) { - FADD(*this, insn, GetReg20(insn)); + FADD(*this, insn, GetReg20F(insn)); } void TranslatorVisitor::FADD_cbuf(u64) { diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_conversion_integer.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_conversion_integer.cpp index c4288d9a83..acd8445ad1 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_conversion_integer.cpp +++ b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_conversion_integer.cpp @@ -55,21 +55,21 @@ size_t BitSize(DestFormat dest_format) { } } -void TranslateF2I(TranslatorVisitor& v, u64 insn, const IR::U16U32U64& op_a) { +void TranslateF2I(TranslatorVisitor& v, u64 insn, const IR::F16F32F64& src_a) { // F2I is used to convert from a floating point value to an integer const F2I f2i{insn}; - const IR::U16U32U64 float_value{v.ir.FPAbsNeg(op_a, f2i.abs != 0, f2i.neg != 0)}; - const IR::U16U32U64 rounded_value{[&] { + const IR::F16F32F64 op_a{v.ir.FPAbsNeg(src_a, f2i.abs != 0, f2i.neg != 0)}; + const IR::F16F32F64 rounded_value{[&] { switch (f2i.rounding) { case Rounding::Round: - return v.ir.FPRoundEven(float_value); + return v.ir.FPRoundEven(op_a); case Rounding::Floor: - return v.ir.FPFloor(float_value); + return v.ir.FPFloor(op_a); case Rounding::Ceil: - return v.ir.FPCeil(float_value); + return v.ir.FPCeil(op_a); case Rounding::Trunc: - return v.ir.FPTrunc(float_value); + return v.ir.FPTrunc(op_a); default: throw NotImplementedException("Invalid F2I rounding {}", f2i.rounding.Value()); } @@ -105,12 +105,12 @@ void TranslatorVisitor::F2I_reg(u64 insn) { BitField<20, 8, IR::Reg> src_reg; } const f2i{insn}; - const IR::U16U32U64 op_a{[&]() -> IR::U16U32U64 { + const IR::F16F32F64 op_a{[&]() -> IR::F16F32F64 { switch (f2i.base.src_format) { case SrcFormat::F16: - return ir.CompositeExtract(ir.UnpackFloat2x16(X(f2i.src_reg)), f2i.base.half); + return IR::F16{ir.CompositeExtract(ir.UnpackFloat2x16(X(f2i.src_reg)), f2i.base.half)}; case SrcFormat::F32: - return X(f2i.src_reg); + return F(f2i.src_reg); case SrcFormat::F64: return ir.PackDouble2x32(ir.CompositeConstruct(X(f2i.src_reg), X(f2i.src_reg + 1))); default: diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_fused_multiply_add.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_fused_multiply_add.cpp index 30ca052ec5..1464f2807a 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_fused_multiply_add.cpp +++ b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_fused_multiply_add.cpp @@ -9,7 +9,7 @@ namespace Shader::Maxwell { namespace { -void FFMA(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, const IR::U32& src_c, bool neg_a, +void FFMA(TranslatorVisitor& v, u64 insn, const IR::F32& src_b, const IR::F32& src_c, bool neg_a, bool neg_b, bool neg_c, bool sat, bool cc, FmzMode fmz_mode, FpRounding fp_rounding) { union { u64 raw; @@ -23,18 +23,18 @@ void FFMA(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, const IR::U32& s if (cc) { throw NotImplementedException("FFMA CC"); } - const IR::U32 op_a{v.ir.FPAbsNeg(v.X(ffma.src_a), false, neg_a)}; - const IR::U32 op_b{v.ir.FPAbsNeg(src_b, false, neg_b)}; - const IR::U32 op_c{v.ir.FPAbsNeg(src_c, false, neg_c)}; + const IR::F32 op_a{v.ir.FPAbsNeg(v.F(ffma.src_a), false, neg_a)}; + const IR::F32 op_b{v.ir.FPAbsNeg(src_b, false, neg_b)}; + const IR::F32 op_c{v.ir.FPAbsNeg(src_c, false, neg_c)}; const IR::FpControl fp_control{ .no_contraction{true}, .rounding{CastFpRounding(fp_rounding)}, .fmz_mode{CastFmzMode(fmz_mode)}, }; - v.X(ffma.dest_reg, v.ir.FPFma(op_a, op_b, op_c, fp_control)); + v.F(ffma.dest_reg, v.ir.FPFma(op_a, op_b, op_c, fp_control)); } -void FFMA(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, const IR::U32& src_c) { +void FFMA(TranslatorVisitor& v, u64 insn, const IR::F32& src_b, const IR::F32& src_c) { union { u64 raw; BitField<47, 1, u64> cc; @@ -51,7 +51,7 @@ void FFMA(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, const IR::U32& s } // Anonymous namespace void TranslatorVisitor::FFMA_reg(u64 insn) { - FFMA(*this, insn, GetReg20(insn), GetReg39(insn)); + FFMA(*this, insn, GetReg20F(insn), GetReg39F(insn)); } void TranslatorVisitor::FFMA_rc(u64) { @@ -59,7 +59,7 @@ void TranslatorVisitor::FFMA_rc(u64) { } void TranslatorVisitor::FFMA_cr(u64 insn) { - FFMA(*this, insn, GetCbuf(insn), GetReg39(insn)); + FFMA(*this, insn, GetCbufF(insn), GetReg39F(insn)); } void TranslatorVisitor::FFMA_imm(u64) { diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multi_function.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multi_function.cpp index e2ab0dab22..90cddb18b4 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multi_function.cpp +++ b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multi_function.cpp @@ -35,8 +35,8 @@ void TranslatorVisitor::MUFU(u64 insn) { BitField<50, 1, u64> sat; } const mufu{insn}; - const IR::U32 op_a{ir.FPAbsNeg(X(mufu.src_reg), mufu.abs != 0, mufu.neg != 0)}; - IR::U32 value{[&]() -> IR::U32 { + const IR::F32 op_a{ir.FPAbsNeg(F(mufu.src_reg), mufu.abs != 0, mufu.neg != 0)}; + IR::F32 value{[&]() -> IR::F32 { switch (mufu.operation) { case Operation::Cos: return ir.FPCosNotReduced(op_a); @@ -65,7 +65,7 @@ void TranslatorVisitor::MUFU(u64 insn) { value = ir.FPSaturate(value); } - X(mufu.dest_reg, value); + F(mufu.dest_reg, value); } } // namespace Shader::Maxwell diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multiply.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multiply.cpp index 743a1e2f0f..1b1d38be7a 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multiply.cpp +++ b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multiply.cpp @@ -4,6 +4,7 @@ #include "common/bit_field.h" #include "common/common_types.h" +#include "shader_recompiler/frontend/ir/ir_emitter.h" #include "shader_recompiler/frontend/ir/modifiers.h" #include "shader_recompiler/frontend/maxwell/translate/impl/common_encoding.h" #include "shader_recompiler/frontend/maxwell/translate/impl/impl.h" @@ -43,7 +44,7 @@ float ScaleFactor(Scale scale) { throw NotImplementedException("Invalid FMUL scale {}", scale); } -void FMUL(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, FmzMode fmz_mode, +void FMUL(TranslatorVisitor& v, u64 insn, const IR::F32& src_b, FmzMode fmz_mode, FpRounding fp_rounding, Scale scale, bool sat, bool cc, bool neg_b) { union { u64 raw; @@ -57,23 +58,23 @@ void FMUL(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, FmzMode fmz_mode if (sat) { throw NotImplementedException("FMUL SAT"); } - IR::U32 op_a{v.X(fmul.src_a)}; + IR::F32 op_a{v.F(fmul.src_a)}; if (scale != Scale::None) { if (fmz_mode != FmzMode::FTZ || fp_rounding != FpRounding::RN) { throw NotImplementedException("FMUL scale with non-FMZ or non-RN modifiers"); } op_a = v.ir.FPMul(op_a, v.ir.Imm32(ScaleFactor(scale))); } - const IR::U32 op_b{v.ir.FPAbsNeg(src_b, false, neg_b)}; + const IR::F32 op_b{v.ir.FPAbsNeg(src_b, false, neg_b)}; const IR::FpControl fp_control{ .no_contraction{true}, .rounding{CastFpRounding(fp_rounding)}, .fmz_mode{CastFmzMode(fmz_mode)}, }; - v.X(fmul.dest_reg, v.ir.FPMul(op_a, op_b, fp_control)); + v.F(fmul.dest_reg, v.ir.FPMul(op_a, op_b, fp_control)); } -void FMUL(TranslatorVisitor& v, u64 insn, const IR::U32& src_b) { +void FMUL(TranslatorVisitor& v, u64 insn, const IR::F32& src_b) { union { u64 raw; BitField<39, 2, FpRounding> fp_rounding; @@ -90,7 +91,7 @@ void FMUL(TranslatorVisitor& v, u64 insn, const IR::U32& src_b) { } // Anonymous namespace void TranslatorVisitor::FMUL_reg(u64 insn) { - return FMUL(*this, insn, GetReg20(insn)); + return FMUL(*this, insn, GetReg20F(insn)); } void TranslatorVisitor::FMUL_cbuf(u64) { diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/impl.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/impl.cpp index 548c7f611d..3c9eaddd94 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/impl/impl.cpp +++ b/src/shader_recompiler/frontend/maxwell/translate/impl/impl.cpp @@ -12,10 +12,18 @@ IR::U32 TranslatorVisitor::X(IR::Reg reg) { return ir.GetReg(reg); } +IR::F32 TranslatorVisitor::F(IR::Reg reg) { + return ir.BitCast(X(reg)); +} + void TranslatorVisitor::X(IR::Reg dest_reg, const IR::U32& value) { ir.SetReg(dest_reg, value); } +void TranslatorVisitor::F(IR::Reg dest_reg, const IR::F32& value) { + X(dest_reg, ir.BitCast(value)); +} + IR::U32 TranslatorVisitor::GetReg20(u64 insn) { union { u64 raw; @@ -32,6 +40,14 @@ IR::U32 TranslatorVisitor::GetReg39(u64 insn) { return X(reg.index); } +IR::F32 TranslatorVisitor::GetReg20F(u64 insn) { + return ir.BitCast(GetReg20(insn)); +} + +IR::F32 TranslatorVisitor::GetReg39F(u64 insn) { + return ir.BitCast(GetReg39(insn)); +} + IR::U32 TranslatorVisitor::GetCbuf(u64 insn) { union { u64 raw; @@ -49,6 +65,10 @@ IR::U32 TranslatorVisitor::GetCbuf(u64 insn) { return ir.GetCbuf(binding, byte_offset); } +IR::F32 TranslatorVisitor::GetCbufF(u64 insn) { + return ir.BitCast(GetCbuf(insn)); +} + IR::U32 TranslatorVisitor::GetImm20(u64 insn) { union { u64 raw; diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/impl.h b/src/shader_recompiler/frontend/maxwell/translate/impl/impl.h index ef6d977fef..b701605d73 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/impl/impl.h +++ b/src/shader_recompiler/frontend/maxwell/translate/impl/impl.h @@ -296,12 +296,18 @@ public: void XMAD_imm(u64 insn); [[nodiscard]] IR::U32 X(IR::Reg reg); + [[nodiscard]] IR::F32 F(IR::Reg reg); + void X(IR::Reg dest_reg, const IR::U32& value); + void F(IR::Reg dest_reg, const IR::F32& value); [[nodiscard]] IR::U32 GetReg20(u64 insn); [[nodiscard]] IR::U32 GetReg39(u64 insn); + [[nodiscard]] IR::F32 GetReg20F(u64 insn); + [[nodiscard]] IR::F32 GetReg39F(u64 insn); [[nodiscard]] IR::U32 GetCbuf(u64 insn); + [[nodiscard]] IR::F32 GetCbufF(u64 insn); [[nodiscard]] IR::U32 GetImm20(u64 insn); diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp index 23512db1a4..de65173e8d 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp +++ b/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp @@ -5,22 +5,23 @@ #include "common/bit_field.h" #include "common/common_types.h" #include "shader_recompiler/exception.h" +#include "shader_recompiler/frontend/ir/ir_emitter.h" #include "shader_recompiler/frontend/maxwell/opcode.h" #include "shader_recompiler/frontend/maxwell/translate/impl/impl.h" namespace Shader::Maxwell { namespace { enum class InterpolationMode : u64 { - Pass = 0, - Multiply = 1, - Constant = 2, - Sc = 3, + Pass, + Multiply, + Constant, + Sc, }; enum class SampleMode : u64 { - Default = 0, - Centroid = 1, - Offset = 2, + Default, + Centroid, + Offset, }; } // Anonymous namespace @@ -54,12 +55,12 @@ void TranslatorVisitor::IPA(u64 insn) { } const IR::Attribute attribute{ipa.attribute}; - IR::U32 value{ir.GetAttribute(attribute)}; + IR::F32 value{ir.GetAttribute(attribute)}; if (IR::IsGeneric(attribute)) { // const bool is_perspective{UnimplementedReadHeader(GenericAttributeIndex(attribute))}; const bool is_perspective{false}; if (is_perspective) { - const IR::U32 rcp_position_w{ir.FPRecip(ir.GetAttribute(IR::Attribute::PositionW))}; + const IR::F32 rcp_position_w{ir.FPRecip(ir.GetAttribute(IR::Attribute::PositionW))}; value = ir.FPMul(value, rcp_position_w); } } @@ -68,7 +69,7 @@ void TranslatorVisitor::IPA(u64 insn) { case InterpolationMode::Pass: break; case InterpolationMode::Multiply: - value = ir.FPMul(value, ir.GetReg(ipa.multiplier)); + value = ir.FPMul(value, F(ipa.multiplier)); break; case InterpolationMode::Constant: throw NotImplementedException("IPA.CONSTANT"); @@ -86,7 +87,7 @@ void TranslatorVisitor::IPA(u64 insn) { value = ir.FPSaturate(value); } - ir.SetReg(ipa.dest_reg, value); + F(ipa.dest_reg, value); } } // namespace Shader::Maxwell diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_memory.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_memory.cpp index c9669c6178..9f1570479d 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_memory.cpp +++ b/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_memory.cpp @@ -114,7 +114,7 @@ void TranslatorVisitor::LDG(u64 insn) { } const IR::Value vector{ir.LoadGlobal64(address)}; for (int i = 0; i < 2; ++i) { - X(dest_reg + i, ir.CompositeExtract(vector, i)); + X(dest_reg + i, IR::U32{ir.CompositeExtract(vector, i)}); } break; } @@ -124,7 +124,7 @@ void TranslatorVisitor::LDG(u64 insn) { } const IR::Value vector{ir.LoadGlobal128(address)}; for (int i = 0; i < 4; ++i) { - X(dest_reg + i, ir.CompositeExtract(vector, i)); + X(dest_reg + i, IR::U32{ir.CompositeExtract(vector, i)}); } break; } diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp index 02f5b653d4..7fb3192d8e 100644 --- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp @@ -5,6 +5,7 @@ #include #include +#include "common/bit_cast.h" #include "common/bit_util.h" #include "shader_recompiler/exception.h" #include "shader_recompiler/frontend/ir/microinstruction.h" @@ -25,6 +26,8 @@ template return value.U1(); } else if constexpr (std::is_same_v) { return value.U32(); + } else if constexpr (std::is_same_v) { + return value.F32(); } else if constexpr (std::is_same_v) { return value.U64(); } @@ -115,6 +118,19 @@ void FoldLogicalAnd(IR::Inst& inst) { } } +template +void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) { + const IR::Value value{inst.Arg(0)}; + if (value.IsImmediate()) { + inst.ReplaceUsesWith(IR::Value{Common::BitCast(Arg(value))}); + return; + } + IR::Inst* const arg_inst{value.InstRecursive()}; + if (value.InstRecursive()->Opcode() == reverse) { + inst.ReplaceUsesWith(arg_inst->Arg(0)); + } +} + void ConstantPropagation(IR::Inst& inst) { switch (inst.Opcode()) { case IR::Opcode::GetRegister: @@ -123,6 +139,10 @@ void ConstantPropagation(IR::Inst& inst) { return FoldGetPred(inst); case IR::Opcode::IAdd32: return FoldAdd(inst); + case IR::Opcode::BitCastF32U32: + return FoldBitCast(inst, IR::Opcode::BitCastU32F32); + case IR::Opcode::BitCastU32F32: + return FoldBitCast(inst, IR::Opcode::BitCastF32U32); case IR::Opcode::IAdd64: return FoldAdd(inst); case IR::Opcode::BitFieldUExtract: diff --git a/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp b/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp index ee69a5c9d9..34393e1d57 100644 --- a/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp +++ b/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp @@ -108,8 +108,8 @@ bool MeetsBias(const StorageBufferAddr& storage_buffer, const Bias& bias) noexce storage_buffer.offset < bias.offset_end; } -/// Ignores a global memory operation, reads return zero and writes are ignored -void IgnoreGlobalMemory(IR::Block& block, IR::Block::iterator inst) { +/// Discards a global memory operation, reads return zero and writes are ignored +void DiscardGlobalMemory(IR::Block& block, IR::Block::iterator inst) { const IR::Value zero{u32{0}}; switch (inst->Opcode()) { case IR::Opcode::LoadGlobalS8: @@ -120,12 +120,12 @@ void IgnoreGlobalMemory(IR::Block& block, IR::Block::iterator inst) { inst->ReplaceUsesWith(zero); break; case IR::Opcode::LoadGlobal64: - inst->ReplaceUsesWith( - IR::Value{&*block.PrependNewInst(inst, IR::Opcode::CompositeConstruct2, {zero, zero})}); + inst->ReplaceUsesWith(IR::Value{ + &*block.PrependNewInst(inst, IR::Opcode::CompositeConstructU32x2, {zero, zero})}); break; case IR::Opcode::LoadGlobal128: inst->ReplaceUsesWith(IR::Value{&*block.PrependNewInst( - inst, IR::Opcode::CompositeConstruct4, {zero, zero, zero, zero})}); + inst, IR::Opcode::CompositeConstructU32x4, {zero, zero, zero, zero})}); break; case IR::Opcode::WriteGlobalS8: case IR::Opcode::WriteGlobalU8: @@ -137,7 +137,8 @@ void IgnoreGlobalMemory(IR::Block& block, IR::Block::iterator inst) { inst->Invalidate(); break; default: - throw LogicError("Invalid opcode to ignore its global memory operation {}", inst->Opcode()); + throw LogicError("Invalid opcode to discard its global memory operation {}", + inst->Opcode()); } } @@ -196,7 +197,7 @@ void CollectStorageBuffers(IR::Block& block, IR::Block::iterator inst, storage_buffer = Track(addr, nullptr); if (!storage_buffer) { // If that also failed, drop the global memory usage - IgnoreGlobalMemory(block, inst); + DiscardGlobalMemory(block, inst); } } // Collect storage buffer and the instruction @@ -242,12 +243,12 @@ std::optional TrackLowAddress(IR::IREmitter& ir, IR::Inst* inst) { if (vector.IsImmediate()) { return std::nullopt; } - // This vector is expected to be a CompositeConstruct2 + // This vector is expected to be a CompositeConstructU32x2 IR::Inst* const vector_inst{vector.InstRecursive()}; - if (vector_inst->Opcode() != IR::Opcode::CompositeConstruct2) { + if (vector_inst->Opcode() != IR::Opcode::CompositeConstructU32x2) { return std::nullopt; } - // Grab the first argument from the CompositeConstruct2, this is the low address. + // Grab the first argument from the CompositeConstructU32x2, this is the low address. // Re-apply the offset in case we found one. const IR::U32 low_addr{vector_inst->Arg(0)}; return imm_offset != 0 ? IR::U32{ir.IAdd(low_addr, ir.Imm32(imm_offset))} : low_addr; diff --git a/src/shader_recompiler/main.cpp b/src/shader_recompiler/main.cpp index 4022c6fe2a..e6596d8287 100644 --- a/src/shader_recompiler/main.cpp +++ b/src/shader_recompiler/main.cpp @@ -52,7 +52,7 @@ int main() { // RunDatabase(); // FileEnvironment env{"D:\\Shaders\\Database\\test.bin"}; - FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"}; + FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS15C2FB1F0B965767.bin"}; auto cfg{std::make_unique(env, 0)}; // fmt::print(stdout, "{}\n", cfg->Dot());