From 7d53e0d2d00152d22defd906a18ef5659838402d Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 17 Oct 2023 20:33:12 -0400 Subject: [PATCH] Still root causing the fused l2 nn issue --- .../raft/distance/detail/distance_ops/l2_exp.cuh | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index c8a9798590..2f14df328e 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -32,11 +32,15 @@ struct l2_exp_cutlass_op { { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; + if (raft::sqrt(outVal) == 0.002918735) { + printf("aNorm: %lf, bNorm:%lf, acc: %lf, outVal: %lf\n", aNorm, bNorm, accVal, outVal); + } + /** * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal) * can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead. */ - outVal = outVal * (outVal > 1e-4 && !(aNorm == bNorm && accVal > 0.0)); + outVal = outVal * (raft::abs(outVal) >= sizeof(DataT) == 4 ? 1e-5 : 1e-14); return sqrt ? raft::sqrt(outVal) : outVal; } @@ -91,12 +95,16 @@ struct l2_exp_distance_op { DataT accVal = acc[i][j]; DataT val = regxn[i] + regyn[j] - (DataT)2.0 * accVal; + if (regxn[i] == regyn[j]) { + printf("aNorm: %lf, bNorm:%lf, acc: %lf, outVal: %lf\n", regxn[i], regyn[j], accVal, val); + } + /** * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product * (accVal) can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal * instead. */ - acc[i][j] = val * (val >= 1e-4 && !(regxn[i] == regyn[j] && accVal > 0.0)); + acc[i][j] = val * (raft::abs(val) >= sizeof(DataT) == 4 ? 1e-5 : 1e-14); } } if (sqrt) {