diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index d698931f6d..342a3805f5 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -243,21 +243,24 @@ impl<'w> BlockContext<'w> { self.writer.constant_ids[init.index()] } crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), - crate::Expression::Compose { - ty: _, - ref components, - } => { + crate::Expression::Compose { ty, ref components } => { self.temp_list.clear(); - for &component in components { - self.temp_list.push(self.cached[component]); - } - if self.ir_function.expressions.is_const(expr_handle) { - let ty = self - .writer - .get_expression_lookup_type(&self.fun_info[expr_handle].ty); - self.writer.get_constant_composite(ty, &self.temp_list) + self.temp_list.extend( + crate::proc::flatten_compose( + ty, + components, + &self.ir_function.expressions, + &self.ir_module.types, + ) + .map(|component| self.cached[component]), + ); + self.writer + .get_constant_composite(LookupType::Handle(ty), &self.temp_list) } else { + self.temp_list + .extend(components.iter().map(|&component| self.cached[component])); + let id = self.gen_id(); block.body.push(Instruction::composite_construct( result_type_id, diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 7d79377786..0f4809565c 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1269,10 +1269,14 @@ impl Writer { self.get_constant_null(type_id) } crate::Expression::Compose { ty, ref components } => { - let component_ids: Vec<_> = components - .iter() - .map(|component| self.constant_ids[component.index()]) - .collect(); + let component_ids: Vec<_> = crate::proc::flatten_compose( + ty, + components, + &ir_module.const_expressions, + &ir_module.types, + ) + .map(|component| self.constant_ids[component.index()]) + .collect(); self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice()) } crate::Expression::Splat { size, value } => { diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 6b917c4d67..9d2055ff6a 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -73,6 +73,8 @@ pub enum ConstantEvaluatorError { SplatScalarOnly, #[error("Can only swizzle vector constants")] SwizzleVectorOnly, + #[error("swizzle component not present in source expression")] + SwizzleOutOfBounds, #[error("Type is not constructible")] TypeNotConstructible, #[error("Subexpression(s) are not constant")] @@ -306,20 +308,31 @@ impl ConstantEvaluator<'_> { let expr = Expression::Splat { size, value }; Ok(self.register_evaluated_expr(expr, span)) } - Expression::Compose { - ty, - components: ref src_components, - } => { + Expression::Compose { ty, ref components } => { let dst_ty = get_dst_ty(ty)?; - let components = pattern + let mut flattened = [src_constant; 4]; // dummy value + let len = + crate::proc::flatten_compose(ty, components, self.expressions, self.types) + .zip(flattened.iter_mut()) + .map(|(component, elt)| *elt = component) + .count(); + let flattened = &flattened[..len]; + + let swizzled_components = pattern[..size as usize] .iter() - .take(size as usize) - .map(|&sc| src_components[sc as usize]) - .collect(); + .map(|&sc| { + let sc = sc as usize; + if let Some(elt) = flattened.get(sc) { + Ok(*elt) + } else { + Err(ConstantEvaluatorError::SwizzleOutOfBounds) + } + }) + .collect::>, _>>()?; let expr = Expression::Compose { ty: dst_ty, - components, + components: swizzled_components, }; Ok(self.register_evaluated_expr(expr, span)) } @@ -455,9 +468,8 @@ impl ConstantEvaluator<'_> { .components() .ok_or(ConstantEvaluatorError::InvalidAccessBase)?; - components - .get(index) - .copied() + crate::proc::flatten_compose(ty, components, self.expressions, self.types) + .nth(index) .ok_or(ConstantEvaluatorError::InvalidAccessIndex) } _ => Err(ConstantEvaluatorError::InvalidAccessBase), diff --git a/src/proc/mod.rs b/src/proc/mod.rs index b654f5c4b2..cb08ce49e6 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -656,6 +656,61 @@ impl GlobalCtx<'_> { } } +/// Return an iterator over the individual components assembled by a +/// `Compose` expression. +/// +/// Given `ty` and `components` from an `Expression::Compose`, return an +/// iterator over the components of the resulting value. +/// +/// Normally, this would just be an iterator over `components`. However, +/// `Compose` expressions can concatenate vectors, in which case the i'th +/// value being composed is not generally the i'th element of `components`. +/// This function consults `ty` to decide if this concatenation is occuring, +/// and returns an iterator that produces the components of the result of +/// the `Compose` expression in either case. +pub fn flatten_compose<'arenas>( + ty: crate::Handle, + components: &'arenas [crate::Handle], + expressions: &'arenas crate::Arena, + types: &'arenas crate::UniqueArena, +) -> impl Iterator> + 'arenas { + // Returning `impl Iterator` is a bit tricky. We may or may not want to + // flatten the components, but we have to settle on a single concrete + // type to return. The below is a single iterator chain that handles + // both the flattening and non-flattening cases. + let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner { + (size as usize, true) + } else { + (components.len(), false) + }; + + fn flattener<'c>( + component: &'c crate::Handle, + is_vector: bool, + expressions: &'c crate::Arena, + ) -> &'c [crate::Handle] { + if is_vector { + if let crate::Expression::Compose { + ty: _, + components: ref subcomponents, + } = expressions[*component] + { + return subcomponents; + } + } + std::slice::from_ref(component) + } + + // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to flatten + // two levels. + components + .iter() + .flat_map(move |component| flattener(component, is_vector, expressions)) + .flat_map(move |component| flattener(component, is_vector, expressions)) + .take(size) + .cloned() +} + #[test] fn test_matrix_size() { let module = crate::Module::default(); diff --git a/tests/in/const-exprs.wgsl b/tests/in/const-exprs.wgsl new file mode 100644 index 0000000000..c89e61d499 --- /dev/null +++ b/tests/in/const-exprs.wgsl @@ -0,0 +1,14 @@ +@group(0) @binding(0) var out: vec4; +@group(0) @binding(1) var out2: i32; +@group(0) @binding(2) var out3: i32; + +@compute @workgroup_size(1) +fn main() { + let a = vec2(1, 2); + let b = vec2(3, 4); + out = vec4(a, b).wzyx; + + out2 = vec4(a, b)[1]; + + out3 = vec4(vec3(vec2(6, 7), 8), 9)[0]; +} diff --git a/tests/out/glsl/const-exprs.main.Compute.glsl b/tests/out/glsl/const-exprs.main.Compute.glsl new file mode 100644 index 0000000000..ff634004ca --- /dev/null +++ b/tests/out/glsl/const-exprs.main.Compute.glsl @@ -0,0 +1,23 @@ +#version 310 es + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +layout(std430) buffer type_block_0Compute { ivec4 _group_0_binding_0_cs; }; + +layout(std430) buffer type_1_block_1Compute { int _group_0_binding_1_cs; }; + +layout(std430) buffer type_1_block_2Compute { int _group_0_binding_2_cs; }; + + +void main() { + ivec2 a = ivec2(1, 2); + ivec2 b = ivec2(3, 4); + _group_0_binding_0_cs = ivec4(4, 3, 2, 1); + _group_0_binding_1_cs = 2; + _group_0_binding_2_cs = 6; + return; +} + diff --git a/tests/out/hlsl/const-exprs.hlsl b/tests/out/hlsl/const-exprs.hlsl new file mode 100644 index 0000000000..f6faee1d40 --- /dev/null +++ b/tests/out/hlsl/const-exprs.hlsl @@ -0,0 +1,14 @@ +RWByteAddressBuffer out_ : register(u0); +RWByteAddressBuffer out2_ : register(u1); +RWByteAddressBuffer out3_ : register(u2); + +[numthreads(1, 1, 1)] +void main() +{ + int2 a = int2(1, 2); + int2 b = int2(3, 4); + out_.Store4(0, asuint(int4(4, 3, 2, 1))); + out2_.Store(0, asuint(2)); + out3_.Store(0, asuint(6)); + return; +} diff --git a/tests/out/hlsl/const-exprs.ron b/tests/out/hlsl/const-exprs.ron new file mode 100644 index 0000000000..a07b03300b --- /dev/null +++ b/tests/out/hlsl/const-exprs.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_5_1", + ), + ], +) diff --git a/tests/out/msl/const-exprs.msl b/tests/out/msl/const-exprs.msl new file mode 100644 index 0000000000..19b9b727fb --- /dev/null +++ b/tests/out/msl/const-exprs.msl @@ -0,0 +1,19 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + + +kernel void main_( + device metal::int4& out [[user(fake0)]] +, device int& out2_ [[user(fake0)]] +, device int& out3_ [[user(fake0)]] +) { + metal::int2 a = metal::int2(1, 2); + metal::int2 b = metal::int2(3, 4); + out = metal::int4(4, 3, 2, 1); + out2_ = 2; + out3_ = 6; + return; +} diff --git a/tests/out/spv/const-exprs.spvasm b/tests/out/spv/const-exprs.spvasm new file mode 100644 index 0000000000..23f7d242eb --- /dev/null +++ b/tests/out/spv/const-exprs.spvasm @@ -0,0 +1,60 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 34 +OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %16 "main" +OpExecutionMode %16 LocalSize 1 1 1 +OpDecorate %6 DescriptorSet 0 +OpDecorate %6 Binding 0 +OpDecorate %7 Block +OpMemberDecorate %7 0 Offset 0 +OpDecorate %9 DescriptorSet 0 +OpDecorate %9 Binding 1 +OpDecorate %10 Block +OpMemberDecorate %10 0 Offset 0 +OpDecorate %12 DescriptorSet 0 +OpDecorate %12 Binding 2 +OpDecorate %13 Block +OpMemberDecorate %13 0 Offset 0 +%2 = OpTypeVoid +%4 = OpTypeInt 32 1 +%3 = OpTypeVector %4 4 +%5 = OpTypeVector %4 2 +%7 = OpTypeStruct %3 +%8 = OpTypePointer StorageBuffer %7 +%6 = OpVariable %8 StorageBuffer +%10 = OpTypeStruct %4 +%11 = OpTypePointer StorageBuffer %10 +%9 = OpVariable %11 StorageBuffer +%13 = OpTypeStruct %4 +%14 = OpTypePointer StorageBuffer %13 +%12 = OpVariable %14 StorageBuffer +%17 = OpTypeFunction %2 +%18 = OpTypePointer StorageBuffer %3 +%20 = OpTypeInt 32 0 +%19 = OpConstant %20 0 +%22 = OpTypePointer StorageBuffer %4 +%25 = OpConstant %4 1 +%26 = OpConstant %4 2 +%27 = OpConstantComposite %5 %25 %26 +%28 = OpConstant %4 3 +%29 = OpConstant %4 4 +%30 = OpConstantComposite %5 %28 %29 +%31 = OpConstantComposite %3 %29 %28 %26 %25 +%32 = OpConstant %4 6 +%16 = OpFunction %2 None %17 +%15 = OpLabel +%21 = OpAccessChain %18 %6 %19 +%23 = OpAccessChain %22 %9 %19 +%24 = OpAccessChain %22 %12 %19 +OpBranch %33 +%33 = OpLabel +OpStore %21 %31 +OpStore %23 %26 +OpStore %24 %32 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/const-exprs.wgsl b/tests/out/wgsl/const-exprs.wgsl new file mode 100644 index 0000000000..201535836e --- /dev/null +++ b/tests/out/wgsl/const-exprs.wgsl @@ -0,0 +1,16 @@ +@group(0) @binding(0) +var out: vec4; +@group(0) @binding(1) +var out2_: i32; +@group(0) @binding(2) +var out3_: i32; + +@compute @workgroup_size(1, 1, 1) +fn main() { + let a = vec2(1, 2); + let b = vec2(3, 4); + out = vec4(4, 3, 2, 1); + out2_ = 2; + out3_ = 6; + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 495bea9598..2bc7f45444 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -777,6 +777,10 @@ fn convert_wgsl() { Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), ("msl-varyings", Targets::METAL), + ( + "const-exprs", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() {