Skip to content

Commit

Permalink
Introduce encoders in separate modules (elixir-nx#225)
Browse files Browse the repository at this point in the history
* Introduce encoders in separate modules

* Update preprocessing.ex

* Add module docs
  • Loading branch information
msluszniak authored Dec 29, 2023
1 parent d541516 commit 9819798
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 65 deletions.
76 changes: 11 additions & 65 deletions lib/scholar/preprocessing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,6 @@ defmodule Scholar.Preprocessing do
]
]

encode_schema = [
num_classes: [
required: true,
type: :pos_integer,
doc: """
Number of classes to be encoded.
"""
]
]

normalize_schema =
general_schema ++
[
Expand Down Expand Up @@ -58,7 +48,6 @@ defmodule Scholar.Preprocessing do

@normalize_schema NimbleOptions.new!(normalize_schema)
@binarize_schema NimbleOptions.new!(binarize_schema)
@encode_schema NimbleOptions.new!(encode_schema)

@doc """
Standardizes the tensor by removing the mean and scaling to unit variance.
Expand All @@ -75,7 +64,7 @@ defmodule Scholar.Preprocessing do
>
"""
deftransform standard_scale(tensor, opts \\ []) do
defn standard_scale(tensor, opts \\ []) do
Scholar.Preprocessing.StandardScaler.fit_transform(tensor, opts)
end

Expand Down Expand Up @@ -110,7 +99,7 @@ defmodule Scholar.Preprocessing do
1.0
>
"""
deftransform max_abs_scale(tensor, opts \\ []) do
defn max_abs_scale(tensor, opts \\ []) do
Scholar.Preprocessing.MaxAbsScaler.fit_transform(tensor, opts)
end

Expand All @@ -134,7 +123,7 @@ defmodule Scholar.Preprocessing do
0.0
>
"""
deftransform min_max_scale(tensor, opts \\ []) do
defn min_max_scale(tensor, opts \\ []) do
Scholar.Preprocessing.MinMaxScaler.fit_transform(tensor, opts)
end

Expand Down Expand Up @@ -176,11 +165,8 @@ defmodule Scholar.Preprocessing do
end

@doc """
Encodes a tensor's values into integers from range 0 to `:num_classes - 1`.
## Options
#{NimbleOptions.docs(@encode_schema)}
It is a shortcut for `Scholar.Preprocessing.OrdinalEncoder.fit_transform/2`.
See `Scholar.Preprocessing.OrdinalEncoder` for more information.
## Examples
Expand All @@ -190,42 +176,13 @@ defmodule Scholar.Preprocessing do
[1, 0, 2, 3, 0, 2, 0]
>
"""
deftransform ordinal_encode(tensor, opts \\ []) do
ordinal_encode_n(tensor, NimbleOptions.validate!(opts, @encode_schema))
end

defnp ordinal_encode_n(tensor, opts) do
sorted = Nx.sort(tensor)
num_classes = opts[:num_classes]

# A mask with a single 1 in every group of equal values
representative_mask =
Nx.concatenate([
sorted[0..-2//1] != sorted[1..-1//1],
Nx.tensor([1])
])

representative_indices =
representative_mask
|> Nx.argsort(direction: :desc)
|> Nx.slice_along_axis(0, num_classes)

representative_values = Nx.take(sorted, representative_indices)

(Nx.new_axis(tensor, 1) ==
Nx.new_axis(representative_values, 0))
|> Nx.argmax(axis: 1)
defn ordinal_encode(tensor, opts \\ []) do
Scholar.Preprocessing.OrdinalEncoder.fit_transform(tensor, opts)
end

@doc """
Encode labels as a one-hot numeric tensor.
Labels must be integers from 0 to `:num_classes - 1`. If the data does
not meet the condition, please use `ordinal_encoding/2` first.
## Options
#{NimbleOptions.docs(@encode_schema)}
It is a shortcut for `Scholar.Preprocessing.OneHotEncoder.fit_transform/2`.
See `Scholar.Preprocessing.OneHotEncoder` for more information.
## Examples
Expand All @@ -243,19 +200,8 @@ defmodule Scholar.Preprocessing do
]
>
"""
deftransform one_hot_encode(tensor, opts \\ []) do
one_hot_encode_n(tensor, NimbleOptions.validate!(opts, @encode_schema))
end

defnp one_hot_encode_n(tensor, opts) do
{len} = Nx.shape(tensor)

if opts[:num_classes] > len do
raise ArgumentError,
"expected :num_classes to be at most as length of label vector"
end

Nx.new_axis(tensor, -1) == Nx.iota({1, opts[:num_classes]})
defn one_hot_encode(tensor, opts \\ []) do
Scholar.Preprocessing.OneHotEncoder.fit_transform(tensor, opts)
end

@doc """
Expand Down
136 changes: 136 additions & 0 deletions lib/scholar/preprocessing/one_hot_encoder.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
defmodule Scholar.Preprocessing.OneHotEncoder do
@moduledoc """
Implements encoder that converts integer value (substitute of categorical data in tensors) into 0-1 vector.
The index of 1 in the vector is aranged in sorted manner. This means that for x < y => one_index(x) < one_index(y).
Currently the module supports only 1D tensors.
"""
import Nx.Defn

@derive {Nx.Container, containers: [:encoder, :one_hot]}
defstruct [:encoder, :one_hot]

encode_schema = [
num_classes: [
required: true,
type: :pos_integer,
doc: """
Number of classes to be encoded.
"""
]
]

@encode_schema NimbleOptions.new!(encode_schema)

@doc """
Creates mapping from values into one-hot vectors.
## Options
#{NimbleOptions.docs(@encode_schema)}
## Examples
iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2])
iex> Scholar.Preprocessing.OneHotEncoder.fit(t, num_classes: 4)
%Scholar.Preprocessing.OneHotEncoder{
encoder: %Scholar.Preprocessing.OrdinalEncoder{
encoding_tensor: Nx.tensor(
[
[0, 2],
[1, 3],
[2, 4],
[3, 56]
]
)
},
one_hot: Nx.tensor(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]
], type: :u8
)
}
"""
deftransform fit(tensor, opts \\ []) do
fit_n(tensor, NimbleOptions.validate!(opts, @encode_schema))
end

defnp fit_n(tensor, opts) do
encoder = Scholar.Preprocessing.OrdinalEncoder.fit(tensor, opts)
one_hot = Nx.iota({opts[:num_classes]}) == Nx.iota({opts[:num_classes], 1})
%__MODULE__{encoder: encoder, one_hot: one_hot}
end

@doc """
Encode labels as a one-hot numeric tensor. All values provided to `transform/2` must be seen
in `fit/2` function, otherwise an error occurs.
## Examples
iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2])
iex> enoder = Scholar.Preprocessing.OneHotEncoder.fit(t, num_classes: 4)
iex> Scholar.Preprocessing.OneHotEncoder.transform(enoder, t)
#Nx.Tensor<
u8[7][4]
[
[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[1, 0, 0, 0],
[0, 0, 1, 0],
[1, 0, 0, 0]
]
>
iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2])
iex> enoder = Scholar.Preprocessing.OneHotEncoder.fit(t, num_classes: 4)
iex> new_tensor = Nx.tensor([2, 3, 4, 3, 4, 56, 2])
iex> Scholar.Preprocessing.OneHotEncoder.transform(enoder, new_tensor)
#Nx.Tensor<
u8[7][4]
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[1, 0, 0, 0]
]
>
"""
defn transform(%__MODULE__{encoder: encoder, one_hot: one_hot}, tensor) do
decoded = Scholar.Preprocessing.OrdinalEncoder.transform(encoder, tensor)
Nx.take(one_hot, decoded)
end

@doc """
Apply encoding on the provided tensor directly. It's equivalent to `fit/2` and then `transform/2` on the same data.
## Examples
iex> t = Nx.tensor([3, 2, 4, 56, 2, 4, 2])
iex> Scholar.Preprocessing.OneHotEncoder.fit_transform(t, num_classes: 4)
#Nx.Tensor<
u8[7][4]
[
[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[1, 0, 0, 0],
[0, 0, 1, 0],
[1, 0, 0, 0]
]
>
"""
defn fit_transform(tensor, opts \\ []) do
tensor
|> fit(opts)
|> transform(tensor)
end
end
Loading

0 comments on commit 9819798

Please sign in to comment.