Skip to content

Commit

Permalink
Let ConstantEvaluator see through Constant exprs in Splat exprs.
Browse files Browse the repository at this point in the history
  • Loading branch information
jimblandy authored and teoxoy committed Sep 29, 2023
1 parent a6d9fcd commit deb7c59
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 32 deletions.
89 changes: 86 additions & 3 deletions src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ impl<'a> ConstantEvaluator<'a> {
.collect::<Result<Vec<_>, _>>()?;
Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span))
}
Expression::Splat { value, .. } => {
self.check(value)?;
Ok(self.register_evaluated_expr(expr.clone(), span))
Expression::Splat { size, value } => {
let value = self.check_and_get(value)?;
Ok(self.register_evaluated_expr(Expression::Splat { size, value }, span))
}
Expression::AccessIndex { base, index } => {
let base = self.check_and_get(base)?;
Expand Down Expand Up @@ -1424,4 +1424,87 @@ mod tests {
panic!("unexpected evaluation result")
}
}

#[test]
fn splat_of_constant() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
let mut const_expressions = Arena::new();

let i32_ty = types.insert(
Type {
name: None,
inner: TypeInner::Scalar {
kind: ScalarKind::Sint,
width: 4,
},
},
Default::default(),
);

let vec2_i32_ty = types.insert(
Type {
name: None,
inner: TypeInner::Vector {
size: VectorSize::Bi,
kind: ScalarKind::Sint,
width: 4,
},
},
Default::default(),
);

let h = constants.append(
Constant {
name: None,
r#override: crate::Override::None,
ty: i32_ty,
init: const_expressions
.append(Expression::Literal(Literal::I32(4)), Default::default()),
},
Default::default(),
);

let h_expr = const_expressions.append(Expression::Constant(h), Default::default());

let mut solver = ConstantEvaluator {
types: &mut types,
constants: &constants,
expressions: &mut const_expressions,
function_local_data: None,
};

let solved_compose = solver
.try_eval_and_append(
&Expression::Splat {
size: VectorSize::Bi,
value: h_expr,
},
Default::default(),
)
.unwrap();
let solved_negate = solver
.try_eval_and_append(
&Expression::Unary {
op: UnaryOperator::Negate,
expr: solved_compose,
},
Default::default(),
)
.unwrap();

let pass = match const_expressions[solved_negate] {
Expression::Compose { ty, ref components } => {
ty == vec2_i32_ty
&& components.iter().all(|&component| {
let component = &const_expressions[component];
matches!(*component, Expression::Literal(Literal::I32(-4)))
})
}
_ => false,
};
if !pass {
panic!("unexpected evaluation result")
}
}
}
2 changes: 1 addition & 1 deletion tests/out/glsl/const-exprs.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void non_constant_initializers() {
}

void splat_of_constant() {
_group_0_binding_0_cs = -(ivec4(FOUR));
_group_0_binding_0_cs = ivec4(-4, -4, -4, -4);
return;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/out/hlsl/const-exprs.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void non_constant_initializers()

void splat_of_constant()
{
out_.Store4(0, asuint(-((FOUR).xxxx)));
out_.Store4(0, asuint(int4(-4, -4, -4, -4)));
return;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/out/msl/const-exprs.msl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void non_constant_initializers(
void splat_of_constant(
device metal::int4& out
) {
out = -(metal::int4(FOUR));
out = metal::int4(-4, -4, -4, -4);
return;
}

Expand Down
48 changes: 23 additions & 25 deletions tests/out/spv/const-exprs.spvasm
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 92
; Bound: 90
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %82 "main"
OpExecutionMode %82 LocalSize 1 1 1
OpEntryPoint GLCompute %80 "main"
OpExecutionMode %80 LocalSize 1 1 1
OpDecorate %8 DescriptorSet 0
OpDecorate %8 Binding 0
OpDecorate %9 Block
Expand Down Expand Up @@ -50,9 +50,8 @@ OpMemberDecorate %12 0 Offset 0
%53 = OpTypePointer Function %4
%55 = OpConstantNull %4
%57 = OpConstantNull %4
%72 = OpConstantComposite %3 %7 %7 %7 %7
%78 = OpConstant %4 -4
%79 = OpConstantComposite %3 %78 %78 %78 %78
%72 = OpConstant %4 -4
%73 = OpConstantComposite %3 %72 %72 %72 %72
%15 = OpFunction %2 None %16
%14 = OpLabel
%20 = OpAccessChain %17 %8 %18
Expand Down Expand Up @@ -107,31 +106,30 @@ OpFunctionEnd
%70 = OpFunction %2 None %16
%69 = OpLabel
%71 = OpAccessChain %17 %8 %18
OpBranch %73
%73 = OpLabel
%74 = OpSNegate %3 %72
OpStore %71 %74
OpBranch %74
%74 = OpLabel
OpStore %71 %73
OpReturn
OpFunctionEnd
%76 = OpFunction %2 None %16
%75 = OpLabel
%77 = OpAccessChain %17 %8 %18
OpBranch %80
%80 = OpLabel
OpStore %77 %79
OpBranch %78
%78 = OpLabel
OpStore %77 %73
OpReturn
OpFunctionEnd
%82 = OpFunction %2 None %16
%81 = OpLabel
%83 = OpAccessChain %17 %8 %18
%84 = OpAccessChain %30 %11 %18
OpBranch %85
%85 = OpLabel
%86 = OpFunctionCall %2 %15
%87 = OpFunctionCall %2 %29
%88 = OpFunctionCall %2 %36
%89 = OpFunctionCall %2 %48
%90 = OpFunctionCall %2 %70
%91 = OpFunctionCall %2 %76
%80 = OpFunction %2 None %16
%79 = OpLabel
%81 = OpAccessChain %17 %8 %18
%82 = OpAccessChain %30 %11 %18
OpBranch %83
%83 = OpLabel
%84 = OpFunctionCall %2 %15
%85 = OpFunctionCall %2 %29
%86 = OpFunctionCall %2 %36
%87 = OpFunctionCall %2 %48
%88 = OpFunctionCall %2 %70
%89 = OpFunctionCall %2 %76
OpReturn
OpFunctionEnd
2 changes: 1 addition & 1 deletion tests/out/wgsl/const-exprs.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn non_constant_initializers() {
}

fn splat_of_constant() {
out = -(vec4(FOUR));
out = vec4<i32>(-4, -4, -4, -4);
return;
}

Expand Down

0 comments on commit deb7c59

Please sign in to comment.