Skip to content

Commit

Permalink
[const-eval] error on NaN and infinite floats
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Oct 10, 2023
1 parent 6167ef5 commit ee0fd1a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 30 deletions.
69 changes: 40 additions & 29 deletions src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl ExpressionConstnessTracker {
}
}

#[derive(Clone, Debug, PartialEq, thiserror::Error)]
#[derive(Clone, Debug, thiserror::Error)]
pub enum ConstantEvaluatorError {
#[error("Constants cannot access function arguments")]
FunctionArg,
Expand Down Expand Up @@ -144,6 +144,8 @@ pub enum ConstantEvaluatorError {
RemainderByZero,
#[error("RHS of shift operation is greater than or equal to 32")]
ShiftedMoreThan32Bits,
#[error(transparent)]
Literal(#[from] crate::valid::LiteralError),
}

impl<'a> ConstantEvaluator<'a> {
Expand Down Expand Up @@ -270,18 +272,18 @@ impl<'a> ConstantEvaluator<'a> {
Ok(self.constants[c].init)
}
Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
Ok(self.register_evaluated_expr(expr.clone(), span))
self.register_evaluated_expr(expr.clone(), span)
}
Expression::Compose { ty, ref components } => {
let components = components
.iter()
.map(|component| self.check_and_get(*component))
.collect::<Result<Vec<_>, _>>()?;
Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span))
self.register_evaluated_expr(Expression::Compose { ty, components }, span)
}
Expression::Splat { size, value } => {
let value = self.check_and_get(value)?;
Ok(self.register_evaluated_expr(Expression::Splat { size, value }, span))
self.register_evaluated_expr(Expression::Splat { size, value }, span)
}
Expression::AccessIndex { base, index } => {
let base = self.check_and_get(base)?;
Expand Down Expand Up @@ -395,7 +397,7 @@ impl<'a> ConstantEvaluator<'a> {
ty,
components: vec![value; size as usize],
};
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
Expression::ZeroValue(ty) => {
let inner = match self.types[ty].inner {
Expand All @@ -404,7 +406,7 @@ impl<'a> ConstantEvaluator<'a> {
};
let res_ty = self.types.insert(Type { name: None, inner }, span);
let expr = Expression::ZeroValue(res_ty);
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
_ => Err(ConstantEvaluatorError::SplatScalarOnly),
}
Expand Down Expand Up @@ -436,11 +438,11 @@ impl<'a> ConstantEvaluator<'a> {
Expression::ZeroValue(ty) => {
let dst_ty = get_dst_ty(ty)?;
let expr = Expression::ZeroValue(dst_ty);
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
Expression::Splat { value, .. } => {
let expr = Expression::Splat { size, value };
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
Expression::Compose { ty, ref components } => {
let dst_ty = get_dst_ty(ty)?;
Expand Down Expand Up @@ -468,7 +470,7 @@ impl<'a> ConstantEvaluator<'a> {
ty: dst_ty,
components: swizzled_components,
};
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
_ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
}
Expand Down Expand Up @@ -565,7 +567,7 @@ impl<'a> ConstantEvaluator<'a> {
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
};

Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}

fn math_clamp(
Expand Down Expand Up @@ -670,7 +672,7 @@ impl<'a> ConstantEvaluator<'a> {
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
};

Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}

fn array_length(
Expand All @@ -684,7 +686,7 @@ impl<'a> ConstantEvaluator<'a> {
TypeInner::Array { size, .. } => match size {
crate::ArraySize::Constant(len) => {
let expr = Expression::Literal(Literal::U32(len.get()));
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
crate::ArraySize::Dynamic => {
Err(ConstantEvaluatorError::ArrayLengthDynamic)
Expand Down Expand Up @@ -722,7 +724,7 @@ impl<'a> ConstantEvaluator<'a> {
self.types.insert(Type { name: None, inner }, span)
}
};
Ok(self.register_evaluated_expr(Expression::ZeroValue(ty), span))
self.register_evaluated_expr(Expression::ZeroValue(ty), span)
}
}
Expression::Splat { size, value } => {
Expand Down Expand Up @@ -788,7 +790,7 @@ impl<'a> ConstantEvaluator<'a> {
Literal::zero(kind, width)
.ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
);
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
TypeInner::Vector { size, kind, width } => {
let scalar_ty = self.types.insert(
Expand All @@ -803,7 +805,7 @@ impl<'a> ConstantEvaluator<'a> {
ty,
components: vec![el; size as usize],
};
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
TypeInner::Matrix {
columns,
Expand All @@ -826,7 +828,7 @@ impl<'a> ConstantEvaluator<'a> {
ty,
components: vec![el; columns as usize],
};
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
TypeInner::Array {
base,
Expand All @@ -838,7 +840,7 @@ impl<'a> ConstantEvaluator<'a> {
ty,
components: vec![el; size.get() as usize],
};
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
TypeInner::Struct { ref members, .. } => {
let types: Vec<_> = members.iter().map(|m| m.ty).collect();
Expand All @@ -847,7 +849,7 @@ impl<'a> ConstantEvaluator<'a> {
components.push(self.eval_zero_value_impl(ty, span)?);
}
let expr = Expression::Compose { ty, components };
Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}
_ => Err(ConstantEvaluatorError::TypeNotConstructible),
}
Expand Down Expand Up @@ -933,7 +935,7 @@ impl<'a> ConstantEvaluator<'a> {
_ => return Err(ConstantEvaluatorError::InvalidCastArg),
};

Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}

fn unary_op(
Expand Down Expand Up @@ -977,7 +979,7 @@ impl<'a> ConstantEvaluator<'a> {
_ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
};

Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}

fn binary_op(
Expand Down Expand Up @@ -1113,7 +1115,7 @@ impl<'a> ConstantEvaluator<'a> {
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
};

Ok(self.register_evaluated_expr(expr, span))
self.register_evaluated_expr(expr, span)
}

/// Deep copy `expr` from `expressions` into `self.expressions`.
Expand All @@ -1132,17 +1134,17 @@ impl<'a> ConstantEvaluator<'a> {
match expressions[expr] {
ref expr @ (Expression::Literal(_)
| Expression::Constant(_)
| Expression::ZeroValue(_)) => Ok(self.register_evaluated_expr(expr.clone(), span)),
| Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
Expression::Compose { ty, ref components } => {
let mut components = components.clone();
for component in &mut components {
*component = self.copy_from(*component, expressions)?;
}
Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span))
self.register_evaluated_expr(Expression::Compose { ty, components }, span)
}
Expression::Splat { size, value } => {
let value = self.copy_from(value, expressions)?;
Ok(self.register_evaluated_expr(Expression::Splat { size, value }, span))
self.register_evaluated_expr(Expression::Splat { size, value }, span)
}
_ => {
log::debug!("copy_from: SubexpressionsAreNotConstant");
Expand All @@ -1151,8 +1153,17 @@ impl<'a> ConstantEvaluator<'a> {
}
}

fn register_evaluated_expr(&mut self, expr: Expression, span: Span) -> Handle<Expression> {
// TODO: use the validate_literal function from https://github.com/gfx-rs/naga/pull/2508 here
fn register_evaluated_expr(
&mut self,
expr: Expression,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
// It suffices to only check literals, since we only register one
// expression at a time, `Compose` expressions can only refer to other
// expressions, and `ZeroValue` expressions are always okay.
if let Expression::Literal(literal) = expr {
crate::valid::validate_literal(literal)?;
}

if let Some(FunctionLocalData {
ref mut emitter,
Expand All @@ -1168,14 +1179,14 @@ impl<'a> ConstantEvaluator<'a> {
let h = self.expressions.append(expr, span);
emitter.start(self.expressions);
expression_constness.insert(h);
h
Ok(h)
} else {
let h = self.expressions.append(expr, span);
expression_constness.insert(h);
h
Ok(h)
}
} else {
self.expressions.append(expr, span)
Ok(self.expressions.append(expr, span))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,7 @@ impl super::Validator {
}
}

fn validate_literal(literal: crate::Literal) -> Result<(), LiteralError> {
pub fn validate_literal(literal: crate::Literal) -> Result<(), LiteralError> {
let is_nan = match literal {
crate::Literal::F64(v) => v.is_nan(),
crate::Literal::F32(v) => v.is_nan(),
Expand Down
1 change: 1 addition & 0 deletions src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::ops;
use crate::span::{AddSpan as _, WithSpan};
pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
pub use compose::ComposeError;
pub use expression::{validate_literal, LiteralError};
pub use expression::{ConstExpressionError, ExpressionError};
pub use function::{CallError, FunctionError, LocalVariableError};
pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
Expand Down

0 comments on commit ee0fd1a

Please sign in to comment.