From 4dba950dd1e013c92bde76de245c88dbf68f6b74 Mon Sep 17 00:00:00 2001 From: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> Date: Fri, 17 May 2024 14:50:18 +0100 Subject: [PATCH] fix: Disallow non-finite values for `ConstF64` (#1075) Fixes #1049 . Other solutions (custom serialization/deserialization for `ConstF64`, or custom deserializer taking "null" to NaN) are difficult or otherwise problematic, and it may actually be a good thing to disallow non-finite float constants as it can catch bugs earlier. --- hugr/src/algorithm/const_fold.rs | 31 +++++++++++++++++++ hugr/src/hugr/serialize/test.rs | 4 --- .../arithmetic/float_ops/const_fold.rs | 13 ++++++-- .../std_extensions/arithmetic/float_types.rs | 6 +++- specification/hugr.md | 3 ++ 5 files changed, 49 insertions(+), 8 deletions(-) diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index f2b0dab24..f0d66e6f6 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -512,4 +512,35 @@ mod test { let expected = Value::true_val(); assert_fully_folded(&h, &expected); } + + #[test] + fn test_const_fold_to_nonfinite() { + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + ]) + .unwrap(); + + // HUGR computing 1.0 / 1.0 + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); + let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h0, ®); + let expected = Value::extension(ConstF64::new(1.0)); + assert_fully_folded(&h0, &expected); + assert_eq!(h0.node_count(), 5); + + // HUGR computing 1.0 / 0.0 + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); + let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); + let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h1, ®); + assert_eq!(h1.node_count(), 8); + } } diff --git a/hugr/src/hugr/serialize/test.rs b/hugr/src/hugr/serialize/test.rs index 069740459..12bfdb98f 100644 --- a/hugr/src/hugr/serialize/test.rs +++ b/hugr/src/hugr/serialize/test.rs @@ -404,10 +404,6 @@ fn roundtrip_sumtype(#[case] sum_type: SumType) { #[case(Value::extension(ConstF64::new(-1.5)))] #[case(Value::extension(ConstF64::new(0.0)))] #[case(Value::extension(ConstF64::new(-0.0)))] -// These cases fail -// #[case(Value::extension(ConstF64::new(std::f64::NAN)))] -// #[case(Value::extension(ConstF64::new(std::f64::INFINITY)))] -// #[case(Value::extension(ConstF64::new(std::f64::NEG_INFINITY)))] #[case(Value::extension(ConstF64::new(f64::MIN_POSITIVE)))] #[case(Value::sum(1,[Value::extension(ConstInt::new_u(2,1).unwrap())], SumType::new([vec![], vec![INT_TYPES[2].clone()]])).unwrap())] #[case(Value::tuple([Value::false_val(), Value::extension(ConstInt::new_s(2,1).unwrap())]))] diff --git a/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs b/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs index 0323bcfe6..a97a2d1c7 100644 --- a/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs +++ b/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs @@ -54,8 +54,11 @@ impl ConstFold for BinaryFold { consts: &[(IncomingPort, ops::Value)], ) -> ConstFoldResult { let [f1, f2] = get_floats(consts)?; - - let res = ConstF64::new((self.0)(f1, f2)); + let x: f64 = (self.0)(f1, f2); + if !x.is_finite() { + return None; + } + let res = ConstF64::new(x); Some(vec![(0.into(), res.into())]) } } @@ -115,7 +118,11 @@ impl ConstFold for UnaryFold { consts: &[(IncomingPort, ops::Value)], ) -> ConstFoldResult { let [f1] = get_floats(consts)?; - let res = ConstF64::new((self.0)(f1)); + let x: f64 = (self.0)(f1); + if !x.is_finite() { + return None; + } + let res = ConstF64::new(x); Some(vec![(0.into(), res.into())]) } } diff --git a/hugr/src/std_extensions/arithmetic/float_types.rs b/hugr/src/std_extensions/arithmetic/float_types.rs index b663f136f..713e2e030 100644 --- a/hugr/src/std_extensions/arithmetic/float_types.rs +++ b/hugr/src/std_extensions/arithmetic/float_types.rs @@ -40,7 +40,11 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Create a new [`ConstF64`] - pub const fn new(value: f64) -> Self { + pub fn new(value: f64) -> Self { + // This function can't be `const` because `is_finite()` is not yet stable as a const function. + if !value.is_finite() { + panic!("ConstF64 must have a finite value."); + } Self { value } } diff --git a/specification/hugr.md b/specification/hugr.md index ace4410bd..0bc912266 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -1782,6 +1782,9 @@ Other operations: The `float64` type represents IEEE 754-2019 floating-point data of 64 bits. +Non-finite `float64` values (i.e. NaN and ±infinity) are not allowed in `Const` +nodes. + #### `arithmetic.float` Floating-point operations are defined as follows. All operations below