Skip to content

Commit

Permalink
ConstantEvaluator::swizzle: Handle vector concatenation and indexing (#…
Browse files Browse the repository at this point in the history
…2485)

* ConstantEvaluator::swizzle: Handle vector concatenation, indexing.

* Handle vector Compose expressions nested two deep.

* Move `flatten_compose` to `proc`, and make it a free function.

* [spv-out] Ensure that we flatten Compose for OpConstantCompose.
  • Loading branch information
jimblandy committed Sep 20, 2023
1 parent 9668a1b commit 7ca5d3f
Show file tree
Hide file tree
Showing 12 changed files with 270 additions and 28 deletions.
27 changes: 15 additions & 12 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,21 +243,24 @@ impl<'w> BlockContext<'w> {
self.writer.constant_ids[init.index()]
}
crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
crate::Expression::Compose {
ty: _,
ref components,
} => {
crate::Expression::Compose { ty, ref components } => {
self.temp_list.clear();
for &component in components {
self.temp_list.push(self.cached[component]);
}

if self.ir_function.expressions.is_const(expr_handle) {
let ty = self
.writer
.get_expression_lookup_type(&self.fun_info[expr_handle].ty);
self.writer.get_constant_composite(ty, &self.temp_list)
self.temp_list.extend(
crate::proc::flatten_compose(
ty,
components,
&self.ir_function.expressions,
&self.ir_module.types,
)
.map(|component| self.cached[component]),
);
self.writer
.get_constant_composite(LookupType::Handle(ty), &self.temp_list)
} else {
self.temp_list
.extend(components.iter().map(|&component| self.cached[component]));

let id = self.gen_id();
block.body.push(Instruction::composite_construct(
result_type_id,
Expand Down
12 changes: 8 additions & 4 deletions src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1269,10 +1269,14 @@ impl Writer {
self.get_constant_null(type_id)
}
crate::Expression::Compose { ty, ref components } => {
let component_ids: Vec<_> = components
.iter()
.map(|component| self.constant_ids[component.index()])
.collect();
let component_ids: Vec<_> = crate::proc::flatten_compose(
ty,
components,
&ir_module.const_expressions,
&ir_module.types,
)
.map(|component| self.constant_ids[component.index()])
.collect();
self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
}
crate::Expression::Splat { size, value } => {
Expand Down
36 changes: 24 additions & 12 deletions src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ pub enum ConstantEvaluatorError {
SplatScalarOnly,
#[error("Can only swizzle vector constants")]
SwizzleVectorOnly,
#[error("swizzle component not present in source expression")]
SwizzleOutOfBounds,
#[error("Type is not constructible")]
TypeNotConstructible,
#[error("Subexpression(s) are not constant")]
Expand Down Expand Up @@ -306,20 +308,31 @@ impl ConstantEvaluator<'_> {
let expr = Expression::Splat { size, value };
Ok(self.register_evaluated_expr(expr, span))
}
Expression::Compose {
ty,
components: ref src_components,
} => {
Expression::Compose { ty, ref components } => {
let dst_ty = get_dst_ty(ty)?;

let components = pattern
let mut flattened = [src_constant; 4]; // dummy value
let len =
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
.zip(flattened.iter_mut())
.map(|(component, elt)| *elt = component)
.count();
let flattened = &flattened[..len];

let swizzled_components = pattern[..size as usize]
.iter()
.take(size as usize)
.map(|&sc| src_components[sc as usize])
.collect();
.map(|&sc| {
let sc = sc as usize;
if let Some(elt) = flattened.get(sc) {
Ok(*elt)
} else {
Err(ConstantEvaluatorError::SwizzleOutOfBounds)
}
})
.collect::<Result<Vec<Handle<Expression>>, _>>()?;
let expr = Expression::Compose {
ty: dst_ty,
components,
components: swizzled_components,
};
Ok(self.register_evaluated_expr(expr, span))
}
Expand Down Expand Up @@ -455,9 +468,8 @@ impl ConstantEvaluator<'_> {
.components()
.ok_or(ConstantEvaluatorError::InvalidAccessBase)?;

components
.get(index)
.copied()
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
.nth(index)
.ok_or(ConstantEvaluatorError::InvalidAccessIndex)
}
_ => Err(ConstantEvaluatorError::InvalidAccessBase),
Expand Down
55 changes: 55 additions & 0 deletions src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,61 @@ impl GlobalCtx<'_> {
}
}

/// Return an iterator over the individual components assembled by a
/// `Compose` expression.
///
/// Given `ty` and `components` from an `Expression::Compose`, return an
/// iterator over the components of the resulting value.
///
/// Normally, this would just be an iterator over `components`. However,
/// `Compose` expressions can concatenate vectors, in which case the i'th
/// value being composed is not generally the i'th element of `components`.
/// This function consults `ty` to decide if this concatenation is occuring,
/// and returns an iterator that produces the components of the result of
/// the `Compose` expression in either case.
pub fn flatten_compose<'arenas>(
ty: crate::Handle<crate::Type>,
components: &'arenas [crate::Handle<crate::Expression>],
expressions: &'arenas crate::Arena<crate::Expression>,
types: &'arenas crate::UniqueArena<crate::Type>,
) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
// Returning `impl Iterator` is a bit tricky. We may or may not want to
// flatten the components, but we have to settle on a single concrete
// type to return. The below is a single iterator chain that handles
// both the flattening and non-flattening cases.
let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
(size as usize, true)
} else {
(components.len(), false)
};

fn flattener<'c>(
component: &'c crate::Handle<crate::Expression>,
is_vector: bool,
expressions: &'c crate::Arena<crate::Expression>,
) -> &'c [crate::Handle<crate::Expression>] {
if is_vector {
if let crate::Expression::Compose {
ty: _,
components: ref subcomponents,
} = expressions[*component]
{
return subcomponents;
}
}
std::slice::from_ref(component)
}

// Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to flatten
// two levels.
components
.iter()
.flat_map(move |component| flattener(component, is_vector, expressions))
.flat_map(move |component| flattener(component, is_vector, expressions))
.take(size)
.cloned()
}

#[test]
fn test_matrix_size() {
let module = crate::Module::default();
Expand Down
14 changes: 14 additions & 0 deletions tests/in/const-exprs.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@group(0) @binding(0) var<storage, read_write> out: vec4<i32>;
@group(0) @binding(1) var<storage, read_write> out2: i32;
@group(0) @binding(2) var<storage, read_write> out3: i32;

@compute @workgroup_size(1)
fn main() {
let a = vec2(1, 2);
let b = vec2(3, 4);
out = vec4(a, b).wzyx;

out2 = vec4(a, b)[1];

out3 = vec4(vec3(vec2(6, 7), 8), 9)[0];
}
23 changes: 23 additions & 0 deletions tests/out/glsl/const-exprs.main.Compute.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#version 310 es

precision highp float;
precision highp int;

layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

layout(std430) buffer type_block_0Compute { ivec4 _group_0_binding_0_cs; };

layout(std430) buffer type_1_block_1Compute { int _group_0_binding_1_cs; };

layout(std430) buffer type_1_block_2Compute { int _group_0_binding_2_cs; };


void main() {
ivec2 a = ivec2(1, 2);
ivec2 b = ivec2(3, 4);
_group_0_binding_0_cs = ivec4(4, 3, 2, 1);
_group_0_binding_1_cs = 2;
_group_0_binding_2_cs = 6;
return;
}

14 changes: 14 additions & 0 deletions tests/out/hlsl/const-exprs.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
RWByteAddressBuffer out_ : register(u0);
RWByteAddressBuffer out2_ : register(u1);
RWByteAddressBuffer out3_ : register(u2);

[numthreads(1, 1, 1)]
void main()
{
int2 a = int2(1, 2);
int2 b = int2(3, 4);
out_.Store4(0, asuint(int4(4, 3, 2, 1)));
out2_.Store(0, asuint(2));
out3_.Store(0, asuint(6));
return;
}
12 changes: 12 additions & 0 deletions tests/out/hlsl/const-exprs.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_5_1",
),
],
)
19 changes: 19 additions & 0 deletions tests/out/msl/const-exprs.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;


