Skip to content

Commit

Permalink
Update of metrics.
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuellujan committed Dec 22, 2023
1 parent 897fcf1 commit ee4f5c1
Showing 1 changed file with 79 additions and 1 deletion.
80 changes: 79 additions & 1 deletion src/Metrics/metrics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,81 @@
export calc_metrics, get_metrics
export calc_metrics, get_metrics, mae, rmse, rsq, mean_cos

"""
mae(x_pred, x)
`x_pred`: vector of predicted values. E.g. predicted energies.
`x`: vector of true values. E.g. DFT energies.
Returns mean absolute error.
"""
function mae(x_pred, x)
return sum(abs.(x_pred .- x)) / length(x)
end

"""
rmse(x_pred, x)
`x_pred`: vector of predicted values. E.g. predicted energies.
`x`: vector of true values. E.g. DFT energies.
Returns mean root mean square error.
"""
function rmse(x_pred, x)
return sqrt(sum((x_pred .- x) .^ 2) / length(x))
end

"""
rsq(x_pred, x)
`x_pred`: vector of predicted values. E.g. predicted energies.
`x`: vector of true values. E.g. DFT energies.
Returns R-squared.
"""
function rsq(x_pred, x)
return 1 - sum((x_pred .- x) .^ 2) / sum((x .- mean(x)) .^ 2)
end

"""
mean_cos(x_pred, x)
`x_pred`: vector of predicted forces,
`x`: vector of true forces.
Returns mean cosine.
"""
function mean_cos(x_pred, x)
x_pred_v = collect(eachcol(reshape(x_pred, 3, :)))
x_v = collect(eachcol(reshape(x, 3, :)))
x_cos = dot.(x_v, x_pred_v) ./ (norm.(x_v) .* norm.(x_pred_v))
x_mean_cos = mean(filter(!isnan, x_cos))
return x_mean_cos
end

"""
get_metrics(
x_pred,
x;
metrics = [mae, rmse, rsq],
label = "x"
)
`x_pred`: vector of predicted forces,
`x`: vector of true forces.
`metrics`: vector of metrics.
`label`: label used as prefix in dictionary keys.
Returns and OrderedDict with different metrics.
"""
function get_metrics(
x_pred,
x;
metrics = [mae, rmse, rsq],
label = "x"
)
return OrderedDict( "$(label)_$(Symbol(m))" => m(x_pred, x)
for m in metrics)
end


"""
Expand All @@ -17,6 +94,7 @@ function calc_metrics(x_pred, x)
return x_mae, x_rmse, x_rsq
end


"""
get_metrics( e_train_pred, e_train, e_test_pred, e_test)
Expand Down

0 comments on commit ee4f5c1

Please sign in to comment.