diff --git a/src/elexsolver/TransitionSolver.py b/src/elexsolver/TransitionSolver.py index 3eee22fd..0c98b597 100644 --- a/src/elexsolver/TransitionSolver.py +++ b/src/elexsolver/TransitionSolver.py @@ -11,15 +11,13 @@ LOG = logging.getLogger(__name__) -def mean_absolute_error(Y_expected: np.ndarray, Y_pred: np.ndarray): +def mean_absolute_error(Y_expected: np.ndarray, Y_pred: np.ndarray, weights: np.ndarray | None = None): if isinstance(Y_expected, list): Y_expected = np.array(Y_expected) if isinstance(Y_pred, list): Y_pred = np.array(Y_pred) - absolute_errors = np.abs(Y_pred - Y_expected) - error_sum = np.sum(absolute_errors) - return error_sum / len(absolute_errors) + return np.average(np.abs(Y_expected - Y_pred), weights=weights) def weighted_absolute_percentage_error(Y_expected: np.ndarray, Y_pred: np.ndarray):