kernel void main_(
device metal::int4& out [[user(fake0)]]
, device int& out2_ [[user(fake0)]]
, device int& out3_ [[user(fake0)]]
) {
metal::int2 a = metal::int2(1, 2);
metal::int2 b = metal::int2(3, 4);
out = metal::int4(4, 3, 2, 1);
out2_ = 2;
out3_ = 6;
return;
}
66 changes: 66 additions & 0 deletions tests/out/spv/const-exprs.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 40
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %17 "main"
OpExecutionMode %17 LocalSize 1 1 1
OpDecorate %7 DescriptorSet 0
OpDecorate %7 Binding 0
OpDecorate %8 Block
OpMemberDecorate %8 0 Offset 0
OpDecorate %10 DescriptorSet 0
OpDecorate %10 Binding 1
OpDecorate %11 Block
OpMemberDecorate %11 0 Offset 0
OpDecorate %13 DescriptorSet 0
OpDecorate %13 Binding 2
OpDecorate %14 Block
OpMemberDecorate %14 0 Offset 0
%2 = OpTypeVoid
%4 = OpTypeInt 32 1
%3 = OpTypeVector %4 4
%5 = OpTypeVector %4 2
%6 = OpTypeVector %4 3
%8 = OpTypeStruct %3
%9 = OpTypePointer StorageBuffer %8
%7 = OpVariable %9 StorageBuffer
%11 = OpTypeStruct %4
%12 = OpTypePointer StorageBuffer %11
%10 = OpVariable %12 StorageBuffer
%14 = OpTypeStruct %4
%15 = OpTypePointer StorageBuffer %14
%13 = OpVariable %15 StorageBuffer
%18 = OpTypeFunction %2
%19 = OpTypePointer StorageBuffer %3
%21 = OpTypeInt 32 0
%20 = OpConstant %21 0
%23 = OpTypePointer StorageBuffer %4
%26 = OpConstant %4 1
%27 = OpConstant %4 2
%28 = OpConstantComposite %5 %26 %27
%29 = OpConstant %4 3
%30 = OpConstant %4 4
%31 = OpConstantComposite %5 %29 %30
%32 = OpConstantComposite %3 %30 %29 %27 %26
%33 = OpConstant %4 6
%34 = OpConstant %4 7
%35 = OpConstantComposite %5 %33 %34
%36 = OpConstant %4 8
%37 = OpConstantComposite %6 %33 %34 %36
%38 = OpConstant %4 9
%17 = OpFunction %2 None %18
%16 = OpLabel
%22 = OpAccessChain %19 %7 %20
%24 = OpAccessChain %23 %10 %20
%25 = OpAccessChain %23 %13 %20
OpBranch %39
%39 = OpLabel
OpStore %22 %32
OpStore %24 %27
OpStore %25 %33
OpReturn
OpFunctionEnd
16 changes: 16 additions & 0 deletions tests/out/wgsl/const-exprs.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@group(0) @binding(0)
var<storage, read_write> out: vec4<i32>;
@group(0) @binding(1)
var<storage, read_write> out2_: i32;
@group(0) @binding(2)
var<storage, read_write> out3_: i32;

@compute @workgroup_size(1, 1, 1)
fn main() {
let a = vec2<i32>(1, 2);
let b = vec2<i32>(3, 4);
out = vec4<i32>(4, 3, 2, 1);
out2_ = 2;
out3_ = 6;
return;
}
4 changes: 4 additions & 0 deletions tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,10 @@ fn convert_wgsl() {
"constructors",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"const-exprs",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
];

for &(name, targets) in inputs.iter() {
Expand Down

0 comments on commit 7ca5d3f

Please sign in to comment.