From 61d2f01f86e3da1a85239ffa3f5aa54d844e3a98 Mon Sep 17 00:00:00 2001 From: Diane Napolitano Date: Wed, 27 Dec 2023 22:00:04 -0500 Subject: [PATCH] Adding option to compute MAE with sample weights --- src/elexsolver/TransitionSolver.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/elexsolver/TransitionSolver.py b/src/elexsolver/TransitionSolver.py index 3eee22fd..0c98b597 100644 --- a/src/elexsolver/TransitionSolver.py +++ b/src/elexsolver/TransitionSolver.py @@ -11,15 +11,13 @@ LOG = logging.getLogger(__name__) -def mean_absolute_error(Y_expected: np.ndarray, Y_pred: np.ndarray): +def mean_absolute_error(Y_expected: np.ndarray, Y_pred: np.ndarray, weights: np.ndarray | None = None): if isinstance(Y_expected, list): Y_expected = np.array(Y_expected) if isinstance(Y_pred, list): Y_pred = np.array(Y_pred) - absolute_errors = np.abs(Y_pred - Y_expected) - error_sum = np.sum(absolute_errors) - return error_sum / len(absolute_errors) + return np.average(np.abs(Y_expected - Y_pred), weights=weights) def weighted_absolute_percentage_error(Y_expected: np.ndarray, Y_pred: np.ndarray):