From cf817235207d6ccce68c8cd013910f9a02529f02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Wed, 31 Jul 2024 16:07:39 +0200 Subject: [PATCH 1/6] Update --- lib/scholar/linear/logistic_regression.ex | 13 +- lib/scholar/metrics/classification.ex | 7 +- lib/scholar/naive_bayes/complement.ex | 72 +++++--- lib/scholar/naive_bayes/multinomial.ex | 3 +- lib/scholar/neighbors/large_vis.ex | 2 +- lib/scholar/preprocessing.ex | 12 +- lib/scholar/preprocessing/one_hot_encoder.ex | 90 ++++----- lib/scholar/preprocessing/ordinal_encoder.ex | 183 ++++++++++++------- 8 files changed, 236 insertions(+), 146 deletions(-) diff --git a/lib/scholar/linear/logistic_regression.ex b/lib/scholar/linear/logistic_regression.ex index 3c62c9fa..1cc42dbc 100644 --- a/lib/scholar/linear/logistic_regression.ex +++ b/lib/scholar/linear/logistic_regression.ex @@ -143,17 +143,22 @@ defmodule Scholar.Linear.LogisticRegression do # Logistic Regression training loop defnp fit_n(x, y, coef, bias, coef_optimizer_state, bias_optimizer_state, opts) do + num_samples = Nx.axis_size(x, 0) iterations = opts[:iterations] num_classes = opts[:num_classes] optimizer_update_fn = opts[:optimizer_update_fn] - y = Scholar.Preprocessing.one_hot_encode(y, num_classes: num_classes) + y_one_hot = + y + |> Nx.new_axis(1) + |> Nx.broadcast({num_samples, num_classes}) + |> Nx.equal(Nx.iota({num_samples, num_classes}, axis: 1)) {{final_coef, final_bias}, _} = while {{coef, bias}, - {x, iterations, y, coef_optimizer_state, bias_optimizer_state, + {x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state, has_converged = Nx.u8(0), iter = 0}}, iter < iterations and not has_converged do - {loss, {coef_grad, bias_grad}} = loss_and_grad(coef, bias, x, y) + {loss, {coef_grad, bias_grad}} = loss_and_grad(coef, bias, x, y_one_hot) {coef_updates, coef_optimizer_state} = optimizer_update_fn.(coef_grad, coef_optimizer_state, coef) @@ -168,7 +173,7 @@ defmodule Scholar.Linear.LogisticRegression do has_converged = Nx.sum(Nx.abs(loss)) < Nx.size(x) * opts[:eps] {{coef, bias}, - {x, iterations, y, coef_optimizer_state, bias_optimizer_state, has_converged, iter + 1}} + {x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state, has_converged, iter + 1}} end %__MODULE__{ diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index c7ccb8fa..4031f988 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -1307,6 +1307,7 @@ defmodule Scholar.Metrics.Classification do raise ArgumentError, "y_true and y_prob must have the same size along axis 0" end + num_samples = Nx.size(y_true) num_classes = opts[:num_classes] if Nx.axis_size(y_prob, 1) != num_classes do @@ -1321,8 +1322,10 @@ defmodule Scholar.Metrics.Classification do ) y_true_onehot = - ordinal_encode(y_true, num_classes: num_classes) - |> one_hot_encode(num_classes: num_classes) + y_true + |> Nx.new_axis(1) + |> Nx.broadcast({num_samples, num_classes}) + |> Nx.equal(Nx.iota({num_samples, num_classes}, axis: 1)) y_prob = Nx.clip(y_prob, 0, 1) diff --git a/lib/scholar/naive_bayes/complement.ex b/lib/scholar/naive_bayes/complement.ex index 4ad69dec..a9f7e759 100644 --- a/lib/scholar/naive_bayes/complement.ex +++ b/lib/scholar/naive_bayes/complement.ex @@ -11,7 +11,7 @@ defmodule Scholar.NaiveBayes.Complement do Reference: - * [1] - [Paper about Complement Naive Bayes Algorithm](https://cdn.aaai.org/ICML/2003/ICML03-081.pdf) + * [1] [Tackling the Poor Assumptions of Naive Bayes Text Classifiers](https://cdn.aaai.org/ICML/2003/ICML03-081.pdf) """ import Nx.Defn import Scholar.Shared @@ -93,8 +93,8 @@ defmodule Scholar.NaiveBayes.Complement do @opts_schema NimbleOptions.new!(opts_schema) @doc """ - The multinomial Naive Bayes classifier is suitable for classification with - discrete features (e.g., word counts for text classification) + Fits a complement naive Bayes classifier. The function assumes that the targets `y` are integers + between 0 and `num_classes` - 1 (inclusive). ## Options @@ -340,35 +340,49 @@ defmodule Scholar.NaiveBayes.Complement do classes_encoded = Nx.iota({num_classes}) - classes = + y_one_hot = y - |> Scholar.Preprocessing.ordinal_encode(num_classes: num_classes) - |> Scholar.Preprocessing.one_hot_encode(num_classes: num_classes) + |> Nx.new_axis(1) + |> Nx.broadcast({num_samples, num_classes}) + |> Nx.equal(Nx.iota({num_samples, num_classes}, axis: 1)) + |> Nx.as_type(x_type) - {_, classes_features} = classes_shape = Nx.shape(classes) - - classes = - cond do - classes_features == 1 and num_classes == 2 -> - Nx.concatenate([1 - classes, classes], axis: 1) - - classes_features == 1 and num_classes != 2 -> - Nx.broadcast(1.0, classes_shape) - - true -> - classes - end - - classes = + y_weighted = if opts[:sample_weights_flag], - do: classes * Nx.reshape(sample_weights, {:auto, 1}), - else: classes - - {_, n_classes} = Nx.shape(classes) - class_count = Nx.broadcast(Nx.tensor(0.0, type: x_type), {n_classes}) - feature_count = Nx.broadcast(Nx.tensor(0.0, type: x_type), {n_classes, num_features}) - feature_count = feature_count + Nx.dot(classes, [0], x, [0]) - class_count = class_count + Nx.sum(classes, axes: [0]) + do: Nx.reshape(sample_weights, {num_samples, 1}) * y_one_hot, + else: y_one_hot + + # classes = + # y + # |> Scholar.Preprocessing.ordinal_encode(num_classes: num_classes) + # |> Scholar.Preprocessing.one_hot_encode(num_classes: num_classes) + + # {_, classes_features} = classes_shape = Nx.shape(classes) + + # classes = + # cond do + # classes_features == 1 and num_classes == 2 -> + # Nx.concatenate([1 - classes, classes], axis: 1) + + # classes_features == 1 and num_classes != 2 -> + # Nx.broadcast(1.0, classes_shape) + + # true -> + # classes + # end + + # classes = + # if opts[:sample_weights_flag], + # do: classes * Nx.reshape(sample_weights, {:auto, 1}), + # else: classes + + # {_, n_classes} = Nx.shape(classes) + # class_count = Nx.broadcast(Nx.tensor(0.0, type: x_type), {n_classes}) + # class_count = class_count + Nx.sum(classes, axes: [0]) + class_count = Nx.sum(y_weighted, axes: [0]) + # feature_count = Nx.broadcast(Nx.tensor(0.0, type: x_type), {n_classes, num_features}) + # feature_count = feature_count + Nx.dot(classes, [0], x, [0]) + feature_count = Nx.dot(y_weighted, [0], x, [0]) feature_all = Nx.sum(feature_count, axes: [0]) alpha = check_alpha(alpha, opts[:force_alpha], num_features) complement_count = feature_all + alpha - feature_count diff --git a/lib/scholar/naive_bayes/multinomial.ex b/lib/scholar/naive_bayes/multinomial.ex index de2bd1bc..f3c81151 100644 --- a/lib/scholar/naive_bayes/multinomial.ex +++ b/lib/scholar/naive_bayes/multinomial.ex @@ -72,7 +72,7 @@ defmodule Scholar.NaiveBayes.Multinomial do @opts_schema NimbleOptions.new!(opts_schema) @doc """ - Fits a naive Bayes model. The function assumes that targets `y` are integers + Fits a naive Bayes model. The function assumes that the targets `y` are integers between 0 and `num_classes` - 1 (inclusive). Otherwise, those samples will not contribute to `class_count`. @@ -239,6 +239,7 @@ defmodule Scholar.NaiveBayes.Multinomial do y_one_hot = y |> Nx.new_axis(1) + |> Nx.broadcast({num_samples, num_classes}) |> Nx.equal(Nx.iota({num_samples, num_classes}, axis: 1)) |> Nx.as_type(type) diff --git a/lib/scholar/neighbors/large_vis.ex b/lib/scholar/neighbors/large_vis.ex index 076a9c9c..1fc15b99 100644 --- a/lib/scholar/neighbors/large_vis.ex +++ b/lib/scholar/neighbors/large_vis.ex @@ -45,7 +45,7 @@ defmodule Scholar.Neighbors.LargeVis do ], num_iters: [ type: :non_neg_integer, - default: 1, + default: 3, doc: "The number of times to perform neighborhood expansion." ], key: [ diff --git a/lib/scholar/preprocessing.ex b/lib/scholar/preprocessing.ex index 79f04bbf..ccdfb56a 100644 --- a/lib/scholar/preprocessing.ex +++ b/lib/scholar/preprocessing.ex @@ -145,14 +145,14 @@ defmodule Scholar.Preprocessing do ## Examples - iex> Scholar.Preprocessing.ordinal_encode(Nx.tensor([3, 2, 4, 56, 2, 4, 2]), num_classes: 4) + iex> Scholar.Preprocessing.ordinal_encode(Nx.tensor([3, 2, 4, 56, 2, 4, 2])) #Nx.Tensor< - s64[7] + u64[7] [1, 0, 2, 3, 0, 2, 0] > """ - defn ordinal_encode(tensor, opts \\ []) do - Scholar.Preprocessing.OrdinalEncoder.fit_transform(tensor, opts) + defn ordinal_encode(tensor) do + Scholar.Preprocessing.OrdinalEncoder.fit_transform(tensor) end @doc """ @@ -161,7 +161,7 @@ defmodule Scholar.Preprocessing do ## Examples - iex> Scholar.Preprocessing.one_hot_encode(Nx.tensor([2, 0, 3, 2, 1, 1, 0]), num_classes: 4) + iex> Scholar.Preprocessing.one_hot_encode(Nx.tensor([2, 0, 3, 2, 1, 1, 0]), num_categories: 4) #Nx.Tensor< u8[7][4] [ @@ -175,7 +175,7 @@ defmodule Scholar.Preprocessing do ] > """ - defn one_hot_encode(tensor, opts \\ []) do + defn one_hot_encode(tensor, opts) do Scholar.Preprocessing.OneHotEncoder.fit_transform(tensor, opts) end diff --git a/lib/scholar/preprocessing/one_hot_encoder.ex b/lib/scholar/preprocessing/one_hot_encoder.ex index 79a56ac4..8bf33457 100644 --- a/lib/scholar/preprocessing/one_hot_encoder.ex +++ b/lib/scholar/preprocessing/one_hot_encoder.ex @@ -7,15 +7,15 @@ defmodule Scholar.Preprocessing.OneHotEncoder do """ import Nx.Defn - @derive {Nx.Container, containers: [:encoder, :one_hot]} - defstruct [:encoder, :one_hot] + @derive {Nx.Container, containers: [:ordinal_encoder]} + defstruct [:ordinal_encoder] encode_schema = [ - num_classes: [ + num_categories: [ required: true, type: :pos_integer, doc: """ - Number of classes to be encoded. + The number of categories to be encoded. """ ] ] @@ -31,37 +31,32 @@ defmodule Scholar.Preprocessing.OneHotEncoder do ## Examples - iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) - iex> Scholar.Preprocessing.OneHotEncoder.fit(t, num_classes: 4) + iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) + iex> Scholar.Preprocessing.OneHotEncoder.fit(tensor, num_categories: 4) %Scholar.Preprocessing.OneHotEncoder{ - encoder: %Scholar.Preprocessing.OrdinalEncoder{ - encoding_tensor: Nx.tensor( - [ - [0, 2], - [1, 3], - [2, 4], - [3, 56] - ] + ordinal_encoder: %Scholar.Preprocessing.OrdinalEncoder{ + categories: Nx.tensor([2, 3, 4, 56] ) - }, - one_hot: Nx.tensor( - [ - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1] - ], type: :u8 - ) + } } """ - deftransform fit(tensor, opts \\ []) do - fit_n(tensor, NimbleOptions.validate!(opts, @encode_schema)) + deftransform fit(tensor, opts) do + if Nx.rank(tensor) != 1 do + raise ArgumentError, + """ + expected input tensor to have shape {num_samples}, \ + got tensor with shape: #{inspect(Nx.shape(tensor))} + """ + end + + opts = NimbleOptions.validate!(opts, @encode_schema) + + fit_n(tensor, opts) end defnp fit_n(tensor, opts) do - encoder = Scholar.Preprocessing.OrdinalEncoder.fit(tensor, opts) - one_hot = Nx.iota({opts[:num_classes]}) == Nx.iota({opts[:num_classes], 1}) - %__MODULE__{encoder: encoder, one_hot: one_hot} + ordinal_encoder = Scholar.Preprocessing.OrdinalEncoder.fit(tensor, opts) + %__MODULE__{ordinal_encoder: ordinal_encoder} end @doc """ @@ -70,9 +65,9 @@ defmodule Scholar.Preprocessing.OneHotEncoder do ## Examples - iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) - iex> encoder = Scholar.Preprocessing.OneHotEncoder.fit(t, num_classes: 4) - iex> Scholar.Preprocessing.OneHotEncoder.transform(encoder, t) + iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) + iex> encoder = Scholar.Preprocessing.OneHotEncoder.fit(tensor, num_categories: 4) + iex> Scholar.Preprocessing.OneHotEncoder.transform(encoder, tensor) #Nx.Tensor< u8[7][4] [ @@ -86,8 +81,8 @@ defmodule Scholar.Preprocessing.OneHotEncoder do ] > - iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) - iex> encoder = Scholar.Preprocessing.OneHotEncoder.fit(t, num_classes: 4) + iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) + iex> encoder = Scholar.Preprocessing.OneHotEncoder.fit(tensor, num_categories: 4) iex> new_tensor = Nx.tensor([2, 3, 4, 3, 4, 56, 2]) iex> Scholar.Preprocessing.OneHotEncoder.transform(encoder, new_tensor) #Nx.Tensor< @@ -103,9 +98,15 @@ defmodule Scholar.Preprocessing.OneHotEncoder do ] > """ - defn transform(%__MODULE__{encoder: encoder, one_hot: one_hot}, tensor) do - decoded = Scholar.Preprocessing.OrdinalEncoder.transform(encoder, tensor) - Nx.take(one_hot, decoded) + defn transform(%__MODULE__{ordinal_encoder: ordinal_encoder}, tensor) do + num_categories = Nx.size(ordinal_encoder.categories) + num_samples = Nx.size(tensor) + encoded = + ordinal_encoder + |> Scholar.Preprocessing.OrdinalEncoder.transform(tensor) + |> Nx.new_axis(1) + |> Nx.broadcast({num_samples, num_categories}) + encoded == Nx.iota({num_samples, num_categories}, axis: 1) end @doc """ @@ -113,8 +114,8 @@ defmodule Scholar.Preprocessing.OneHotEncoder do ## Examples - iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) - iex> Scholar.Preprocessing.OneHotEncoder.fit_transform(t, num_classes: 4) + iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) + iex> Scholar.Preprocessing.OneHotEncoder.fit_transform(tensor, num_categories: 4) #Nx.Tensor< u8[7][4] [ @@ -128,9 +129,14 @@ defmodule Scholar.Preprocessing.OneHotEncoder do ] > """ - defn fit_transform(tensor, opts \\ []) do - tensor - |> fit(opts) - |> transform(tensor) + defn fit_transform(tensor, opts) do + num_samples = Nx.size(tensor) + num_categories = opts[:num_categories] + encoded = + tensor + |> Scholar.Preprocessing.OrdinalEncoder.fit_transform() + |> Nx.new_axis(1) + |> Nx.broadcast({num_samples, num_categories}) + encoded == Nx.iota({num_samples, num_categories}, axis: 1) end end diff --git a/lib/scholar/preprocessing/ordinal_encoder.ex b/lib/scholar/preprocessing/ordinal_encoder.ex index 99cf5ad3..e4053c74 100644 --- a/lib/scholar/preprocessing/ordinal_encoder.ex +++ b/lib/scholar/preprocessing/ordinal_encoder.ex @@ -1,22 +1,22 @@ defmodule Scholar.Preprocessing.OrdinalEncoder do @moduledoc """ Implements encoder that converts integer value (substitute of categorical data in tensors) into other integer value. - The values assigned starts from `0` and go up to `num_classes - 1`.They are maintained in sorted manner. + The values assigned starts from `0` and go up to `num_categories - 1`. They are maintained in sorted manner. This means that for x < y => encoded_value(x) < encoded_value(y). Currently the module supports only 1D tensors. """ import Nx.Defn - @derive {Nx.Container, containers: [:encoding_tensor]} - defstruct [:encoding_tensor] + @derive {Nx.Container, containers: [:categories]} + defstruct [:categories] encode_schema = [ - num_classes: [ + num_categories: [ required: true, type: :pos_integer, doc: """ - Number of classes to be encoded. + The number of categories to be encoded. """ ] ] @@ -32,62 +32,67 @@ defmodule Scholar.Preprocessing.OrdinalEncoder do ## Examples - iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) - iex> Scholar.Preprocessing.OrdinalEncoder.fit(t, num_classes: 4) + iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) + iex> Scholar.Preprocessing.OrdinalEncoder.fit(tensor, num_categories: 4) %Scholar.Preprocessing.OrdinalEncoder{ - encoding_tensor: Nx.tensor( - [ - [0, 2], - [1, 3], - [2, 4], - [3, 56] - ] - ) + categories: Nx.tensor([2, 3, 4, 56]) } """ deftransform fit(tensor, opts \\ []) do - fit_n(tensor, NimbleOptions.validate!(opts, @encode_schema)) - end + if Nx.rank(tensor) != 1 do + raise ArgumentError, + """ + expected input tensor to have shape {num_samples}, \ + got tensor with shape: #{inspect(Nx.shape(tensor))} + """ + end - defnp fit_n(tensor, opts) do - sorted = Nx.sort(tensor) - num_classes = opts[:num_classes] + opts = NimbleOptions.validate!(opts, @encode_schema) - # A mask with a single 1 in every group of equal values - representative_mask = - Nx.concatenate([ - sorted[0..-2//1] != sorted[1..-1//1], - Nx.tensor([1]) - ]) - - representative_indices = - representative_mask - |> Nx.argsort(direction: :desc) - |> Nx.slice_along_axis(0, num_classes) - - representative_values = Nx.take(sorted, representative_indices) |> Nx.new_axis(-1) + fit_n(tensor, opts) + end - encoding_tensor = - Nx.concatenate([Nx.iota(Nx.shape(representative_values)), representative_values], axis: 1) + defnp fit_n(tensor, opts) do + categories = + if Nx.size(tensor) > 1 do + sorted = Nx.sort(tensor) + num_categories = opts[:num_categories] + + # A mask with a single 1 in every group of equal values + representative_mask = + Nx.concatenate([ + sorted[0..-2//1] != sorted[1..-1//1], + Nx.tensor([true]) + ]) + + representative_indices = + representative_mask + |> Nx.argsort(direction: :desc, stable: true) + |> Nx.slice_along_axis(0, num_categories) + + Nx.take(sorted, representative_indices) + else + tensor + end - %__MODULE__{encoding_tensor: encoding_tensor} + %__MODULE__{categories: categories} end @doc """ - Encodes a tensor's values into integers from range 0 to `:num_classes - 1` or -1 if the value did not occur in fitting process. + Encodes tensor elements into integers from range 0 to `:num_categories - 1` or -1 if the value did not occur in fitting process. ## Examples - iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) - iex> encoder = Scholar.Preprocessing.OrdinalEncoder.fit(t, num_classes: 4) - iex> Scholar.Preprocessing.OrdinalEncoder.transform(encoder, t) + iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) + iex> encoder = Scholar.Preprocessing.OrdinalEncoder.fit(tensor, num_categories: 4) + iex> Scholar.Preprocessing.OrdinalEncoder.transform(encoder, tensor) #Nx.Tensor< s64[7] [1, 0, 2, 3, 0, 2, 0] > - iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) - iex> encoder = Scholar.Preprocessing.OrdinalEncoder.fit(t, num_classes: 4) + iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) + iex> encoder = Scholar.Preprocessing.OrdinalEncoder.fit(tensor, num_categories: 4) iex> new_tensor = Nx.tensor([2, 3, 4, 5, 4, 56, 2]) iex> Scholar.Preprocessing.OrdinalEncoder.transform(encoder, new_tensor) #Nx.Tensor< @@ -95,16 +100,24 @@ defmodule Scholar.Preprocessing.OrdinalEncoder do [0, 1, 2, -1, 2, 3, 0] > """ - defn transform(%__MODULE__{encoding_tensor: encoding_tensor}, tensor) do - tensor_size = Nx.axis_size(encoding_tensor, 0) - decode_size = Nx.axis_size(tensor, 0) + defn transform(%__MODULE__{categories: categories}, tensor) do + if Nx.rank(tensor) != 1 do + raise ArgumentError, + """ + expected input tensor to have shape {num_samples}, \ + got tensor with shape: #{inspect(Nx.shape(tensor))} + """ + end + + num_categories = Nx.size(categories) + size = Nx.size(tensor) input_vectorized_axes = tensor.vectorized_axes tensor = - Nx.revectorize(tensor, [x: decode_size], target_shape: {:auto}) + Nx.revectorize(tensor, [x: size], target_shape: {:auto}) left = 0 - right = tensor_size - 1 + right = num_categories - 1 label = -1 [left, right, label, tensor] = @@ -116,22 +129,37 @@ defmodule Scholar.Preprocessing.OrdinalEncoder do ]) {label, _} = - while {label, {left, right, tensor, encoding_tensor}}, left <= right do + while {label, {left, right, tensor, categories}}, left <= right do curr = Nx.quotient(left + right, 2) cond do - tensor[0] > encoding_tensor[curr][1] -> - {label, {curr + 1, right, tensor, encoding_tensor}} + tensor[0] > categories[curr] -> + {label, {curr + 1, right, tensor, categories}} - tensor[0] < encoding_tensor[curr][1] -> - {label, {left, curr - 1, tensor, encoding_tensor}} + tensor[0] < categories[curr] -> + {label, {left, curr - 1, tensor, categories}} true -> - {encoding_tensor[curr][0], {1, 0, tensor, encoding_tensor}} + {curr, {1, 0, tensor, categories}} end end - Nx.revectorize(label, input_vectorized_axes, target_shape: {decode_size}) + Nx.revectorize(label, input_vectorized_axes, target_shape: {size}) + end + + @doc """ + Decodes tensor elements into original categories seen during fitting. + + ## Examples + + iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) + iex> encoder = Scholar.Preprocessing.OridinalEncoder.fit(tensor, num_categories: 4) + iex> encoded = Nx.tensor([1, 0, 2, 3, 1, 0, 2]) + iex> Scholar.Preprocessing.OridinalEncoder.inverse_transform(encoder, encoded) + Nx.tensor([3, 2, 4, 56, 3, 2, 4]) + """ + deftransform inverse_transform(%__MODULE__{categories: categories}, encoded_tensor) do + Nx.take(categories, encoded_tensor) end @doc """ @@ -139,16 +167,49 @@ defmodule Scholar.Preprocessing.OrdinalEncoder do ## Examples - iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) - iex> Scholar.Preprocessing.OrdinalEncoder.fit_transform(t, num_classes: 4) + iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) + iex> Scholar.Preprocessing.OridinalEncoder.fit_transform(tensor) #Nx.Tensor< - s64[7] + u64[7] [1, 0, 2, 3, 0, 2, 0] > """ - defn fit_transform(tensor, opts \\ []) do - tensor - |> fit(opts) - |> transform(tensor) + deftransform fit_transform(tensor) do + if Nx.rank(tensor) != 1 do + raise ArgumentError, + """ + expected input tensor to have shape {num_samples}, \ + got tensor with shape: #{inspect(Nx.shape(tensor))} + """ + end + + fit_transform_n(tensor) + end + + defnp fit_transform_n(tensor) do + size = Nx.size(tensor) + indices = Nx.argsort(tensor, type: :u64) + sorted = Nx.take(tensor, indices) + + change_indices = + Nx.concatenate([ + Nx.tensor([true]), + sorted[0..(size - 2)] != sorted[1..(size - 1)] + ]) + + ordinal_values = + change_indices + |> Nx.as_type(:u64) + |> Nx.cumulative_sum() + |> Nx.subtract(1) + + inverse = + Nx.indexed_put( + Nx.broadcast(Nx.u64(0), {size}), + Nx.new_axis(indices, 1), + Nx.iota({size}, type: :u64) + ) + + Nx.take(ordinal_values, inverse) end end From 32b98fde967258c55670146aa78b2e77f604fa26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Wed, 31 Jul 2024 16:32:11 +0200 Subject: [PATCH 2/6] mix format --- lib/scholar/linear/logistic_regression.ex | 4 +++- lib/scholar/preprocessing/one_hot_encoder.ex | 22 ++++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/lib/scholar/linear/logistic_regression.ex b/lib/scholar/linear/logistic_regression.ex index 1cc42dbc..f2af2b8a 100644 --- a/lib/scholar/linear/logistic_regression.ex +++ b/lib/scholar/linear/logistic_regression.ex @@ -147,6 +147,7 @@ defmodule Scholar.Linear.LogisticRegression do iterations = opts[:iterations] num_classes = opts[:num_classes] optimizer_update_fn = opts[:optimizer_update_fn] + y_one_hot = y |> Nx.new_axis(1) @@ -173,7 +174,8 @@ defmodule Scholar.Linear.LogisticRegression do has_converged = Nx.sum(Nx.abs(loss)) < Nx.size(x) * opts[:eps] {{coef, bias}, - {x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state, has_converged, iter + 1}} + {x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state, has_converged, + iter + 1}} end %__MODULE__{ diff --git a/lib/scholar/preprocessing/one_hot_encoder.ex b/lib/scholar/preprocessing/one_hot_encoder.ex index 8bf33457..e11ee694 100644 --- a/lib/scholar/preprocessing/one_hot_encoder.ex +++ b/lib/scholar/preprocessing/one_hot_encoder.ex @@ -101,16 +101,19 @@ defmodule Scholar.Preprocessing.OneHotEncoder do defn transform(%__MODULE__{ordinal_encoder: ordinal_encoder}, tensor) do num_categories = Nx.size(ordinal_encoder.categories) num_samples = Nx.size(tensor) + encoded = ordinal_encoder |> Scholar.Preprocessing.OrdinalEncoder.transform(tensor) |> Nx.new_axis(1) |> Nx.broadcast({num_samples, num_categories}) + encoded == Nx.iota({num_samples, num_categories}, axis: 1) end @doc """ - Apply encoding on the provided tensor directly. It's equivalent to `fit/2` and then `transform/2` on the same data. + Appl + encoding on the provided tensor directly. It's equivalent to `fit/2` and then `transform/2` on the same data. ## Examples @@ -129,14 +132,29 @@ defmodule Scholar.Preprocessing.OneHotEncoder do ] > """ - defn fit_transform(tensor, opts) do + deftransform fit_transform(tensor, opts) do + if Nx.rank(tensor) != 1 do + raise ArgumentError, + """ + expected input tensor to have shape {num_samples}, \ + got tensor with shape: #{inspect(Nx.shape(tensor))} + """ + end + + opts = NimbleOptions.validate!(opts, @encode_schema) + fit_transform_n(tensor, opts) + end + + defnp fit_transform_n(tensor, opts) do num_samples = Nx.size(tensor) num_categories = opts[:num_categories] + encoded = tensor |> Scholar.Preprocessing.OrdinalEncoder.fit_transform() |> Nx.new_axis(1) |> Nx.broadcast({num_samples, num_categories}) + encoded == Nx.iota({num_samples, num_categories}, axis: 1) end end From b024bd13508acb35394704ff531109f9ecf9c359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Wed, 31 Jul 2024 16:34:28 +0200 Subject: [PATCH 3/6] Remove Scholar.Preprocessing import from Scholar.Metrics.Classification --- lib/scholar/metrics/classification.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index 4031f988..42f857a7 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -13,7 +13,6 @@ defmodule Scholar.Metrics.Classification do import Nx.Defn, except: [assert_shape: 2, assert_shape_pattern: 2] import Scholar.Shared - import Scholar.Preprocessing alias Scholar.Integrate general_schema = [ @@ -1310,6 +1309,7 @@ defmodule Scholar.Metrics.Classification do num_samples = Nx.size(y_true) num_classes = opts[:num_classes] + if Nx.axis_size(y_prob, 1) != num_classes do raise ArgumentError, "y_prob must have a size of num_classes along axis 1" end From 93c7566d3d70a2c5a41f1d0f9d0f5fb5c109dbcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Wed, 31 Jul 2024 16:53:06 +0200 Subject: [PATCH 4/6] mix format --- lib/scholar/metrics/classification.ex | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index 42f857a7..03c926fc 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -1309,7 +1309,6 @@ defmodule Scholar.Metrics.Classification do num_samples = Nx.size(y_true) num_classes = opts[:num_classes] - if Nx.axis_size(y_prob, 1) != num_classes do raise ArgumentError, "y_prob must have a size of num_classes along axis 1" end From 883ff5900129615db656ef78c1eb8cbdf3089215 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Thu, 1 Aug 2024 13:23:29 +0200 Subject: [PATCH 5/6] emove commented out code from NaiveBayes.Complement --- lib/scholar/naive_bayes/complement.ex | 29 --------------------------- 1 file changed, 29 deletions(-) diff --git a/lib/scholar/naive_bayes/complement.ex b/lib/scholar/naive_bayes/complement.ex index a9f7e759..fcdcc633 100644 --- a/lib/scholar/naive_bayes/complement.ex +++ b/lib/scholar/naive_bayes/complement.ex @@ -352,36 +352,7 @@ defmodule Scholar.NaiveBayes.Complement do do: Nx.reshape(sample_weights, {num_samples, 1}) * y_one_hot, else: y_one_hot - # classes = - # y - # |> Scholar.Preprocessing.ordinal_encode(num_classes: num_classes) - # |> Scholar.Preprocessing.one_hot_encode(num_classes: num_classes) - - # {_, classes_features} = classes_shape = Nx.shape(classes) - - # classes = - # cond do - # classes_features == 1 and num_classes == 2 -> - # Nx.concatenate([1 - classes, classes], axis: 1) - - # classes_features == 1 and num_classes != 2 -> - # Nx.broadcast(1.0, classes_shape) - - # true -> - # classes - # end - - # classes = - # if opts[:sample_weights_flag], - # do: classes * Nx.reshape(sample_weights, {:auto, 1}), - # else: classes - - # {_, n_classes} = Nx.shape(classes) - # class_count = Nx.broadcast(Nx.tensor(0.0, type: x_type), {n_classes}) - # class_count = class_count + Nx.sum(classes, axes: [0]) class_count = Nx.sum(y_weighted, axes: [0]) - # feature_count = Nx.broadcast(Nx.tensor(0.0, type: x_type), {n_classes, num_features}) - # feature_count = feature_count + Nx.dot(classes, [0], x, [0]) feature_count = Nx.dot(y_weighted, [0], x, [0]) feature_all = Nx.sum(feature_count, axes: [0]) alpha = check_alpha(alpha, opts[:force_alpha], num_features) From 95375fae11cb4b6696d960d55f263e1469a36aae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Thu, 1 Aug 2024 18:21:06 +0200 Subject: [PATCH 6/6] update docstrings --- lib/scholar/metrics/classification.ex | 10 ++++++---- lib/scholar/naive_bayes/complement.ex | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index 03c926fc..7a1f33c8 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -1262,8 +1262,10 @@ defmodule Scholar.Metrics.Classification do each class, from which the log loss is computed by averaging the negative log of the probability forecasted for the true class over a number of samples. - `y_true` should contain `num_classes` unique values, and the sum of `y_prob` - along axis 1 should be 1 to respect the law of total probability. + `y_true` should be a tensor of shape {num_samples} containing values + between 0 and num_classes - 1 (inclusive). + `y_prob` should be a tensor of shape {num_samples, num_classes} containing + predicted probability distributions over classes for each sample. ## Options @@ -1320,7 +1322,7 @@ defmodule Scholar.Metrics.Classification do type: to_float_type(y_prob) ) - y_true_onehot = + y_one_hot = y_true |> Nx.new_axis(1) |> Nx.broadcast({num_samples, num_classes}) @@ -1329,7 +1331,7 @@ defmodule Scholar.Metrics.Classification do y_prob = Nx.clip(y_prob, 0, 1) sample_loss = - Nx.multiply(y_true_onehot, y_prob) + Nx.multiply(y_one_hot, y_prob) |> Nx.sum(axes: [-1]) |> Nx.log() |> Nx.negate() diff --git a/lib/scholar/naive_bayes/complement.ex b/lib/scholar/naive_bayes/complement.ex index fcdcc633..7204b0cf 100644 --- a/lib/scholar/naive_bayes/complement.ex +++ b/lib/scholar/naive_bayes/complement.ex @@ -94,7 +94,8 @@ defmodule Scholar.NaiveBayes.Complement do @doc """ Fits a complement naive Bayes classifier. The function assumes that the targets `y` are integers - between 0 and `num_classes` - 1 (inclusive). + between 0 and `num_classes` - 1 (inclusive). Otherwise, those samples will not + contribute to `class_count`. ## Options