Skip to content

Commit

Permalink
Still root causing the fused l2 nn issue
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Oct 18, 2023
1 parent f52765f commit 7d53e0d
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ struct l2_exp_cutlass_op {
{
AccT outVal = aNorm + bNorm - DataT(2.0) * accVal;

if (raft::sqrt(outVal) == 0.002918735) {
printf("aNorm: %lf, bNorm:%lf, acc: %lf, outVal: %lf\n", aNorm, bNorm, accVal, outVal);
}

/**
* Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal)
* can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead.
*/
outVal = outVal * (outVal > 1e-4 && !(aNorm == bNorm && accVal > 0.0));
outVal = outVal * (raft::abs(outVal) >= sizeof(DataT) == 4 ? 1e-5 : 1e-14);
return sqrt ? raft::sqrt(outVal) : outVal;
}

Expand Down Expand Up @@ -91,12 +95,16 @@ struct l2_exp_distance_op {
DataT accVal = acc[i][j];
DataT val = regxn[i] + regyn[j] - (DataT)2.0 * accVal;

if (regxn[i] == regyn[j]) {
printf("aNorm: %lf, bNorm:%lf, acc: %lf, outVal: %lf\n", regxn[i], regyn[j], accVal, val);
}

/**
* Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product
* (accVal) can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal
* instead.
*/
acc[i][j] = val * (val >= 1e-4 && !(regxn[i] == regyn[j] && accVal > 0.0));
acc[i][j] = val * (raft::abs(val) >= sizeof(DataT) == 4 ? 1e-5 : 1e-14);
}
}
if (sqrt) {
Expand Down

0 comments on commit 7d53e0d

Please sign in to comment.