Skip to content

Commit

Permalink
fix: Disallow non-finite values for ConstF64 (#1075)
Browse files Browse the repository at this point in the history
Fixes #1049 .

Other solutions (custom serialization/deserialization for `ConstF64`, or
custom deserializer taking "null" to NaN) are difficult or otherwise
problematic, and it may actually be a good thing to disallow non-finite
float constants as it can catch bugs earlier.
  • Loading branch information
cqc-alec authored May 17, 2024
1 parent 4d974cf commit 4dba950
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 8 deletions.
31 changes: 31 additions & 0 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,4 +512,35 @@ mod test {
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}

#[test]
fn test_const_fold_to_nonfinite() {
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
arithmetic::float_types::EXTENSION.to_owned(),
])
.unwrap();

// HUGR computing 1.0 / 1.0
let mut build =
DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0)));
let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0)));
let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap();
let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), &reg).unwrap();
constant_fold_pass(&mut h0, &reg);
let expected = Value::extension(ConstF64::new(1.0));
assert_fully_folded(&h0, &expected);
assert_eq!(h0.node_count(), 5);

// HUGR computing 1.0 / 0.0
let mut build =
DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0)));
let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0)));
let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap();
let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), &reg).unwrap();
constant_fold_pass(&mut h1, &reg);
assert_eq!(h1.node_count(), 8);
}
}
4 changes: 0 additions & 4 deletions hugr/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,6 @@ fn roundtrip_sumtype(#[case] sum_type: SumType) {
#[case(Value::extension(ConstF64::new(-1.5)))]
#[case(Value::extension(ConstF64::new(0.0)))]
#[case(Value::extension(ConstF64::new(-0.0)))]
// These cases fail
// #[case(Value::extension(ConstF64::new(std::f64::NAN)))]
// #[case(Value::extension(ConstF64::new(std::f64::INFINITY)))]
// #[case(Value::extension(ConstF64::new(std::f64::NEG_INFINITY)))]
#[case(Value::extension(ConstF64::new(f64::MIN_POSITIVE)))]
#[case(Value::sum(1,[Value::extension(ConstInt::new_u(2,1).unwrap())], SumType::new([vec![], vec![INT_TYPES[2].clone()]])).unwrap())]
#[case(Value::tuple([Value::false_val(), Value::extension(ConstInt::new_s(2,1).unwrap())]))]
Expand Down
13 changes: 10 additions & 3 deletions hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ impl ConstFold for BinaryFold {
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
let [f1, f2] = get_floats(consts)?;

let res = ConstF64::new((self.0)(f1, f2));
let x: f64 = (self.0)(f1, f2);
if !x.is_finite() {
return None;
}
let res = ConstF64::new(x);
Some(vec![(0.into(), res.into())])
}
}
Expand Down Expand Up @@ -115,7 +118,11 @@ impl ConstFold for UnaryFold {
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
let [f1] = get_floats(consts)?;
let res = ConstF64::new((self.0)(f1));
let x: f64 = (self.0)(f1);
if !x.is_finite() {
return None;
}
let res = ConstF64::new(x);
Some(vec![(0.into(), res.into())])
}
}
Expand Down
6 changes: 5 additions & 1 deletion hugr/src/std_extensions/arithmetic/float_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ impl std::ops::Deref for ConstF64 {

impl ConstF64 {
/// Create a new [`ConstF64`]
pub const fn new(value: f64) -> Self {
pub fn new(value: f64) -> Self {
// This function can't be `const` because `is_finite()` is not yet stable as a const function.
if !value.is_finite() {
panic!("ConstF64 must have a finite value.");
}
Self { value }
}

Expand Down
3 changes: 3 additions & 0 deletions specification/hugr.md
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,9 @@ Other operations:
The `float64` type represents IEEE 754-2019 floating-point data of 64
bits.

Non-finite `float64` values (i.e. NaN and ±infinity) are not allowed in `Const`
nodes.

#### `arithmetic.float`

Floating-point operations are defined as follows. All operations below
Expand Down

0 comments on commit 4dba950

Please sign in to comment.