Skip to content

Commit

Permalink
apply code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
srzeszut committed Apr 16, 2024
1 parent 1563c94 commit fab020e
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions lib/scholar/metrics/classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1517,8 +1517,8 @@ defmodule Scholar.Metrics.Classification do
iex> y_true = Nx.tensor([0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1])
iex> y_score = Nx.tensor([0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.5, 0.5, 0.6, 0.7, 0.7, 0.8, 0.9])
iex> distinct_value_indices = Scholar.Metrics.Classification1.distinct_value_indices(y_score)
iex> {fpr, fnr, thresholds} = Scholar.Metrics.Classification1.det_curve(y_true, y_score, distinct_value_indices)
iex> distinct_value_indices = Scholar.Metrics.Classification.distinct_value_indices(y_score)
iex> {fpr, fnr, thresholds} = Scholar.Metrics.Classification.det_curve(y_true, y_score, distinct_value_indices)
iex> fpr
#Nx.Tensor<
f32[9]
Expand All @@ -1536,7 +1536,6 @@ defmodule Scholar.Metrics.Classification do
>
"""
defn det_curve(y_true, y_score, distinct_value_indices, weights \\ 1.0) do

num_samples = Nx.axis_size(y_true, 0)

weights = validate_weights(weights, num_samples, type: to_float_type(y_true))
Expand All @@ -1546,12 +1545,11 @@ defmodule Scholar.Metrics.Classification do
{fps, tps, thresholds} =
binary_clf_curve(y_true, y_score, distinct_value_indices, weights)

positive_count= tps[[-1]]
positive_count = tps[[-1]]
negative_count = fps[[-1]]

fns = positive_count - tps

{Nx.reverse(fps)/ negative_count, Nx.reverse(fns)/ positive_count, Nx.reverse(thresholds)}
{Nx.reverse(fps) / negative_count, Nx.reverse(fns) / positive_count, Nx.reverse(thresholds)}
end

end

0 comments on commit fab020e

Please sign in to comment.