Skip to content

Commit

Permalink
fix(mmd): numerical instability
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 10, 2024
1 parent 6fbf884 commit 5dc15e8
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions neuralnetlib/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,11 +504,9 @@ def gaussian_kernel(x: np.ndarray, y: np.ndarray, sigma: float) -> np.ndarray:
k_yy = gaussian_kernel(y_true, y_true, sigma)
k_xy = gaussian_kernel(y_pred, y_true, sigma)

xx_term = (np.sum(k_xx) - np.sum(np.diag(k_xx))) / \
(n * (n - 1)) if n > 1 else 0
yy_term = (np.sum(k_yy) - np.sum(np.diag(k_yy))) / \
(m * (m - 1)) if m > 1 else 0
xy_term = np.sum(k_xy) / (n * m)
xx_term = (np.sum(k_xx) - np.sum(np.diag(k_xx))) / (n * (n - 1) + 1e-8)
yy_term = (np.sum(k_yy) - np.sum(np.diag(k_yy))) / (m * (m - 1) + 1e-8)
xy_term = np.sum(k_xy) / (n * m + 1e-8)

return float(xx_term + yy_term - 2 * xy_term)

Expand Down

0 comments on commit 5dc15e8

Please sign in to comment.