Skip to content

Commit

Permalink
Add det curve (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
srzeszut authored Apr 17, 2024
1 parent 5c1786f commit 15c2eb5
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions lib/scholar/metrics/classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1486,4 +1486,70 @@ defmodule Scholar.Metrics.Classification do
)
end
end

@doc """
Compute error rates for different probability thresholds (DET).
Note: This metric is used for evaluation of ranking and error tradeoffs of
a binary classification task.
## Examples
iex> y_true = Nx.tensor([0, 0, 1, 1])
iex> y_score = Nx.tensor([0.1, 0.4, 0.35, 0.8])
iex> distinct_value_indices = Scholar.Metrics.Classification.distinct_value_indices(y_score)
iex> {fpr, fnr, thresholds} = Scholar.Metrics.Classification.det_curve(y_true, y_score, distinct_value_indices)
iex> fpr
#Nx.Tensor<
f32[4]
[1.0, 0.5, 0.5, 0.0]
>
iex> fnr
#Nx.Tensor<
f32[4]
[0.0, 0.0, 0.5, 0.5]
>
iex> thresholds
#Nx.Tensor<
f32[4]
[0.10000000149011612, 0.3499999940395355, 0.4000000059604645, 0.800000011920929]
>
iex> y_true = Nx.tensor([0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1])
iex> y_score = Nx.tensor([0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.5, 0.5, 0.6, 0.7, 0.7, 0.8, 0.9])
iex> distinct_value_indices = Scholar.Metrics.Classification.distinct_value_indices(y_score)
iex> {fpr, fnr, thresholds} = Scholar.Metrics.Classification.det_curve(y_true, y_score, distinct_value_indices)
iex> fpr
#Nx.Tensor<
f32[9]
[1.0, 0.6666666865348816, 0.6666666865348816, 0.6666666865348816, 0.3333333432674408, 0.3333333432674408, 0.1666666716337204, 0.1666666716337204, 0.0]
>
iex> fnr
#Nx.Tensor<
f32[9]
[0.0, 0.0, 0.2222222238779068, 0.4444444477558136, 0.4444444477558136, 0.6666666865348816, 0.6666666865348816, 0.8888888955116272, 0.8888888955116272]
>
iex> thresholds
#Nx.Tensor<
f32[9]
[0.10000000149011612, 0.20000000298023224, 0.30000001192092896, 0.4000000059604645, 0.5, 0.6000000238418579, 0.699999988079071, 0.800000011920929, 0.8999999761581421]
>
"""
defn det_curve(y_true, y_score, distinct_value_indices, weights \\ 1.0) do
num_samples = Nx.axis_size(y_true, 0)

weights = validate_weights(weights, num_samples, type: to_float_type(y_true))

check_shape(y_true, y_score)

{fps, tps, thresholds} =
binary_clf_curve(y_true, y_score, distinct_value_indices, weights)

positive_count = tps[[-1]]
negative_count = fps[[-1]]

fns = positive_count - tps

{Nx.reverse(fps) / negative_count, Nx.reverse(fns) / positive_count, Nx.reverse(thresholds)}
end
end

0 comments on commit 15c2eb5

Please sign in to comment.