diff --git a/dcor/_dcor_internals.py b/dcor/_dcor_internals.py index af87525..21b390d 100644 --- a/dcor/_dcor_internals.py +++ b/dcor/_dcor_internals.py @@ -184,8 +184,8 @@ def _dcov_from_terms( ) -> Array: """Compute distance covariance WITHOUT centering first.""" first_term = mean_prod / n_samples - second_term = a_axis_sum @ b_axis_sum / n_samples - third_term = a_total_sum * b_total_sum / n_samples + second_term = a_axis_sum / n_samples @ b_axis_sum + third_term = a_total_sum / n_samples * b_total_sum if bias_corrected: first_term /= (n_samples - 3) diff --git a/dcor/tests/test_dcor.py b/dcor/tests/test_dcor.py index 2b7d981..d2a4fd4 100644 --- a/dcor/tests/test_dcor.py +++ b/dcor/tests/test_dcor.py @@ -512,6 +512,35 @@ def test_dcor_constant(self) -> None: corr_af_inv = dcor.distance_correlation_af_inv(a, a) self.assertAlmostEqual(corr_af_inv, 0) + def test_integer_overflow(self) -> None: + """Tests int overflow behavior detected in issue #59.""" + n_samples = 10000 + + # some simple data + arr1 = np.array([1, 2, 3] * n_samples) + arr2 = np.array([10, 20, 5] * n_samples) + + int_int = dcor.distance_correlation( + arr1, + arr2, + ) + float_int = dcor.distance_correlation( + arr1.astype(float), + arr2, + ) + int_float = dcor.distance_correlation( + arr1, + arr2.astype(float), + ) + float_float = dcor.distance_correlation( + arr1.astype(float), + arr2.astype(float), + ) + + self.assertAlmostEqual(int_int, float_float) + self.assertAlmostEqual(float_int, float_float) + self.assertAlmostEqual(int_float, float_float) + class TestDcorArrayAPI(unittest.TestCase): """Check that the energy distance works with the Array API standard."""