From 3bd2834b4f70c9d9a95567278cfb7f2c5512ad63 Mon Sep 17 00:00:00 2001 From: Fredrik Fornwall Date: Tue, 29 Aug 2023 21:34:55 +0200 Subject: [PATCH] [wgsl-in] Handle all(bool) and any(bool) (#2445) Fixes #1911. --- src/front/wgsl/lower/mod.rs | 21 +++- tests/in/standard.wgsl | 8 ++ .../glsl/standard.derivatives.Fragment.glsl | 13 ++- tests/out/hlsl/standard.hlsl | 14 ++- tests/out/msl/standard.msl | 14 ++- tests/out/spv/standard.spvasm | 98 ++++++++++--------- tests/out/wgsl/standard.wgsl | 13 ++- 7 files changed, 120 insertions(+), 61 deletions(-) diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 4c5acf0ad9..23d3282364 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -1696,7 +1696,26 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let argument = self.expression(args.next()?, ctx.reborrow())?; args.finish()?; - crate::Expression::Relational { fun, argument } + // Check for no-op all(bool) and any(bool): + let argument_unmodified = matches!( + fun, + crate::RelationalFunction::All | crate::RelationalFunction::Any + ) && { + ctx.grow_types(argument)?; + matches!( + ctx.resolved_inner(argument), + &crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + .. + } + ) + }; + + if argument_unmodified { + return Ok(Some(argument)); + } else { + crate::Expression::Relational { fun, argument } + } } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) { let mut args = ctx.prepare_args(arguments, 1, span); let expr = self.expression(args.next()?, ctx.reborrow())?; diff --git a/tests/in/standard.wgsl b/tests/in/standard.wgsl index 79f5632cea..9fdc344bc9 100644 --- a/tests/in/standard.wgsl +++ b/tests/in/standard.wgsl @@ -1,5 +1,11 @@ // Standard functions. +fn test_any_and_all_for_bool() -> bool { + let a = any(true); + return all(a); +} + + @fragment fn derivatives(@builtin(position) foo: vec4) -> @location(0) vec4 { var x = dpdxCoarse(foo); @@ -14,5 +20,7 @@ fn derivatives(@builtin(position) foo: vec4) -> @location(0) vec4 { y = dpdy(foo); z = fwidth(foo); + let a = test_any_and_all_for_bool(); + return (x + y) * z; } diff --git a/tests/out/glsl/standard.derivatives.Fragment.glsl b/tests/out/glsl/standard.derivatives.Fragment.glsl index 331ffb5c01..18e150a3fe 100644 --- a/tests/out/glsl/standard.derivatives.Fragment.glsl +++ b/tests/out/glsl/standard.derivatives.Fragment.glsl @@ -5,6 +5,10 @@ precision highp int; layout(location = 0) out vec4 _fs2p_location0; +bool test_any_and_all_for_bool() { + return true; +} + void main() { vec4 foo = gl_FragCoord; vec4 x = vec4(0.0); @@ -28,10 +32,11 @@ void main() { y = _e11; vec4 _e12 = fwidth(foo); z = _e12; - vec4 _e13 = x; - vec4 _e14 = y; - vec4 _e16 = z; - _fs2p_location0 = ((_e13 + _e14) * _e16); + bool _e13 = test_any_and_all_for_bool(); + vec4 _e14 = x; + vec4 _e15 = y; + vec4 _e17 = z; + _fs2p_location0 = ((_e14 + _e15) * _e17); return; } diff --git a/tests/out/hlsl/standard.hlsl b/tests/out/hlsl/standard.hlsl index a2bcc70ec4..d3fd537ebe 100644 --- a/tests/out/hlsl/standard.hlsl +++ b/tests/out/hlsl/standard.hlsl @@ -2,6 +2,11 @@ struct FragmentInput_derivatives { float4 foo_1 : SV_Position; }; +bool test_any_and_all_for_bool() +{ + return true; +} + float4 derivatives(FragmentInput_derivatives fragmentinput_derivatives) : SV_Target0 { float4 foo = fragmentinput_derivatives.foo_1; @@ -27,8 +32,9 @@ float4 derivatives(FragmentInput_derivatives fragmentinput_derivatives) : SV_Tar y = _expr11; float4 _expr12 = fwidth(foo); z = _expr12; - float4 _expr13 = x; - float4 _expr14 = y; - float4 _expr16 = z; - return ((_expr13 + _expr14) * _expr16); + const bool _e13 = test_any_and_all_for_bool(); + float4 _expr14 = x; + float4 _expr15 = y; + float4 _expr17 = z; + return ((_expr14 + _expr15) * _expr17); } diff --git a/tests/out/msl/standard.msl b/tests/out/msl/standard.msl index bca5f0cb00..f02243eaac 100644 --- a/tests/out/msl/standard.msl +++ b/tests/out/msl/standard.msl @@ -5,6 +5,11 @@ using metal::uint; +bool test_any_and_all_for_bool( +) { + return true; +} + struct derivativesInput { }; struct derivativesOutput { @@ -34,8 +39,9 @@ fragment derivativesOutput derivatives( y = _e11; metal::float4 _e12 = metal::fwidth(foo); z = _e12; - metal::float4 _e13 = x; - metal::float4 _e14 = y; - metal::float4 _e16 = z; - return derivativesOutput { (_e13 + _e14) * _e16 }; + bool _e13 = test_any_and_all_for_bool(); + metal::float4 _e14 = x; + metal::float4 _e15 = y; + metal::float4 _e17 = z; + return derivativesOutput { (_e14 + _e15) * _e17 }; } diff --git a/tests/out/spv/standard.spvasm b/tests/out/spv/standard.spvasm index 6d9445808b..07bb2a7908 100644 --- a/tests/out/spv/standard.spvasm +++ b/tests/out/spv/standard.spvasm @@ -1,56 +1,66 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 33 +; Bound: 40 OpCapability Shader OpCapability DerivativeControl %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %16 "derivatives" %11 %14 -OpExecutionMode %16 OriginUpperLeft -OpDecorate %11 BuiltIn FragCoord -OpDecorate %14 Location 0 +OpEntryPoint Fragment %22 "derivatives" %17 %20 +OpExecutionMode %22 OriginUpperLeft +OpDecorate %17 BuiltIn FragCoord +OpDecorate %20 Location 0 %2 = OpTypeVoid -%4 = OpTypeFloat 32 -%3 = OpTypeVector %4 4 -%6 = OpTypePointer Function %3 -%7 = OpConstantNull %3 -%12 = OpTypePointer Input %3 -%11 = OpVariable %12 Input -%15 = OpTypePointer Output %3 -%14 = OpVariable %15 Output -%17 = OpTypeFunction %2 -%16 = OpFunction %2 None %17 +%3 = OpTypeBool +%5 = OpTypeFloat 32 +%4 = OpTypeVector %5 4 +%8 = OpTypeFunction %3 +%9 = OpConstantTrue %3 +%12 = OpTypePointer Function %4 +%13 = OpConstantNull %4 +%18 = OpTypePointer Input %4 +%17 = OpVariable %18 Input +%21 = OpTypePointer Output %4 +%20 = OpVariable %21 Output +%23 = OpTypeFunction %2 +%7 = OpFunction %3 None %8 +%6 = OpLabel +OpBranch %10 %10 = OpLabel -%5 = OpVariable %6 Function %7 -%8 = OpVariable %6 Function %7 -%9 = OpVariable %6 Function %7 -%13 = OpLoad %3 %11 -OpBranch %18 -%18 = OpLabel -%19 = OpDPdxCoarse %3 %13 -OpStore %5 %19 -%20 = OpDPdyCoarse %3 %13 -OpStore %8 %20 -%21 = OpFwidthCoarse %3 %13 -OpStore %9 %21 -%22 = OpDPdxFine %3 %13 -OpStore %5 %22 -%23 = OpDPdyFine %3 %13 -OpStore %8 %23 -%24 = OpFwidthFine %3 %13 -OpStore %9 %24 -%25 = OpDPdx %3 %13 -OpStore %5 %25 -%26 = OpDPdy %3 %13 -OpStore %8 %26 -%27 = OpFwidth %3 %13 -OpStore %9 %27 -%28 = OpLoad %3 %5 -%29 = OpLoad %3 %8 -%30 = OpFAdd %3 %28 %29 -%31 = OpLoad %3 %9 -%32 = OpFMul %3 %30 %31 +OpReturnValue %9 +OpFunctionEnd +%22 = OpFunction %2 None %23 +%16 = OpLabel +%11 = OpVariable %12 Function %13 +%14 = OpVariable %12 Function %13 +%15 = OpVariable %12 Function %13 +%19 = OpLoad %4 %17 +OpBranch %24 +%24 = OpLabel +%25 = OpDPdxCoarse %4 %19 +OpStore %11 %25 +%26 = OpDPdyCoarse %4 %19 +OpStore %14 %26 +%27 = OpFwidthCoarse %4 %19 +OpStore %15 %27 +%28 = OpDPdxFine %4 %19 +OpStore %11 %28 +%29 = OpDPdyFine %4 %19 +OpStore %14 %29 +%30 = OpFwidthFine %4 %19 +OpStore %15 %30 +%31 = OpDPdx %4 %19 +OpStore %11 %31 +%32 = OpDPdy %4 %19 OpStore %14 %32 +%33 = OpFwidth %4 %19 +OpStore %15 %33 +%34 = OpFunctionCall %3 %7 +%35 = OpLoad %4 %11 +%36 = OpLoad %4 %14 +%37 = OpFAdd %4 %35 %36 +%38 = OpLoad %4 %15 +%39 = OpFMul %4 %37 %38 +OpStore %20 %39 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/standard.wgsl b/tests/out/wgsl/standard.wgsl index 80e8f24989..886cf09193 100644 --- a/tests/out/wgsl/standard.wgsl +++ b/tests/out/wgsl/standard.wgsl @@ -1,3 +1,7 @@ +fn test_any_and_all_for_bool() -> bool { + return true; +} + @fragment fn derivatives(@builtin(position) foo: vec4) -> @location(0) vec4 { var x: vec4; @@ -22,8 +26,9 @@ fn derivatives(@builtin(position) foo: vec4) -> @location(0) vec4 { y = _e11; let _e12 = fwidth(foo); z = _e12; - let _e13 = x; - let _e14 = y; - let _e16 = z; - return ((_e13 + _e14) * _e16); + let _e13 = test_any_and_all_for_bool(); + let _e14 = x; + let _e15 = y; + let _e17 = z; + return ((_e14 + _e15) * _e17); }