diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index de727da63c..f2352a12e1 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -128,6 +128,8 @@ pub enum ConstantEvaluatorError { InvalidMathArg, #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")] InvalidMathArgCount(crate::MathFunction, usize, usize), + #[error("value of `low` is greater than `high` for clamp built-in function")] + InvalidClamp, #[error("Splat is defined only on scalar values")] SplatScalarOnly, #[error("Can only swizzle vector constants")] @@ -497,60 +499,181 @@ impl<'a> ConstantEvaluator<'a> { )); } - let const0 = &self.expressions[arg]; - let const1 = arg1.map(|arg| &self.expressions[arg]); - let const2 = arg2.map(|arg| &self.expressions[arg]); - let _const3 = arg3.map(|arg| &self.expressions[arg]); - match fun { - crate::MathFunction::Pow => { - let literal = match (const0, const1.unwrap()) { - (&Expression::Literal(value0), &Expression::Literal(value1)) => { - match (value0, value1) { - (Literal::I32(a), Literal::I32(b)) => Literal::I32(a.pow(b as u32)), - (Literal::U32(a), Literal::U32(b)) => Literal::U32(a.pow(b)), - (Literal::F32(a), Literal::F32(b)) => Literal::F32(a.powf(b)), - _ => return Err(ConstantEvaluatorError::InvalidMathArg), - } + crate::MathFunction::Pow => self.math_pow(arg, arg1.unwrap(), span), + crate::MathFunction::Clamp => self.math_clamp(arg, arg1.unwrap(), arg2.unwrap(), span), + fun => Err(ConstantEvaluatorError::NotImplemented(format!( + "{fun:?} built-in function" + ))), + } + } + + fn math_pow( + &mut self, + e1: Handle, + e2: Handle, + span: Span, + ) -> Result, ConstantEvaluatorError> { + let e1 = self.eval_zero_value_and_splat(e1, span)?; + let e2 = self.eval_zero_value_and_splat(e2, span)?; + + let expr = match (&self.expressions[e1], &self.expressions[e2]) { + (&Expression::Literal(Literal::F32(a)), &Expression::Literal(Literal::F32(b))) => { + Expression::Literal(Literal::F32(a.powf(b))) + } + ( + &Expression::Compose { + components: ref src_components0, + ty: ty0, + }, + &Expression::Compose { + components: ref src_components1, + ty: ty1, + }, + ) if ty0 == ty1 + && matches!( + self.types[ty0].inner, + crate::TypeInner::Vector { + kind: crate::ScalarKind::Float, + .. } - _ => return Err(ConstantEvaluatorError::InvalidMathArg), - }; + ) => + { + let mut components: Vec<_> = crate::proc::flatten_compose( + ty0, + src_components0, + self.expressions, + self.types, + ) + .chain(crate::proc::flatten_compose( + ty1, + src_components1, + self.expressions, + self.types, + )) + .collect(); + + let mid = components.len() / 2; + let (first, last) = components.split_at_mut(mid); + for (a, b) in first.iter_mut().zip(&*last) { + *a = self.math_pow(*a, *b, span)?; + } + components.drain(mid..); - let expr = Expression::Literal(literal); - Ok(self.register_evaluated_expr(expr, span)) + Expression::Compose { + ty: ty0, + components, + } } - crate::MathFunction::Clamp => { - let literal = match (const0, const1.unwrap(), const2.unwrap()) { - ( - &Expression::Literal(value0), - &Expression::Literal(value1), - &Expression::Literal(value2), - ) => match (value0, value1, value2) { - (Literal::I32(a), Literal::I32(b), Literal::I32(c)) => { - Literal::I32(a.clamp(b, c)) + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }; + + Ok(self.register_evaluated_expr(expr, span)) + } + + fn math_clamp( + &mut self, + e: Handle, + low: Handle, + high: Handle, + span: Span, + ) -> Result, ConstantEvaluatorError> { + let e = self.eval_zero_value_and_splat(e, span)?; + let low = self.eval_zero_value_and_splat(low, span)?; + let high = self.eval_zero_value_and_splat(high, span)?; + + let expr = match ( + &self.expressions[e], + &self.expressions[low], + &self.expressions[high], + ) { + (&Expression::Literal(e), &Expression::Literal(low), &Expression::Literal(high)) => { + let literal = match (e, low, high) { + (Literal::I32(e), Literal::I32(low), Literal::I32(high)) => { + if low > high { + return Err(ConstantEvaluatorError::InvalidClamp); + } else { + Literal::I32(e.clamp(low, high)) } - (Literal::U32(a), Literal::U32(b), Literal::U32(c)) => { - Literal::U32(a.clamp(b, c)) + } + (Literal::U32(e), Literal::U32(low), Literal::U32(high)) => { + if low > high { + return Err(ConstantEvaluatorError::InvalidClamp); + } else { + Literal::U32(e.clamp(low, high)) } - (Literal::F32(a), Literal::F32(b), Literal::F32(c)) => { - Literal::F32(glsl_float_clamp(a, b, c)) + } + (Literal::F32(e), Literal::F32(low), Literal::F32(high)) => { + if low > high { + return Err(ConstantEvaluatorError::InvalidClamp); + } else { + Literal::F32(e.clamp(low, high)) } - _ => return Err(ConstantEvaluatorError::InvalidMathArg), - }, - _ => { - return Err(ConstantEvaluatorError::NotImplemented( - "clamp built-in function with vector values".into(), - )) } + _ => return Err(ConstantEvaluatorError::InvalidMathArg), }; + Expression::Literal(literal) + } + ( + &Expression::Compose { + components: ref src_components0, + ty: ty0, + }, + &Expression::Compose { + components: ref src_components1, + ty: ty1, + }, + &Expression::Compose { + components: ref src_components2, + ty: ty2, + }, + ) if ty0 == ty1 + && ty0 == ty2 + && matches!( + self.types[ty0].inner, + crate::TypeInner::Vector { + kind: crate::ScalarKind::Float, + .. + } + ) => + { + let mut components: Vec<_> = crate::proc::flatten_compose( + ty0, + src_components0, + self.expressions, + self.types, + ) + .chain(crate::proc::flatten_compose( + ty1, + src_components1, + self.expressions, + self.types, + )) + .chain(crate::proc::flatten_compose( + ty2, + src_components2, + self.expressions, + self.types, + )) + .collect(); + + let chunk_size = components.len() / 3; + let (es, rem) = components.split_at_mut(chunk_size); + let (lows, highs) = rem.split_at(chunk_size); + for ((e, low), high) in es.iter_mut().zip(lows).zip(highs) { + *e = self.math_clamp(*e, *low, *high, span)?; + } + components.drain(chunk_size..); - let expr = Expression::Literal(literal); - Ok(self.register_evaluated_expr(expr, span)) + Expression::Compose { + ty: ty0, + components, + } } - fun => Err(ConstantEvaluatorError::NotImplemented(format!( - "{fun:?} built-in function" - ))), - } + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }; + + Ok(self.register_evaluated_expr(expr, span)) } fn array_length( @@ -996,6 +1119,8 @@ impl<'a> ConstantEvaluator<'a> { } fn register_evaluated_expr(&mut self, expr: Expression, span: Span) -> Handle { + // TODO: use the validate_literal function from https://github.com/gfx-rs/naga/pull/2508 here + if let Some(FunctionLocalData { ref mut emitter, ref mut block, @@ -1022,57 +1147,6 @@ impl<'a> ConstantEvaluator<'a> { } } -/// Helper function to implement the GLSL `max` function for floats. -/// -/// While Rust does provide a `f64::max` method, it has a different behavior than the -/// GLSL `max` for NaNs. In Rust, if any of the arguments is a NaN, then the other -/// is returned. -/// -/// This leads to different results in the following example -/// ``` -/// use std::cmp::max; -/// std::f64::NAN.max(1.0); -/// ``` -/// -/// Rust will return `1.0` while GLSL should return NaN. -fn glsl_float_max(x: f32, y: f32) -> f32 { - if x < y { - y - } else { - x - } -} - -/// Helper function to implement the GLSL `min` function for floats. -/// -/// While Rust does provide a `f64::min` method, it has a different behavior than the -/// GLSL `min` for NaNs. In Rust, if any of the arguments is a NaN, then the other -/// is returned. -/// -/// This leads to different results in the following example -/// ``` -/// use std::cmp::min; -/// std::f64::NAN.min(1.0); -/// ``` -/// -/// Rust will return `1.0` while GLSL should return NaN. -fn glsl_float_min(x: f32, y: f32) -> f32 { - if y < x { - y - } else { - x - } -} - -/// Helper function to implement the GLSL `clamp` function for floats. -/// -/// While Rust does provide a `f64::clamp` method, it panics if either -/// `min` or `max` are `NaN`s which is not the behavior specified by -/// the glsl specification. -fn glsl_float_clamp(value: f32, min: f32, max: f32) -> f32 { - glsl_float_min(glsl_float_max(value, min), max) -} - #[cfg(test)] mod tests { use std::vec; @@ -1084,19 +1158,6 @@ mod tests { use super::{Behavior, ConstantEvaluator}; - #[test] - fn nan_handling() { - assert!(super::glsl_float_max(f32::NAN, 2.0).is_nan()); - assert!(!super::glsl_float_max(2.0, f32::NAN).is_nan()); - - assert!(super::glsl_float_min(f32::NAN, 2.0).is_nan()); - assert!(!super::glsl_float_min(2.0, f32::NAN).is_nan()); - - assert!(super::glsl_float_clamp(f32::NAN, 1.0, 2.0).is_nan()); - assert!(!super::glsl_float_clamp(1.0, f32::NAN, 2.0).is_nan()); - assert!(!super::glsl_float_clamp(1.0, 2.0, f32::NAN).is_nan()); - } - #[test] fn unary_op() { let mut types = UniqueArena::new();