From ee4f5c1beb3cfbed17babfd01d872aa6632c74d7 Mon Sep 17 00:00:00 2001 From: Emmanuel Lujan Date: Thu, 21 Dec 2023 20:19:20 -0500 Subject: [PATCH] Update of metrics. --- src/Metrics/metrics.jl | 80 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/src/Metrics/metrics.jl b/src/Metrics/metrics.jl index f16b0bb8..1046e882 100644 --- a/src/Metrics/metrics.jl +++ b/src/Metrics/metrics.jl @@ -1,4 +1,81 @@ -export calc_metrics, get_metrics +export calc_metrics, get_metrics, mae, rmse, rsq, mean_cos + +""" + mae(x_pred, x) + +`x_pred`: vector of predicted values. E.g. predicted energies. +`x`: vector of true values. E.g. DFT energies. + +Returns mean absolute error. +""" +function mae(x_pred, x) + return sum(abs.(x_pred .- x)) / length(x) +end + +""" + rmse(x_pred, x) + +`x_pred`: vector of predicted values. E.g. predicted energies. +`x`: vector of true values. E.g. DFT energies. + +Returns mean root mean square error. +""" +function rmse(x_pred, x) + return sqrt(sum((x_pred .- x) .^ 2) / length(x)) +end + +""" + rsq(x_pred, x) + +`x_pred`: vector of predicted values. E.g. predicted energies. +`x`: vector of true values. E.g. DFT energies. + +Returns R-squared. +""" +function rsq(x_pred, x) + return 1 - sum((x_pred .- x) .^ 2) / sum((x .- mean(x)) .^ 2) +end + +""" + mean_cos(x_pred, x) + +`x_pred`: vector of predicted forces, +`x`: vector of true forces. + +Returns mean cosine. +""" +function mean_cos(x_pred, x) + x_pred_v = collect(eachcol(reshape(x_pred, 3, :))) + x_v = collect(eachcol(reshape(x, 3, :))) + x_cos = dot.(x_v, x_pred_v) ./ (norm.(x_v) .* norm.(x_pred_v)) + x_mean_cos = mean(filter(!isnan, x_cos)) + return x_mean_cos +end + +""" + get_metrics( + x_pred, + x; + metrics = [mae, rmse, rsq], + label = "x" + ) + +`x_pred`: vector of predicted forces, +`x`: vector of true forces. +`metrics`: vector of metrics. +`label`: label used as prefix in dictionary keys. + +Returns and OrderedDict with different metrics. +""" +function get_metrics( + x_pred, + x; + metrics = [mae, rmse, rsq], + label = "x" +) + return OrderedDict( "$(label)_$(Symbol(m))" => m(x_pred, x) + for m in metrics) +end """ @@ -17,6 +94,7 @@ function calc_metrics(x_pred, x) return x_mae, x_rmse, x_rsq end + """ get_metrics( e_train_pred, e_train, e_test_pred, e_test)