diff --git a/etc/function-definitions.json b/etc/function-definitions.json index 39b6c970..fb3745cb 100644 --- a/etc/function-definitions.json +++ b/etc/function-definitions.json @@ -688,6 +688,7 @@ "src/libm_helper.rs", "src/math/arch/i686.rs", "src/math/arch/wasm32.rs", + "src/math/generic/sqrt.rs", "src/math/sqrt.rs" ], "type": "f64" @@ -696,6 +697,7 @@ "sources": [ "src/math/arch/i686.rs", "src/math/arch/wasm32.rs", + "src/math/generic/sqrt.rs", "src/math/sqrtf.rs" ], "type": "f32" diff --git a/src/math/generic/mod.rs b/src/math/generic/mod.rs index 08524b68..f3bd659e 100644 --- a/src/math/generic/mod.rs +++ b/src/math/generic/mod.rs @@ -1,5 +1,7 @@ mod copysign; mod fabs; +mod sqrt; pub use copysign::copysign; pub use fabs::fabs; +pub use sqrt::sqrt; diff --git a/src/math/generic/sqrt.rs b/src/math/generic/sqrt.rs new file mode 100644 index 00000000..b60952d1 --- /dev/null +++ b/src/math/generic/sqrt.rs @@ -0,0 +1,163 @@ +/* SPDX-License-Identifier: MIT */ +/* origin: musl src/math/sqrt.c. */ + +use core::ops; + +use super::super::support::{IntTy, cold_path, raise_invalid}; +use super::super::{CastFrom, CastInto, DInt, Float, HInt, Int, MinInt}; + +pub fn sqrt(x: F) -> F +where + F::Int: DInt + HInt + CastInto, + F::Int: ops::Rem, + u32: CastInto, + u8: CastInto, +{ + let zero = IntTy::::ZERO; + let one = IntTy::::ONE; + + let mut ix = x.to_bits(); + // Exponent and sign + let mut top = u32::cast_from(ix >> F::SIG_BITS); + + if top.wrapping_sub(1) >= F::EXP_MAX - 1 { + cold_path(); + // + if ix.overflowing_mul(f_int::(2u8)).0 == zero { + return x; + } + + // Positive infinity + if ix == F::EXP_MASK { + return x; + } + + // NaN or negative + if ix > F::EXP_MASK { + return raise_invalid(x); + } + + let scaled = x * F::from_parts(false, (F::SIG_BITS + F::EXP_BIAS) as i32, zero); + ix = scaled.to_bits(); + top = scaled.exp().unsigned(); + top = top.wrapping_sub(F::SIG_BITS); + } + + let even = (top & 1) != 0; + let mut m = (ix << F::EXP_BITS) | (one << (F::BITS - 1)); + if even { + m >>= 1; + } + top = (top.wrapping_add(F::EXP_MAX >> 1)) >> 1; + + // 32-bit three + let three = f_int::(0b11u8) << ((F::BITS / 2) - 2); + + let mut r: F::Int; + let mut s: F::Int; + let mut d: F::Int; + let mut u: F::Int; + let i: F::Int; + + // 17 for f32, 46 for f64 + i = (ix >> 46) % f_int::(128u8); + // i = (ix >> todo!()) % f_int::(128u8); + r = f_int::(RSQRT_TAB[usize::cast_from(i)]) << 16; + // TODO: can some of this casting back and forth be removed? + s = wmulh::(u32::cast_from(m >> 32), u32::cast_from(r)).cast(); + d = wmulh::(s.cast(), r.cast()).cast(); + u = three - d; + + r = wmulh::(r.cast(), u.cast()).cast() << 1; + s = wmulh::(s.cast(), u.cast()).cast() << 1; + d = wmulh::(s.cast(), r.cast()).cast(); + u = three - d; + r = wmulh::(r.cast(), u.cast()).cast() << 1; + /* |r sqrt(m) - 1| < 0x1.3704p-29 (measured worst-case) */ + r <<= 32; + s = mul64(m, r); + d = mul64(s, r); + u = (three << 32) - d; + s = mul64(s, u); /* repr: 3.61 */ + /* -0x1p-57 < s - sqrt(m) < 0x1.8001p-61 */ + s = (s - 2u8.cast()) >> 9; /* repr: 12.52 */ + // if F::BITS > 32 { + // } else { + // // + // } + // + // + // + + let d0: F::Int; + let d1: F::Int; + // let d2: F::Int; + + let y: F; + // let t: F; + + d0 = (m << 42).wrapping_sub(s.overflowing_mul(s).0); + d1 = s.wrapping_sub(d0); + // d2 = d1.wrapping_add(s).wrapping_add(one); + s += d1 >> F::BITS - 1; + s &= F::SIG_MASK; + // s &= 0x000fffffffffffff; + s |= f_int::(top) << F::SIG_BITS; + y = F::from_bits(s); + // if (FENV_SUPPORT) { + // /* handle rounding modes and inexact exception: + // only (s+1)^2 == 2^42 m case is exact otherwise + // add a tiny value to cause the fenv effects. */ + // uint64_t tiny = predict_false(d2==0) ? 0 : 0x0010000000000000; + // tiny |= (d1^d2) & 0x8000000000000000; + // t = asdouble(tiny); + // y = eval_as_double(y + t); + // } + y +} + +fn f_int(x: T) -> F::Int +where + F::Int: CastFrom, +{ + F::Int::cast_from(x) +} + +/// Widen multiply, returning the high half. +fn wmulh(a: I, b: I) -> I { + a.widen_mul(b).hi() +} + +fn mul64(a: I, b: I) -> I { + let ahi: I = a.hi().widen(); + let alo: I = a.lo().widen(); + let bhi: I = b.hi().widen(); + let blo: I = b.lo().widen(); + + (ahi * bhi) + (ahi * blo).hi().widen() + (alo * bhi).hi().widen() +} +// fn mul64(a: u32, b: u32) -> u32 { +// a.hi() & b.hi() + +// a.widen_mul(b).hi() +// } + +#[rustfmt::skip] +const RSQRT_TAB: [u16; 128] = [ + 0xb451,0xb2f0,0xb196,0xb044,0xaef9,0xadb6,0xac79,0xab43, + 0xaa14,0xa8eb,0xa7c8,0xa6aa,0xa592,0xa480,0xa373,0xa26b, + 0xa168,0xa06a,0x9f70,0x9e7b,0x9d8a,0x9c9d,0x9bb5,0x9ad1, + 0x99f0,0x9913,0x983a,0x9765,0x9693,0x95c4,0x94f8,0x9430, + 0x936b,0x92a9,0x91ea,0x912e,0x9075,0x8fbe,0x8f0a,0x8e59, + 0x8daa,0x8cfe,0x8c54,0x8bac,0x8b07,0x8a64,0x89c4,0x8925, + 0x8889,0x87ee,0x8756,0x86c0,0x862b,0x8599,0x8508,0x8479, + 0x83ec,0x8361,0x82d8,0x8250,0x81c9,0x8145,0x80c2,0x8040, + 0xff02,0xfd0e,0xfb25,0xf947,0xf773,0xf5aa,0xf3ea,0xf234, + 0xf087,0xeee3,0xed47,0xebb3,0xea27,0xe8a3,0xe727,0xe5b2, + 0xe443,0xe2dc,0xe17a,0xe020,0xdecb,0xdd7d,0xdc34,0xdaf1, + 0xd9b3,0xd87b,0xd748,0xd61a,0xd4f1,0xd3cd,0xd2ad,0xd192, + 0xd07b,0xcf69,0xce5b,0xcd51,0xcc4a,0xcb48,0xca4a,0xc94f, + 0xc858,0xc764,0xc674,0xc587,0xc49d,0xc3b7,0xc2d4,0xc1f4, + 0xc116,0xc03c,0xbf65,0xbe90,0xbdbe,0xbcef,0xbc23,0xbb59, + 0xba91,0xb9cc,0xb90a,0xb84a,0xb78c,0xb6d0,0xb617,0xb560, +]; diff --git a/src/math/sqrt.rs b/src/math/sqrt.rs index 2fd7070b..a9ac753d 100644 --- a/src/math/sqrt.rs +++ b/src/math/sqrt.rs @@ -90,6 +90,10 @@ pub fn sqrt(x: f64) -> f64 { args: x, } + if true { + return super::generic::sqrt(x); + } + use core::num::Wrapping; const TINY: f64 = 1.0e-300; @@ -226,12 +230,13 @@ pub fn sqrt(x: f64) -> f64 { #[cfg(test)] mod tests { + use super::super::Float; use super::*; #[test] fn sanity_check() { - assert_eq!(sqrt(100.0), 10.0); - assert_eq!(sqrt(4.0), 2.0); + assert_biteq!(sqrt(100.0), 10.0); + assert_biteq!(sqrt(4.0), 2.0); } /// The spec: https://en.cppreference.com/w/cpp/numeric/math/sqrt @@ -241,24 +246,27 @@ mod tests { assert!(sqrt(-1.0).is_nan()); assert!(sqrt(f64::NAN).is_nan()); for f in [0.0, -0.0, f64::INFINITY].iter().copied() { - assert_eq!(sqrt(f), f); + assert_biteq!(sqrt(f), f); } } #[test] #[allow(clippy::approx_constant)] fn conformance_tests() { - let values = [3.14159265359, 10000.0, f64::from_bits(0x0000000f), f64::INFINITY]; - let results = [ - 4610661241675116657u64, - 4636737291354636288u64, - 2197470602079456986u64, - 9218868437227405312u64, + let cases = [ + (3.14159265359, 4610661241675116657u64), + (10000.0, 4636737291354636288u64), + (f64::from_bits(0x0000000f), 2197470602079456986u64), + (f64::INFINITY, 9218868437227405312u64), ]; - for i in 0..values.len() { - let bits = f64::to_bits(sqrt(values[i])); - assert_eq!(results[i], bits); + for (input, output) in cases { + assert_biteq!( + sqrt(input), + f64::from_bits(output), + "input: {input} ({:#018x})", + input.to_bits() + ); } } } diff --git a/src/math/support/int_traits.rs b/src/math/support/int_traits.rs index db799c03..b8c33d73 100644 --- a/src/math/support/int_traits.rs +++ b/src/math/support/int_traits.rs @@ -55,10 +55,12 @@ pub trait Int: + ops::BitAnd + cmp::Ord + CastFrom + + CastFrom + CastFrom + CastFrom + CastFrom + CastInto + + CastInto + CastInto + CastInto + CastInto @@ -88,6 +90,7 @@ pub trait Int: fn wrapping_shr(self, other: u32) -> Self; fn rotate_left(self, other: u32) -> Self; fn overflowing_add(self, other: Self) -> (Self, bool); + fn overflowing_mul(self, other: Self) -> (Self, bool); fn leading_zeros(self) -> u32; fn ilog2(self) -> u32; } @@ -146,6 +149,10 @@ macro_rules! int_impl_common { ::overflowing_add(self, other) } + fn overflowing_mul(self, other: Self) -> (Self, bool) { + ::overflowing_mul(self, other) + } + fn leading_zeros(self) -> u32 { ::leading_zeros(self) } diff --git a/src/math/support/macros.rs b/src/math/support/macros.rs index 076fdf1f..bd5b09f9 100644 --- a/src/math/support/macros.rs +++ b/src/math/support/macros.rs @@ -110,19 +110,21 @@ macro_rules! hf64 { /// Assert `F::biteq` with better messages. #[cfg(test)] macro_rules! assert_biteq { - ($left:expr, $right:expr, $($arg:tt)*) => {{ - let bits = ($left.to_bits() * 0).leading_zeros(); // hack to get the width from the value + ($left:expr, $right:expr, $($tt:tt)*) => {{ + let l = $left; + let r = $right; + let bits = (l.to_bits() * 0).leading_zeros(); // hack to get the width from the value assert!( - $left.biteq($right), - "\nl: {l:?} ({lb:#0width$x})\nr: {r:?} ({rb:#0width$x})", - l = $left, - lb = $left.to_bits(), - r = $right, - rb = $right.to_bits(), - width = ((bits / 4) + 2) as usize + l.biteq(r), + "{}\nl: {l:?} ({lb:#0width$x})\nr: {r:?} ({rb:#0width$x})", + format_args!($($tt)*), + lb = l.to_bits(), + rb = r.to_bits(), + width = ((bits / 4) + 2) as usize, + ); }}; ($left:expr, $right:expr $(,)?) => { - assert_biteq!($left, $right,) + assert_biteq!($left, $right, "") }; } diff --git a/src/math/support/mod.rs b/src/math/support/mod.rs index e2f4e0e9..a80da908 100644 --- a/src/math/support/mod.rs +++ b/src/math/support/mod.rs @@ -10,3 +10,13 @@ pub(crate) use float_traits::{f32_from_bits, f64_from_bits}; #[allow(unused_imports)] pub use hex_float::{hf32, hf64}; pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt}; + +pub fn cold_path() { + #[cfg(intrinsics_enabled)] + core::intrinsics::cold_path(); +} + +/// Return `x`, first raising `FE_INVALID`. +pub fn raise_invalid(x: F) -> F { + (x - x) / (x - x) +}