diff --git a/neuralnetlib/metrics.py b/neuralnetlib/metrics.py index dadc43f..6f7501b 100644 --- a/neuralnetlib/metrics.py +++ b/neuralnetlib/metrics.py @@ -504,11 +504,9 @@ def gaussian_kernel(x: np.ndarray, y: np.ndarray, sigma: float) -> np.ndarray: k_yy = gaussian_kernel(y_true, y_true, sigma) k_xy = gaussian_kernel(y_pred, y_true, sigma) - xx_term = (np.sum(k_xx) - np.sum(np.diag(k_xx))) / \ - (n * (n - 1)) if n > 1 else 0 - yy_term = (np.sum(k_yy) - np.sum(np.diag(k_yy))) / \ - (m * (m - 1)) if m > 1 else 0 - xy_term = np.sum(k_xy) / (n * m) + xx_term = (np.sum(k_xx) - np.sum(np.diag(k_xx))) / (n * (n - 1) + 1e-8) + yy_term = (np.sum(k_yy) - np.sum(np.diag(k_yy))) / (m * (m - 1) + 1e-8) + xy_term = np.sum(k_xy) / (n * m + 1e-8) return float(xx_term + yy_term - 2 * xy_term)