From 4d94474a70f59d236146c71dbc34d3fe7dcc478a Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sun, 17 Sep 2023 17:51:14 -0700 Subject: [PATCH] ConstantEvaluator::swizzle: Handle vector concatenation. --- src/proc/constant_evaluator.rs | 78 +++++++++++++++++--- tests/in/const-exprs.wgsl | 9 +++ tests/out/glsl/const-exprs.main.Compute.glsl | 17 +++++ tests/out/hlsl/const-exprs.hlsl | 10 +++ tests/out/hlsl/const-exprs.ron | 12 +++ tests/out/msl/const-exprs.msl | 15 ++++ tests/out/spv/const-exprs.spvasm | 41 ++++++++++ tests/out/wgsl/const-exprs.wgsl | 10 +++ tests/snapshots.rs | 4 + 9 files changed, 187 insertions(+), 9 deletions(-) create mode 100644 tests/in/const-exprs.wgsl create mode 100644 tests/out/glsl/const-exprs.main.Compute.glsl create mode 100644 tests/out/hlsl/const-exprs.hlsl create mode 100644 tests/out/hlsl/const-exprs.ron create mode 100644 tests/out/msl/const-exprs.msl create mode 100644 tests/out/spv/const-exprs.spvasm create mode 100644 tests/out/wgsl/const-exprs.wgsl diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 19b635072e..037b043df4 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -80,6 +80,8 @@ pub enum ConstantEvaluatorError { SplatScalarOnly, #[error("Can only swizzle vector constants")] SwizzleVectorOnly, + #[error("swizzle component not present in source expression")] + SwizzleOutOfBounds, #[error("Type is not constructible")] TypeNotConstructible, #[error("Subexpression(s) are not constant")] @@ -305,20 +307,31 @@ impl ConstantEvaluator<'_> { let expr = Expression::Splat { size, value }; Ok(self.register_constant(expr, span)) } - Expression::Compose { - ty, - components: ref src_components, - } => { + Expression::Compose { ty, ref components } => { let dst_ty = get_dst_ty(ty)?; - let components = pattern + let mut flattened = [src_constant; 4]; // dummy value + let len = self + .flatten_compose(ty, components) + .zip(flattened.iter_mut()) + .map(|(component, elt)| *elt = component) + .count(); + let flattened = &flattened[..len]; + + let swizzled_components = pattern[..size as usize] .iter() - .take(size as usize) - .map(|&sc| src_components[sc as usize]) - .collect(); + .map(|&sc| { + let sc = sc as usize; + if let Some(elt) = flattened.get(sc) { + Ok(*elt) + } else { + Err(ConstantEvaluatorError::SwizzleOutOfBounds) + } + }) + .collect::>, _>>()?; let expr = Expression::Compose { ty: dst_ty, - components, + components: swizzled_components, }; Ok(self.register_constant(expr, span)) } @@ -827,6 +840,53 @@ impl ConstantEvaluator<'_> { self.expressions.append(expr, span) } + + /// Return an iterator over the individual components assembled by a + /// `Compose` expression. + /// + /// Given `ty` and `components` from an `Expression::Compose`, return an + /// iterator over the components of the resulting value. + /// + /// Normally, this would just be an iterator over `components`. However, + /// `Compose` expressions can concatenate vectors, in which case the i'th + /// value being composed is not generally the i'th element of `components`. + /// This function consults `ty` to decide if this concatenation is occuring, + /// and returns an iterator that produces the components of the result of + /// the `Compose` expression in either case. + fn flatten_compose<'c>( + &'c self, + ty: Handle, + components: &'c [Handle], + ) -> impl Iterator> + 'c { + // Returning `impl Iterator` is a bit tricky. We may or may not want to + // flatten the components, but we have to settle on a single concrete + // type to return. The below is a single iterator chain that handles + // both the flattening and non-flattening cases. + let (size, is_vector) = if let TypeInner::Vector { size, .. } = self.types[ty].inner { + (size as usize, true) + } else { + (components.len(), false) + }; + + components + .iter() + .flat_map(move |component| { + if let ( + true, + &Expression::Compose { + ty: _, + components: ref subcomponents, + }, + ) = (is_vector, &self.expressions[*component]) + { + subcomponents + } else { + std::slice::from_ref(component) + } + }) + .take(size) + .cloned() + } } /// Helper function to implement the GLSL `max` function for floats. diff --git a/tests/in/const-exprs.wgsl b/tests/in/const-exprs.wgsl new file mode 100644 index 0000000000..2396698627 --- /dev/null +++ b/tests/in/const-exprs.wgsl @@ -0,0 +1,9 @@ +@group(0) @binding(0) +var out: vec4; + +@compute @workgroup_size(1) +fn main() { + let a = vec2(1, 2); + let b = vec2(3, 4); + out = vec4(a, b).wzyx; +} diff --git a/tests/out/glsl/const-exprs.main.Compute.glsl b/tests/out/glsl/const-exprs.main.Compute.glsl new file mode 100644 index 0000000000..b23764df2b --- /dev/null +++ b/tests/out/glsl/const-exprs.main.Compute.glsl @@ -0,0 +1,17 @@ +#version 310 es + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +layout(std430) buffer type_block_0Compute { ivec4 _group_0_binding_0_cs; }; + + +void main() { + ivec2 a = ivec2(1, 2); + ivec2 b = ivec2(3, 4); + _group_0_binding_0_cs = ivec4(4, 3, 2, 1); + return; +} + diff --git a/tests/out/hlsl/const-exprs.hlsl b/tests/out/hlsl/const-exprs.hlsl new file mode 100644 index 0000000000..220380fae9 --- /dev/null +++ b/tests/out/hlsl/const-exprs.hlsl @@ -0,0 +1,10 @@ +RWByteAddressBuffer out_ : register(u0); + +[numthreads(1, 1, 1)] +void main() +{ + int2 a = int2(1, 2); + int2 b = int2(3, 4); + out_.Store4(0, asuint(int4(4, 3, 2, 1))); + return; +} diff --git a/tests/out/hlsl/const-exprs.ron b/tests/out/hlsl/const-exprs.ron new file mode 100644 index 0000000000..a07b03300b --- /dev/null +++ b/tests/out/hlsl/const-exprs.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_5_1", + ), + ], +) diff --git a/tests/out/msl/const-exprs.msl b/tests/out/msl/const-exprs.msl new file mode 100644 index 0000000000..19f9293dc2 --- /dev/null +++ b/tests/out/msl/const-exprs.msl @@ -0,0 +1,15 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + + +kernel void main_( + device metal::int4& out [[user(fake0)]] +) { + metal::int2 a = metal::int2(1, 2); + metal::int2 b = metal::int2(3, 4); + out = metal::int4(4, 3, 2, 1); + return; +} diff --git a/tests/out/spv/const-exprs.spvasm b/tests/out/spv/const-exprs.spvasm new file mode 100644 index 0000000000..f0d12fdd2b --- /dev/null +++ b/tests/out/spv/const-exprs.spvasm @@ -0,0 +1,41 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 25 +OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %12 "main" +OpExecutionMode %12 LocalSize 1 1 1 +OpDecorate %8 DescriptorSet 0 +OpDecorate %8 Binding 0 +OpDecorate %9 Block +OpMemberDecorate %9 0 Offset 0 +%2 = OpTypeVoid +%4 = OpTypeInt 32 1 +%3 = OpTypeVector %4 4 +%5 = OpTypeVector %4 2 +%6 = OpConstant %4 0 +%7 = OpConstant %4 1 +%9 = OpTypeStruct %3 +%10 = OpTypePointer StorageBuffer %9 +%8 = OpVariable %10 StorageBuffer +%13 = OpTypeFunction %2 +%14 = OpTypePointer StorageBuffer %3 +%16 = OpTypeInt 32 0 +%15 = OpConstant %16 0 +%18 = OpConstant %4 2 +%19 = OpConstantComposite %5 %7 %18 +%20 = OpConstant %4 3 +%21 = OpConstant %4 4 +%22 = OpConstantComposite %5 %20 %21 +%23 = OpConstantComposite %3 %21 %20 %18 %7 +%12 = OpFunction %2 None %13 +%11 = OpLabel +%17 = OpAccessChain %14 %8 %15 +OpBranch %24 +%24 = OpLabel +OpStore %17 %23 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/const-exprs.wgsl b/tests/out/wgsl/const-exprs.wgsl new file mode 100644 index 0000000000..d368ba7fb8 --- /dev/null +++ b/tests/out/wgsl/const-exprs.wgsl @@ -0,0 +1,10 @@ +@group(0) @binding(0) +var out: vec4; + +@compute @workgroup_size(1, 1, 1) +fn main() { + let a = vec2(1, 2); + let b = vec2(3, 4); + out = vec4(4, 3, 2, 1); + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index dce0a7edf9..95a4137a8a 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -616,6 +616,10 @@ fn convert_wgsl() { "constructors", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ( + "const-exprs", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() {