Skip to content

Commit

Permalink
Adding option to compute MAE with sample weights
Browse files Browse the repository at this point in the history
  • Loading branch information
dmnapolitano committed Dec 28, 2023
1 parent 27a92be commit 61d2f01
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/elexsolver/TransitionSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
LOG = logging.getLogger(__name__)


def mean_absolute_error(Y_expected: np.ndarray, Y_pred: np.ndarray):
def mean_absolute_error(Y_expected: np.ndarray, Y_pred: np.ndarray, weights: np.ndarray | None = None):
if isinstance(Y_expected, list):
Y_expected = np.array(Y_expected)
if isinstance(Y_pred, list):
Y_pred = np.array(Y_pred)

absolute_errors = np.abs(Y_pred - Y_expected)
error_sum = np.sum(absolute_errors)
return error_sum / len(absolute_errors)
return np.average(np.abs(Y_expected - Y_pred), weights=weights)


def weighted_absolute_percentage_error(Y_expected: np.ndarray, Y_pred: np.ndarray):
Expand Down

0 comments on commit 61d2f01

Please sign in to comment.