diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index 03c926fc..7a1f33c8 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -1262,8 +1262,10 @@ defmodule Scholar.Metrics.Classification do each class, from which the log loss is computed by averaging the negative log of the probability forecasted for the true class over a number of samples. - `y_true` should contain `num_classes` unique values, and the sum of `y_prob` - along axis 1 should be 1 to respect the law of total probability. + `y_true` should be a tensor of shape {num_samples} containing values + between 0 and num_classes - 1 (inclusive). + `y_prob` should be a tensor of shape {num_samples, num_classes} containing + predicted probability distributions over classes for each sample. ## Options @@ -1320,7 +1322,7 @@ defmodule Scholar.Metrics.Classification do type: to_float_type(y_prob) ) - y_true_onehot = + y_one_hot = y_true |> Nx.new_axis(1) |> Nx.broadcast({num_samples, num_classes}) @@ -1329,7 +1331,7 @@ defmodule Scholar.Metrics.Classification do y_prob = Nx.clip(y_prob, 0, 1) sample_loss = - Nx.multiply(y_true_onehot, y_prob) + Nx.multiply(y_one_hot, y_prob) |> Nx.sum(axes: [-1]) |> Nx.log() |> Nx.negate() diff --git a/lib/scholar/naive_bayes/complement.ex b/lib/scholar/naive_bayes/complement.ex index fcdcc633..7204b0cf 100644 --- a/lib/scholar/naive_bayes/complement.ex +++ b/lib/scholar/naive_bayes/complement.ex @@ -94,7 +94,8 @@ defmodule Scholar.NaiveBayes.Complement do @doc """ Fits a complement naive Bayes classifier. The function assumes that the targets `y` are integers - between 0 and `num_classes` - 1 (inclusive). + between 0 and `num_classes` - 1 (inclusive). Otherwise, those samples will not + contribute to `class_count`. ## Options