diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 19b635072e..b2e0c8346d 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -80,6 +80,8 @@ pub enum ConstantEvaluatorError { SplatScalarOnly, #[error("Can only swizzle vector constants")] SwizzleVectorOnly, + #[error("swizzle component not present in source expression")] + SwizzleOutOfBounds, #[error("Type is not constructible")] TypeNotConstructible, #[error("Subexpression(s) are not constant")] @@ -305,20 +307,31 @@ impl ConstantEvaluator<'_> { let expr = Expression::Splat { size, value }; Ok(self.register_constant(expr, span)) } - Expression::Compose { - ty, - components: ref src_components, - } => { + Expression::Compose { ty, ref components } => { let dst_ty = get_dst_ty(ty)?; - let components = pattern + let mut flattened = [src_constant; 4]; // dummy value + let len = self + .flatten_compose(ty, components) + .zip(flattened.iter_mut()) + .map(|(component, elt)| *elt = component) + .count(); + let flattened = &flattened[..len]; + + let swizzled_components = pattern[..size as usize] .iter() - .take(size as usize) - .map(|&sc| src_components[sc as usize]) - .collect(); + .map(|&sc| { + let sc = sc as usize; + if let Some(elt) = flattened.get(sc) { + Ok(*elt) + } else { + Err(ConstantEvaluatorError::SwizzleOutOfBounds) + } + }) + .collect::>, _>>()?; let expr = Expression::Compose { ty: dst_ty, - components, + components: swizzled_components, }; Ok(self.register_constant(expr, span)) } @@ -454,9 +467,8 @@ impl ConstantEvaluator<'_> { .components() .ok_or(ConstantEvaluatorError::InvalidAccessBase)?; - components - .get(index) - .copied() + self.flatten_compose(ty, components) + .nth(index) .ok_or(ConstantEvaluatorError::InvalidAccessIndex) } _ => Err(ConstantEvaluatorError::InvalidAccessBase), @@ -827,6 +839,53 @@ impl ConstantEvaluator<'_> { self.expressions.append(expr, span) } + + /// Return an iterator over the individual components assembled by a + /// `Compose` expression. + /// + /// Given `ty` and `components` from an `Expression::Compose`, return an + /// iterator over the components of the resulting value. + /// + /// Normally, this would just be an iterator over `components`. However, + /// `Compose` expressions can concatenate vectors, in which case the i'th + /// value being composed is not generally the i'th element of `components`. + /// This function consults `ty` to decide if this concatenation is occuring, + /// and returns an iterator that produces the components of the result of + /// the `Compose` expression in either case. + fn flatten_compose<'c>( + &'c self, + ty: Handle, + components: &'c [Handle], + ) -> impl Iterator> + 'c { + // Returning `impl Iterator` is a bit tricky. We may or may not want to + // flatten the components, but we have to settle on a single concrete + // type to return. The below is a single iterator chain that handles + // both the flattening and non-flattening cases. + let (size, is_vector) = if let TypeInner::Vector { size, .. } = self.types[ty].inner { + (size as usize, true) + } else { + (components.len(), false) + }; + + components + .iter() + .flat_map(move |component| { + if let ( + true, + &Expression::Compose { + ty: _, + components: ref subcomponents, + }, + ) = (is_vector, &self.expressions[*component]) + { + subcomponents + } else { + std::slice::from_ref(component) + } + }) + .take(size) + .cloned() + } } /// Helper function to implement the GLSL `max` function for floats. diff --git a/tests/in/const-exprs.wgsl b/tests/in/const-exprs.wgsl new file mode 100644 index 0000000000..d51deee9ef --- /dev/null +++ b/tests/in/const-exprs.wgsl @@ -0,0 +1,14 @@ +@group(0) @binding(0) +var out: vec4; + +@group(0) @binding(1) +var out2: i32; + +@compute @workgroup_size(1) +fn main() { + let a = vec2(1, 2); + let b = vec2(3, 4); + out = vec4(a, b).wzyx; + + out2 = vec4(a, b)[1]; +} diff --git a/tests/out/glsl/const-exprs.main.Compute.glsl b/tests/out/glsl/const-exprs.main.Compute.glsl new file mode 100644 index 0000000000..9918cd68ab --- /dev/null +++ b/tests/out/glsl/const-exprs.main.Compute.glsl @@ -0,0 +1,20 @@ +#version 310 es + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +layout(std430) buffer type_block_0Compute { ivec4 _group_0_binding_0_cs; }; + +layout(std430) buffer type_1_block_1Compute { int _group_0_binding_1_cs; }; + + +void main() { + ivec2 a = ivec2(1, 2); + ivec2 b = ivec2(3, 4); + _group_0_binding_0_cs = ivec4(4, 3, 2, 1); + _group_0_binding_1_cs = 2; + return; +} + diff --git a/tests/out/hlsl/const-exprs.hlsl b/tests/out/hlsl/const-exprs.hlsl new file mode 100644 index 0000000000..f483a9b56d --- /dev/null +++ b/tests/out/hlsl/const-exprs.hlsl @@ -0,0 +1,12 @@ +RWByteAddressBuffer out_ : register(u0); +RWByteAddressBuffer out2_ : register(u1); + +[numthreads(1, 1, 1)] +void main() +{ + int2 a = int2(1, 2); + int2 b = int2(3, 4); + out_.Store4(0, asuint(int4(4, 3, 2, 1))); + out2_.Store(0, asuint(2)); + return; +} diff --git a/tests/out/hlsl/const-exprs.ron b/tests/out/hlsl/const-exprs.ron new file mode 100644 index 0000000000..a07b03300b --- /dev/null +++ b/tests/out/hlsl/const-exprs.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_5_1", + ), + ], +) diff --git a/tests/out/msl/const-exprs.msl b/tests/out/msl/const-exprs.msl new file mode 100644 index 0000000000..d8f38b4fc2 --- /dev/null +++ b/tests/out/msl/const-exprs.msl @@ -0,0 +1,17 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + + +kernel void main_( + device metal::int4& out [[user(fake0)]] +, device int& out2_ [[user(fake0)]] +) { + metal::int2 a = metal::int2(1, 2); + metal::int2 b = metal::int2(3, 4); + out = metal::int4(4, 3, 2, 1); + out2_ = 2; + return; +} diff --git a/tests/out/spv/const-exprs.spvasm b/tests/out/spv/const-exprs.spvasm new file mode 100644 index 0000000000..fa9fb4fd51 --- /dev/null +++ b/tests/out/spv/const-exprs.spvasm @@ -0,0 +1,51 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 30 +OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %15 "main" +OpExecutionMode %15 LocalSize 1 1 1 +OpDecorate %8 DescriptorSet 0 +OpDecorate %8 Binding 0 +OpDecorate %9 Block +OpMemberDecorate %9 0 Offset 0 +OpDecorate %11 DescriptorSet 0 +OpDecorate %11 Binding 1 +OpDecorate %12 Block +OpMemberDecorate %12 0 Offset 0 +%2 = OpTypeVoid +%4 = OpTypeInt 32 1 +%3 = OpTypeVector %4 4 +%5 = OpTypeVector %4 2 +%6 = OpConstant %4 0 +%7 = OpConstant %4 1 +%9 = OpTypeStruct %3 +%10 = OpTypePointer StorageBuffer %9 +%8 = OpVariable %10 StorageBuffer +%12 = OpTypeStruct %4 +%13 = OpTypePointer StorageBuffer %12 +%11 = OpVariable %13 StorageBuffer +%16 = OpTypeFunction %2 +%17 = OpTypePointer StorageBuffer %3 +%19 = OpTypeInt 32 0 +%18 = OpConstant %19 0 +%21 = OpTypePointer StorageBuffer %4 +%23 = OpConstant %4 2 +%24 = OpConstantComposite %5 %7 %23 +%25 = OpConstant %4 3 +%26 = OpConstant %4 4 +%27 = OpConstantComposite %5 %25 %26 +%28 = OpConstantComposite %3 %26 %25 %23 %7 +%15 = OpFunction %2 None %16 +%14 = OpLabel +%20 = OpAccessChain %17 %8 %18 +%22 = OpAccessChain %21 %11 %18 +OpBranch %29 +%29 = OpLabel +OpStore %20 %28 +OpStore %22 %23 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/const-exprs.wgsl b/tests/out/wgsl/const-exprs.wgsl new file mode 100644 index 0000000000..b58339b8af --- /dev/null +++ b/tests/out/wgsl/const-exprs.wgsl @@ -0,0 +1,13 @@ +@group(0) @binding(0) +var out: vec4; +@group(0) @binding(1) +var out2_: i32; + +@compute @workgroup_size(1, 1, 1) +fn main() { + let a = vec2(1, 2); + let b = vec2(3, 4); + out = vec4(4, 3, 2, 1); + out2_ = 2; + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index dce0a7edf9..95a4137a8a 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -616,6 +616,10 @@ fn convert_wgsl() { "constructors", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ( + "const-exprs", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() {