diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index b2e0c8346d..0f4b039b18 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -867,22 +867,29 @@ impl ConstantEvaluator<'_> { (components.len(), false) }; - components - .iter() - .flat_map(move |component| { - if let ( - true, - &Expression::Compose { - ty: _, - components: ref subcomponents, - }, - ) = (is_vector, &self.expressions[*component]) + fn flattener<'c>( + component: &'c Handle, + is_vector: bool, + expressions: &'c Arena, + ) -> &'c [Handle] { + if is_vector { + if let Expression::Compose { + ty: _, + components: ref subcomponents, + } = expressions[*component] { - subcomponents - } else { - std::slice::from_ref(component) + return subcomponents; } - }) + } + std::slice::from_ref(component) + } + + // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to flatten + // two levels. + components + .iter() + .flat_map(move |component| flattener(component, is_vector, self.expressions)) + .flat_map(move |component| flattener(component, is_vector, self.expressions)) .take(size) .cloned() } diff --git a/tests/in/const-exprs.wgsl b/tests/in/const-exprs.wgsl index d51deee9ef..c89e61d499 100644 --- a/tests/in/const-exprs.wgsl +++ b/tests/in/const-exprs.wgsl @@ -1,8 +1,6 @@ -@group(0) @binding(0) -var out: vec4; - -@group(0) @binding(1) -var out2: i32; +@group(0) @binding(0) var out: vec4; +@group(0) @binding(1) var out2: i32; +@group(0) @binding(2) var out3: i32; @compute @workgroup_size(1) fn main() { @@ -11,4 +9,6 @@ fn main() { out = vec4(a, b).wzyx; out2 = vec4(a, b)[1]; + + out3 = vec4(vec3(vec2(6, 7), 8), 9)[0]; } diff --git a/tests/out/glsl/const-exprs.main.Compute.glsl b/tests/out/glsl/const-exprs.main.Compute.glsl index 9918cd68ab..ff634004ca 100644 --- a/tests/out/glsl/const-exprs.main.Compute.glsl +++ b/tests/out/glsl/const-exprs.main.Compute.glsl @@ -9,12 +9,15 @@ layout(std430) buffer type_block_0Compute { ivec4 _group_0_binding_0_cs; }; layout(std430) buffer type_1_block_1Compute { int _group_0_binding_1_cs; }; +layout(std430) buffer type_1_block_2Compute { int _group_0_binding_2_cs; }; + void main() { ivec2 a = ivec2(1, 2); ivec2 b = ivec2(3, 4); _group_0_binding_0_cs = ivec4(4, 3, 2, 1); _group_0_binding_1_cs = 2; + _group_0_binding_2_cs = 6; return; } diff --git a/tests/out/hlsl/const-exprs.hlsl b/tests/out/hlsl/const-exprs.hlsl index f483a9b56d..f6faee1d40 100644 --- a/tests/out/hlsl/const-exprs.hlsl +++ b/tests/out/hlsl/const-exprs.hlsl @@ -1,5 +1,6 @@ RWByteAddressBuffer out_ : register(u0); RWByteAddressBuffer out2_ : register(u1); +RWByteAddressBuffer out3_ : register(u2); [numthreads(1, 1, 1)] void main() @@ -8,5 +9,6 @@ void main() int2 b = int2(3, 4); out_.Store4(0, asuint(int4(4, 3, 2, 1))); out2_.Store(0, asuint(2)); + out3_.Store(0, asuint(6)); return; } diff --git a/tests/out/msl/const-exprs.msl b/tests/out/msl/const-exprs.msl index d8f38b4fc2..19b9b727fb 100644 --- a/tests/out/msl/const-exprs.msl +++ b/tests/out/msl/const-exprs.msl @@ -8,10 +8,12 @@ using metal::uint; kernel void main_( device metal::int4& out [[user(fake0)]] , device int& out2_ [[user(fake0)]] +, device int& out3_ [[user(fake0)]] ) { metal::int2 a = metal::int2(1, 2); metal::int2 b = metal::int2(3, 4); out = metal::int4(4, 3, 2, 1); out2_ = 2; + out3_ = 6; return; } diff --git a/tests/out/spv/const-exprs.spvasm b/tests/out/spv/const-exprs.spvasm index fa9fb4fd51..ae26ded97e 100644 --- a/tests/out/spv/const-exprs.spvasm +++ b/tests/out/spv/const-exprs.spvasm @@ -1,51 +1,67 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 30 +; Bound: 41 OpCapability Shader OpExtension "SPV_KHR_storage_buffer_storage_class" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %15 "main" -OpExecutionMode %15 LocalSize 1 1 1 -OpDecorate %8 DescriptorSet 0 -OpDecorate %8 Binding 0 -OpDecorate %9 Block -OpMemberDecorate %9 0 Offset 0 -OpDecorate %11 DescriptorSet 0 -OpDecorate %11 Binding 1 -OpDecorate %12 Block -OpMemberDecorate %12 0 Offset 0 +OpEntryPoint GLCompute %20 "main" +OpExecutionMode %20 LocalSize 1 1 1 +OpDecorate %10 DescriptorSet 0 +OpDecorate %10 Binding 0 +OpDecorate %11 Block +OpMemberDecorate %11 0 Offset 0 +OpDecorate %13 DescriptorSet 0 +OpDecorate %13 Binding 1 +OpDecorate %14 Block +OpMemberDecorate %14 0 Offset 0 +OpDecorate %16 DescriptorSet 0 +OpDecorate %16 Binding 2 +OpDecorate %17 Block +OpMemberDecorate %17 0 Offset 0 %2 = OpTypeVoid %4 = OpTypeInt 32 1 %3 = OpTypeVector %4 4 %5 = OpTypeVector %4 2 -%6 = OpConstant %4 0 -%7 = OpConstant %4 1 -%9 = OpTypeStruct %3 -%10 = OpTypePointer StorageBuffer %9 -%8 = OpVariable %10 StorageBuffer -%12 = OpTypeStruct %4 -%13 = OpTypePointer StorageBuffer %12 -%11 = OpVariable %13 StorageBuffer -%16 = OpTypeFunction %2 -%17 = OpTypePointer StorageBuffer %3 -%19 = OpTypeInt 32 0 -%18 = OpConstant %19 0 -%21 = OpTypePointer StorageBuffer %4 -%23 = OpConstant %4 2 -%24 = OpConstantComposite %5 %7 %23 -%25 = OpConstant %4 3 -%26 = OpConstant %4 4 -%27 = OpConstantComposite %5 %25 %26 -%28 = OpConstantComposite %3 %26 %25 %23 %7 -%15 = OpFunction %2 None %16 -%14 = OpLabel -%20 = OpAccessChain %17 %8 %18 -%22 = OpAccessChain %21 %11 %18 -OpBranch %29 -%29 = OpLabel -OpStore %20 %28 -OpStore %22 %23 +%6 = OpTypeVector %4 3 +%7 = OpConstant %4 0 +%8 = OpConstant %4 1 +%9 = OpConstant %4 2 +%11 = OpTypeStruct %3 +%12 = OpTypePointer StorageBuffer %11 +%10 = OpVariable %12 StorageBuffer +%14 = OpTypeStruct %4 +%15 = OpTypePointer StorageBuffer %14 +%13 = OpVariable %15 StorageBuffer +%17 = OpTypeStruct %4 +%18 = OpTypePointer StorageBuffer %17 +%16 = OpVariable %18 StorageBuffer +%21 = OpTypeFunction %2 +%22 = OpTypePointer StorageBuffer %3 +%24 = OpTypeInt 32 0 +%23 = OpConstant %24 0 +%26 = OpTypePointer StorageBuffer %4 +%29 = OpConstantComposite %5 %8 %9 +%30 = OpConstant %4 3 +%31 = OpConstant %4 4 +%32 = OpConstantComposite %5 %30 %31 +%33 = OpConstantComposite %3 %31 %30 %9 %8 +%34 = OpConstant %4 6 +%35 = OpConstant %4 7 +%36 = OpConstantComposite %5 %34 %35 +%37 = OpConstant %4 8 +%38 = OpConstantComposite %6 %36 %37 +%39 = OpConstant %4 9 +%20 = OpFunction %2 None %21 +%19 = OpLabel +%25 = OpAccessChain %22 %10 %23 +%27 = OpAccessChain %26 %13 %23 +%28 = OpAccessChain %26 %16 %23 +OpBranch %40 +%40 = OpLabel +OpStore %25 %33 +OpStore %27 %9 +OpStore %28 %34 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/const-exprs.wgsl b/tests/out/wgsl/const-exprs.wgsl index b58339b8af..201535836e 100644 --- a/tests/out/wgsl/const-exprs.wgsl +++ b/tests/out/wgsl/const-exprs.wgsl @@ -2,6 +2,8 @@ var out: vec4; @group(0) @binding(1) var out2_: i32; +@group(0) @binding(2) +var out3_: i32; @compute @workgroup_size(1, 1, 1) fn main() { @@ -9,5 +11,6 @@ fn main() { let b = vec2(3, 4); out = vec4(4, 3, 2, 1); out2_ = 2; + out3_ = 6; return; }