From 3eec55abd016ac1c73279cabee2edff758a48e6b Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sun, 17 Sep 2023 17:51:14 -0700 Subject: [PATCH] ConstantEvaluator::swizzle: Handle vector concatenation. --- src/proc/constant_evaluator.rs | 78 ++++++++++++++++++++++++++++---- tests/in/const-exprs.wgsl | 5 ++ tests/out/hlsl/const-exprs.hlsl | 7 +++ tests/out/hlsl/const-exprs.ron | 8 ++++ tests/out/msl/const-exprs.msl | 13 ++++++ tests/out/spv/const-exprs.spvasm | 26 +++++++++++ tests/out/wgsl/const-exprs.wgsl | 6 +++ tests/snapshots.rs | 4 ++ 8 files changed, 138 insertions(+), 9 deletions(-) create mode 100644 tests/in/const-exprs.wgsl create mode 100644 tests/out/hlsl/const-exprs.hlsl create mode 100644 tests/out/hlsl/const-exprs.ron create mode 100644 tests/out/msl/const-exprs.msl create mode 100644 tests/out/spv/const-exprs.spvasm create mode 100644 tests/out/wgsl/const-exprs.wgsl diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 19b635072e..037b043df4 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -80,6 +80,8 @@ pub enum ConstantEvaluatorError { SplatScalarOnly, #[error("Can only swizzle vector constants")] SwizzleVectorOnly, + #[error("swizzle component not present in source expression")] + SwizzleOutOfBounds, #[error("Type is not constructible")] TypeNotConstructible, #[error("Subexpression(s) are not constant")] @@ -305,20 +307,31 @@ impl ConstantEvaluator<'_> { let expr = Expression::Splat { size, value }; Ok(self.register_constant(expr, span)) } - Expression::Compose { - ty, - components: ref src_components, - } => { + Expression::Compose { ty, ref components } => { let dst_ty = get_dst_ty(ty)?; - let components = pattern + let mut flattened = [src_constant; 4]; // dummy value + let len = self + .flatten_compose(ty, components) + .zip(flattened.iter_mut()) + .map(|(component, elt)| *elt = component) + .count(); + let flattened = &flattened[..len]; + + let swizzled_components = pattern[..size as usize] .iter() - .take(size as usize) - .map(|&sc| src_components[sc as usize]) - .collect(); + .map(|&sc| { + let sc = sc as usize; + if let Some(elt) = flattened.get(sc) { + Ok(*elt) + } else { + Err(ConstantEvaluatorError::SwizzleOutOfBounds) + } + }) + .collect::>, _>>()?; let expr = Expression::Compose { ty: dst_ty, - components, + components: swizzled_components, }; Ok(self.register_constant(expr, span)) } @@ -827,6 +840,53 @@ impl ConstantEvaluator<'_> { self.expressions.append(expr, span) } + + /// Return an iterator over the individual components assembled by a + /// `Compose` expression. + /// + /// Given `ty` and `components` from an `Expression::Compose`, return an + /// iterator over the components of the resulting value. + /// + /// Normally, this would just be an iterator over `components`. However, + /// `Compose` expressions can concatenate vectors, in which case the i'th + /// value being composed is not generally the i'th element of `components`. + /// This function consults `ty` to decide if this concatenation is occuring, + /// and returns an iterator that produces the components of the result of + /// the `Compose` expression in either case. + fn flatten_compose<'c>( + &'c self, + ty: Handle, + components: &'c [Handle], + ) -> impl Iterator> + 'c { + // Returning `impl Iterator` is a bit tricky. We may or may not want to + // flatten the components, but we have to settle on a single concrete + // type to return. The below is a single iterator chain that handles + // both the flattening and non-flattening cases. + let (size, is_vector) = if let TypeInner::Vector { size, .. } = self.types[ty].inner { + (size as usize, true) + } else { + (components.len(), false) + }; + + components + .iter() + .flat_map(move |component| { + if let ( + true, + &Expression::Compose { + ty: _, + components: ref subcomponents, + }, + ) = (is_vector, &self.expressions[*component]) + { + subcomponents + } else { + std::slice::from_ref(component) + } + }) + .take(size) + .cloned() + } } /// Helper function to implement the GLSL `max` function for floats. diff --git a/tests/in/const-exprs.wgsl b/tests/in/const-exprs.wgsl new file mode 100644 index 0000000000..27502783be --- /dev/null +++ b/tests/in/const-exprs.wgsl @@ -0,0 +1,5 @@ +fn f() -> vec4 { + let a = vec2(1, 2); + let b = vec2(3, 4); + return vec4(a, b).wzyx; +} diff --git a/tests/out/hlsl/const-exprs.hlsl b/tests/out/hlsl/const-exprs.hlsl new file mode 100644 index 0000000000..303f539fb1 --- /dev/null +++ b/tests/out/hlsl/const-exprs.hlsl @@ -0,0 +1,7 @@ +int4 f() +{ + int2 a = int2(1, 2); + int2 b = int2(3, 4); + return int4(4, 3, 2, 1); +} + diff --git a/tests/out/hlsl/const-exprs.ron b/tests/out/hlsl/const-exprs.ron new file mode 100644 index 0000000000..4d056ac29b --- /dev/null +++ b/tests/out/hlsl/const-exprs.ron @@ -0,0 +1,8 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ], +) diff --git a/tests/out/msl/const-exprs.msl b/tests/out/msl/const-exprs.msl new file mode 100644 index 0000000000..fa1842a5e8 --- /dev/null +++ b/tests/out/msl/const-exprs.msl @@ -0,0 +1,13 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + + +metal::int4 f( +) { + metal::int2 a = metal::int2(1, 2); + metal::int2 b = metal::int2(3, 4); + return metal::int4(4, 3, 2, 1); +} diff --git a/tests/out/spv/const-exprs.spvasm b/tests/out/spv/const-exprs.spvasm new file mode 100644 index 0000000000..c5a328bf47 --- /dev/null +++ b/tests/out/spv/const-exprs.spvasm @@ -0,0 +1,26 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 17 +OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +%2 = OpTypeVoid +%4 = OpTypeInt 32 1 +%3 = OpTypeVector %4 4 +%5 = OpTypeVector %4 2 +%8 = OpTypeFunction %3 +%9 = OpConstant %4 1 +%10 = OpConstant %4 2 +%11 = OpConstantComposite %5 %9 %10 +%12 = OpConstant %4 3 +%13 = OpConstant %4 4 +%14 = OpConstantComposite %5 %12 %13 +%15 = OpConstantComposite %3 %13 %12 %10 %9 +%7 = OpFunction %3 None %8 +%6 = OpLabel +OpBranch %16 +%16 = OpLabel +OpReturnValue %15 +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/const-exprs.wgsl b/tests/out/wgsl/const-exprs.wgsl new file mode 100644 index 0000000000..5ec55e6071 --- /dev/null +++ b/tests/out/wgsl/const-exprs.wgsl @@ -0,0 +1,6 @@ +fn f() -> vec4 { + let a = vec2(1, 2); + let b = vec2(3, 4); + return vec4(4, 3, 2, 1); +} + diff --git a/tests/snapshots.rs b/tests/snapshots.rs index dce0a7edf9..95a4137a8a 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -616,6 +616,10 @@ fn convert_wgsl() { "constructors", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ( + "const-exprs", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() {