Skip to content

Commit

Permalink
[const-eval] error on NaN and infinite floats
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Oct 9, 2023
1 parent f29e5e4 commit 3ba1ad8
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 30 deletions.
67 changes: 38 additions & 29 deletions src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl ExpressionConstnessTracker {
}
}

#[derive(Clone, Debug, PartialEq, thiserror::Error)]
#[derive(Clone, Debug, thiserror::Error)]
pub enum ConstantEvaluatorError {
#[error("Constants cannot access function arguments")]
FunctionArg,
Expand Down Expand Up @@ -144,6 +144,8 @@ pub enum ConstantEvaluatorError {
RemainderByZero,
#[error("RHS of shift operation is greater than or equal to 32")]
ShiftedMoreThan32Bits,
#[error(transparent)]
Literal(#[from] crate::valid::LiteralError),
}

impl<'a> ConstantEvaluator<'a> {
Expand Down Expand Up @@ -266,18 +268,18 @@ impl<'a> ConstantEvaluator<'a> {
Ok(self.constants[c].init)
}
Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
Ok(self.register_evaluated_expr(expr.clone(), span))
self.register_evaluated_expr(expr.clone(), span)
}
Expression::Compose { ty, ref components } => {
let components = components
.iter()
.map(|component| self.check_and_get(*component))
.collect::<Result<Vec<_>, _>>()?;
Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span))
self.register_evaluated_expr(Expression::Compose { ty, components }, span)
}
Expression::Splat { size, value } => {
let value = self.check_and_get(value)?;
Ok(self.register_evaluated_expr(Expression::Splat { size, value }, span))
self.register_evaluated_expr(Expression::Splat { size, value }, span)
}
Expression::AccessIndex { base, index } => {
let base = self.check_and_get(base)?;
Expand Down Expand Up @@ -391,7 +393,7 @@ impl<'a> ConstantEvaluator<'a> {
ty,
components: vec![value; size as usize],
};
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
Expression::ZeroValue(ty) => {
let inner = match self.types[ty].inner {
Expand All @@ -400,7 +402,7 @@ impl<'a> ConstantEvaluator<'a> {
};
let res_ty = self.types.insert(Type { name: None, inner }, span);
let expr = Expression::ZeroValue(res_ty);
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
_ => Err(ConstantEvaluatorError::SplatScalarOnly),
}
Expand Down Expand Up @@ -432,11 +434,11 @@ impl<'a> ConstantEvaluator<'a> {
Expression::ZeroValue(ty) => {
let dst_ty = get_dst_ty(ty)?;
let expr = Expression::ZeroValue(dst_ty);
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
Expression::Splat { value, .. } => {
let expr = Expression::Splat { size, value };
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
Expression::Compose { ty, ref components } => {
let dst_ty = get_dst_ty(ty)?;
Expand Down Expand Up @@ -464,7 +466,7 @@ impl<'a> ConstantEvaluator<'a> {
ty: dst_ty,
components: swizzled_components,
};
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
_ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
}
Expand Down Expand Up @@ -561,7 +563,7 @@ impl<'a> ConstantEvaluator<'a> {
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
};

Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}

fn math_clamp(
Expand Down Expand Up @@ -666,7 +668,7 @@ impl<'a> ConstantEvaluator<'a> {
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
};

Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}

fn array_length(
Expand All @@ -680,7 +682,7 @@ impl<'a> ConstantEvaluator<'a> {
TypeInner::Array { size, .. } => match size {
crate::ArraySize::Constant(len) => {
let expr = Expression::Literal(Literal::U32(len.get()));
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
crate::ArraySize::Dynamic => {
Err(ConstantEvaluatorError::ArrayLengthDynamic)
Expand Down Expand Up @@ -718,7 +720,7 @@ impl<'a> ConstantEvaluator<'a> {
self.types.insert(Type { name: None, inner }, span)
}
};
Ok(self.register_evaluated_expr(Expression::ZeroValue(ty), span))
self.register_evaluated_expr(Expression::ZeroValue(ty), span)
}
}
Expression::Splat { size, value } => {
Expand Down Expand Up @@ -784,7 +786,7 @@ impl<'a> ConstantEvaluator<'a> {
Literal::zero(kind, width)
.ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
);
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
TypeInner::Vector { size, kind, width } => {
let scalar_ty = self.types.insert(
Expand All @@ -799,7 +801,7 @@ impl<'a> ConstantEvaluator<'a> {
ty,
components: vec![el; size as usize],
};
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
TypeInner::Matrix {
columns,
Expand All @@ -822,7 +824,7 @@ impl<'a> ConstantEvaluator<'a> {
ty,
components: vec![el; columns as usize],
};
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
TypeInner::Array {
base,
Expand All @@ -834,7 +836,7 @@ impl<'a> ConstantEvaluator<'a> {
ty,
components: vec![el; size.get() as usize],
};
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
TypeInner::Struct { ref members, .. } => {
let types: Vec<_> = members.iter().map(|m| m.ty).collect();
Expand All @@ -843,7 +845,7 @@ impl<'a> ConstantEvaluator<'a> {
components.push(self.eval_zero_value_impl(ty, span)?);
}
let expr = Expression::Compose { ty, components };
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
_ => Err(ConstantEvaluatorError::TypeNotConstructible),
}
Expand Down Expand Up @@ -929,7 +931,7 @@ impl<'a> ConstantEvaluator<'a> {
_ => return Err(ConstantEvaluatorError::InvalidCastArg),
};

Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}

fn unary_op(
Expand Down Expand Up @@ -973,7 +975,7 @@ impl<'a> ConstantEvaluator<'a> {
_ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
};

Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}

fn binary_op(
Expand Down Expand Up @@ -1109,7 +1111,7 @@ impl<'a> ConstantEvaluator<'a> {
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
};

Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}

/// Deep copy `expr` from `expressions` into `self.expressions`.
Expand All @@ -1128,17 +1130,17 @@ impl<'a> ConstantEvaluator<'a> {
match expressions[expr] {
ref expr @ (Expression::Literal(_)
| Expression::Constant(_)
| Expression::ZeroValue(_)) => Ok(self.register_evaluated_expr(expr.clone(), span)),
| Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
Expression::Compose { ty, ref components } => {
let mut components = components.clone();
for component in &mut components {
*component = self.copy_from(*component, expressions)?;
}
Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span))
self.register_evaluated_expr(Expression::Compose { ty, components }, span)
}
Expression::Splat { size, value } => {
let value = self.copy_from(value, expressions)?;
Ok(self.register_evaluated_expr(Expression::Splat { size, value }, span))
self.register_evaluated_expr(Expression::Splat { size, value }, span)
}
_ => {
log::debug!("copy_from: SubexpressionsAreNotConstant");
Expand All @@ -1147,8 +1149,15 @@ impl<'a> ConstantEvaluator<'a> {
}
}

fn register_evaluated_expr(&mut self, expr: Expression, span: Span) -> Handle<Expression> {
// TODO: use the validate_literal function from https://github.com/gfx-rs/naga/pull/2508 here
fn register_evaluated_expr(
&mut self,
expr: Expression,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
match expr {
Expression::Literal(literal) => crate::valid::validate_literal(literal)?,
_ => {}
}

if let Some(FunctionLocalData {
ref mut emitter,
Expand All @@ -1164,14 +1173,14 @@ impl<'a> ConstantEvaluator<'a> {
let h = self.expressions.append(expr, span);
emitter.start(self.expressions);
expression_constness.insert(h);
h
Ok(h)
} else {
let h = self.expressions.append(expr, span);
expression_constness.insert(h);
h
Ok(h)
}
} else {
self.expressions.append(expr, span)
Ok(self.expressions.append(expr, span))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1568,7 +1568,7 @@ impl super::Validator {
}
}

fn validate_literal(literal: crate::Literal) -> Result<(), LiteralError> {
pub fn validate_literal(literal: crate::Literal) -> Result<(), LiteralError> {
let is_nan = match literal {
crate::Literal::F64(v) => v.is_nan(),
crate::Literal::F32(v) => v.is_nan(),
Expand Down
1 change: 1 addition & 0 deletions src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::ops;
use crate::span::{AddSpan as _, WithSpan};
pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
pub use compose::ComposeError;
pub use expression::{validate_literal, LiteralError};
pub use expression::{ConstExpressionError, ExpressionError};
pub use function::{CallError, FunctionError, LocalVariableError};
pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
Expand Down

0 comments on commit 3ba1ad8

Please sign in to comment.