diff --git a/src/lib.rs b/src/lib.rs index e7abe1e1c5..563325a7d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -854,7 +854,9 @@ pub enum TypeInner { #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Literal { + /// May not be NaN or infinity. F64(f64), + /// May not be NaN or infinity. F32(f32), U32(u32), I32(i32), diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 3426bf008e..0b61253682 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -124,6 +124,29 @@ pub enum ExpressionError { InvalidWorkGroupUniformLoadResultType(Handle), #[error("Shader requires capability {0:?}")] MissingCapabilities(super::Capabilities), + #[error(transparent)] + Literal(#[from] LiteralError), +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ConstExpressionError { + #[error("The expression is not a constant expression")] + NonConst, + #[error(transparent)] + Compose(#[from] super::ComposeError), + #[error("Type resolution failed")] + Type(#[from] ResolveError), + #[error(transparent)] + Literal(#[from] LiteralError), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum LiteralError { + #[error("Float literal is NaN")] + NaN, + #[error("Float literal is infinite")] + Infinity, } #[cfg(feature = "validate")] @@ -158,11 +181,14 @@ impl super::Validator { handle: Handle, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, - ) -> Result<(), super::ConstExpressionError> { + ) -> Result<(), ConstExpressionError> { use crate::Expression as E; match gctx.const_expressions[handle] { - E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => {} + E::Literal(literal) => { + validate_literal(literal)?; + } + E::Constant(_) | E::ZeroValue(_) => {} E::Compose { ref components, ty } => { validate_compose( ty, @@ -310,7 +336,11 @@ impl super::Validator { } ShaderStages::all() } - E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Literal(literal) => { + validate_literal(literal)?; + ShaderStages::all() + } + E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), E::Compose { ref components, ty } => { validate_compose( ty, @@ -1529,3 +1559,25 @@ impl super::Validator { } } } + +fn validate_literal(literal: crate::Literal) -> Result<(), LiteralError> { + let is_nan = match literal { + crate::Literal::F64(v) => v.is_nan(), + crate::Literal::F32(v) => v.is_nan(), + _ => false, + }; + if is_nan { + return Err(LiteralError::NaN); + } + + let is_infinite = match literal { + crate::Literal::F64(v) => v.is_infinite(), + crate::Literal::F32(v) => v.is_infinite(), + _ => false, + }; + if is_infinite { + return Err(LiteralError::Infinity); + } + + Ok(()) +} diff --git a/src/valid/mod.rs b/src/valid/mod.rs index f99a2055ce..6175aa0945 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -24,7 +24,7 @@ use std::ops; use crate::span::{AddSpan as _, WithSpan}; pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; pub use compose::ComposeError; -pub use expression::ExpressionError; +pub use expression::{ConstExpressionError, ExpressionError}; pub use function::{CallError, FunctionError, LocalVariableError}; pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; pub use r#type::{Disalignment, TypeError, TypeFlags}; @@ -180,16 +180,6 @@ pub struct Validator { valid_expression_set: BitSet, } -#[derive(Clone, Debug, thiserror::Error)] -pub enum ConstExpressionError { - #[error("The expression is not a constant expression")] - NonConst, - #[error(transparent)] - Compose(#[from] ComposeError), - #[error("Type resolution failed")] - Type(#[from] crate::proc::ResolveError), -} - #[derive(Clone, Debug, thiserror::Error)] pub enum ConstantError { #[error("The type doesn't match the constant")]