diff --git a/lib/scholar/decomposition/pca.ex b/lib/scholar/decomposition/pca.ex index 97d6d620..41cbcdc9 100644 --- a/lib/scholar/decomposition/pca.ex +++ b/lib/scholar/decomposition/pca.ex @@ -2,53 +2,51 @@ defmodule Scholar.Decomposition.PCA do @moduledoc """ Principal Component Analysis (PCA). - The main concept of PCA is to find components (i.e. columns of a matrix) which explain the most variance - of data set [1]. The sample data is decomposed using linear combination of - vectors that lie on the directions of those components. + PCA is a method for reducing the dimensionality of the data by transforming the original features + into a new set of uncorrelated features called principal components, which capture the maximum + variance in the data. + It can be trained on the entirety of the data at once using `fit/2` or + incrementally for datasets that are too large to fit in the memory using `incremental_fit/2`. The time complexity is $O(NP^2 + P^3)$ where $N$ is the number of samples and $P$ is the number of features. Space complexity is $O(P * (P+N))$. - Reference: - * [1] - [Principal Component Analysis](https://en.wikipedia.org/wiki/Principal_component_analysis) + References: + + * [1] Dimensionality Reduction with Principal Component Analysis. [Mathematics for Machine Learning](https://mml-book.github.io/book/mml-book.pdf), Chapter 10 + * [2] [Incremental Learning for Robust Visual Tracking](https://www.cs.toronto.edu/~dross/ivt/RossLimLinYang_ijcv.pdf) """ import Nx.Defn @derive {Nx.Container, - keep: [:num_components], + keep: [:whiten?], containers: [ :components, - :explained_variance, - :explained_variance_ratio, :singular_values, + :num_samples_seen, :mean, - :num_features, - :num_samples + :variance, + :explained_variance, + :explained_variance_ratio ]} defstruct [ :components, - :explained_variance, - :explained_variance_ratio, :singular_values, + :num_samples_seen, :mean, - :num_components, - :num_features, - :num_samples + :variance, + :explained_variance, + :explained_variance_ratio, + :whiten? ] - fit_opts_schema = [ + opts = [ num_components: [ - type: {:or, [:pos_integer, {:in, [nil]}]}, - default: nil, - doc: ~S""" - Number of components to keep. If `:num_components` is not set, all components are kept - which is the minimum value from number of features and number of samples. - """ - ] - ] - - transform_opts_schema = [ - whiten: [ + required: true, + type: :pos_integer, + doc: "The number of principal components to keep." + ], + whiten?: [ type: :boolean, default: false, doc: """ @@ -57,23 +55,18 @@ defmodule Scholar.Decomposition.PCA do Whitening will remove some information from the transformed signal (the relative variance scales of the components) but can sometime improve the predictive accuracy of the downstream estimators by making their data respect some hard-wired assumptions. - """ ] ] - fit_transform_opts_schema = fit_opts_schema ++ transform_opts_schema - - @fit_opts_schema NimbleOptions.new!(fit_opts_schema) - @transform_opts_schema NimbleOptions.new!(transform_opts_schema) - @fit_transform_opts_schema NimbleOptions.new!(fit_transform_opts_schema) + @opts_schema NimbleOptions.new!(opts) @doc """ Fits a PCA for sample inputs `x`. ## Options - #{NimbleOptions.docs(@fit_opts_schema)} + #{NimbleOptions.docs(@opts_schema)} ## Return Values @@ -81,163 +74,363 @@ defmodule Scholar.Decomposition.PCA do * `:components` - Principal axes in feature space, representing the directions of maximum variance in the data. Equivalently, the right singular vectors of the centered input data, parallel to its eigenvectors. - The components are sorted by `:explained_variance`. - - * `:explained_variance` - The amount of variance explained by each of the selected components. - The variance estimation uses `:num_samples - 1` degrees of freedom. - Equal to `:num_components` largest eigenvalues of the covariance matrix of `x`. - - * `:explained_variance_ratio` - Percentage of variance explained by each of the selected components. - If `:num_components` is not set then all components are stored and the sum of the ratios is equal to 1.0. + The components are sorted by decreasing `:explained_variance`. * `:singular_values` - The singular values corresponding to each of the selected components. The singular values are equal to the 2-norms of the `:num_components` variables in the lower-dimensional space. + * `:num_samples_seen` - Number of samples in the training data. + * `:mean` - Per-feature empirical mean, estimated from the training set. - * `:num_components` - It equals the parameter `:num_components`, or the lesser - value of `:num_features` and `:num_samples` if the parameter `:num_components` is `nil`. + * `:variance` - Per-feature empirical variance. - * `:num_features` - Number of features in the training data. + * `:explained_variance` - The amount of variance explained by each of the selected components. + The variance estimation uses `:num_samples - 1` degrees of freedom. + Equal to `:num_components` largest eigenvalues of the covariance matrix of `x`. - * `:num_samples` - Number of samples in the training data. + * `:explained_variance_ratio` - Percentage of variance explained by each of the selected components. + + * `:whiten?` - Whether to apply whitening. ## Examples - iex> x = Nx.tensor([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) - iex> Scholar.Decomposition.PCA.fit(x) - %Scholar.Decomposition.PCA{ - components: Nx.tensor( - [ - [-0.838727593421936, -0.5445511937141418], - [0.5445511937141418, -0.838727593421936] - ] - ), - explained_variance: Nx.tensor( - [7.939542293548584, 0.06045711785554886] - ), - explained_variance_ratio: Nx.tensor( - [0.9924428462982178, 0.007557140197604895] - ), - singular_values: Nx.tensor( - [6.300611972808838, 0.5498050451278687] - ), - mean: Nx.tensor( - [0.0, 0.0] - ), - num_components: 2, - num_features: Nx.tensor( - 2 - ), - num_samples: Nx.tensor( - 6 - ) - } + + iex> x = Scidata.Iris.download() |> elem(0) |> Nx.tensor() + iex> pca = Scholar.Decomposition.PCA.fit(x, num_components: 2) + iex> pca.components + Nx.tensor( + [ + [0.36182016134262085, -0.08202514797449112, 0.8565111756324768, 0.3588128685951233], + [0.6585038900375366, 0.7275884747505188, -0.17632202804088593, -0.07679986208677292] + ] + ) + iex> pca.singular_values + Nx.tensor([25.089859008789062, 6.007821559906006]) """ deftransform fit(x, opts \\ []) do - fit_n(x, NimbleOptions.validate!(opts, @fit_opts_schema)) - end + opts = NimbleOptions.validate!(opts, @opts_schema) - # TODO Add support for :num_components as a float when dynamic shapes will be implemented - defnp fit_n(x, opts) do if Nx.rank(x) != 2 do - raise ArgumentError, "expected x to have rank equal to: 2, got: #{inspect(Nx.rank(x))}" + raise ArgumentError, + """ + expected input tensor to have shape {num_samples, num_features}, \ + got tensor with shape: #{inspect(Nx.shape(x))}\ + """ end {num_samples, num_features} = Nx.shape(x) num_components = opts[:num_components] - mean = Nx.mean(x, axes: [0]) - x = x - mean - {decomposer, singular_values, components} = Nx.LinAlg.svd(x, full_matrices?: false) - - num_components = - calculate_num_components( - num_components, - num_features, - num_samples - ) + cond do + num_components > num_samples -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + batch_size = #{num_samples}, got #{num_components} + """ + + num_components > num_features -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + num_features = #{num_features}, got #{num_components} + """ + + true -> + nil + end + + fit_n(x, opts) + end - {_, components} = flip_svd(decomposer, components) - components = components[[0..(num_components - 1), ..]] + defnp fit_n(x, opts) do + num_samples = Nx.axis_size(x, 0) |> Nx.u64() + num_components = opts[:num_components] - explained_variance = singular_values * singular_values / (num_samples - 1) + mean = Nx.mean(x, axes: [0]) + x_centered = x - mean + variance = Nx.sum(x_centered * x_centered / (num_samples - 1), axes: [0]) + {u, s, vt} = Nx.LinAlg.svd(x_centered, full_matrices?: false) + {_, vt} = Scholar.Decomposition.Utils.flip_svd(u, vt) + components = vt[0..(num_components - 1)] + explained_variance = s * s / (num_samples - 1) explained_variance_ratio = - (explained_variance / Nx.sum(explained_variance))[[0..(num_components - 1)]] + (explained_variance / Nx.sum(explained_variance))[0..(num_components - 1)] %__MODULE__{ components: components, - explained_variance: explained_variance[[0..(num_components - 1)]], - explained_variance_ratio: explained_variance_ratio, - singular_values: singular_values[[0..(num_components - 1)]], + singular_values: s[0..(num_components - 1)], + num_samples_seen: num_samples, mean: mean, - num_components: num_components, - num_features: num_features, - num_samples: num_samples + variance: variance, + explained_variance: explained_variance[0..(num_components - 1)], + explained_variance_ratio: explained_variance_ratio, + whiten?: opts[:whiten?] } end @doc """ - For a fitted `model` performs a decomposition. + Fits a PCA model on a stream of batches. ## Options - #{NimbleOptions.docs(@transform_opts_schema)} + #{NimbleOptions.docs(@opts_schema)} + + ## Return values + + The function returns a struct with the following parameters: + + * `:num_components` - The number of principal components. + + * `:components` - Principal axes in feature space, representing the directions of maximum variance in the data. + Equivalently, the right singular vectors of the centered input data, parallel to its eigenvectors. + The components are sorted by decreasing `:explained_variance`. + + * `:singular_values` - The singular values corresponding to each of the selected components. + The singular values are equal to the 2-norms of the `:num_components` variables in the lower-dimensional space. + + * `:num_samples_seen` - The number of data samples processed. + + * `:mean` - Per-feature empirical mean. + + * `:variance` - Per-feature empirical variance. + + * `:explained_variance` - Variance explained by each of the selected components. + + * `:explained_variance_ratio` - Percentage of variance explained by each of the selected components. + + * `:whiten?` - Whether to apply whitening. + + ## Examples + + iex> {x, _} = Scidata.Iris.download() + iex> batches = x |> Nx.tensor() |> Nx.to_batched(10) + iex> pca = Scholar.Decomposition.PCA.incremental_fit(batches, num_components: 2) + iex> pca.components + Nx.tensor( + [ + [-0.33354005217552185, 0.1048964187502861, -0.8618107080105579, -0.3674643635749817], + [-0.5862125754356384, -0.7916879057884216, 0.15874788165092468, -0.06621300429105759] + ] + ) + iex> pca.singular_values + Nx.tensor([77.05782028025969, 10.137848854064941]) + """ + deftransform incremental_fit(batches, opts) do + opts = NimbleOptions.validate!(opts, @opts_schema) + + Enum.reduce( + batches, + nil, + fn batch, model -> fit_batch(model, batch, opts) end + ) + end + + defp fit_batch(nil, batch, opts), do: fit(batch, opts) + defp fit_batch(%__MODULE__{} = model, batch, _opts), do: partial_fit(model, batch) + + @doc """ + Updates the parameters of a PCA model on samples `x`. + + ## Examples + + iex> {x, _} = Scidata.Iris.download() + iex> {first_batch, second_batch} = x |> Nx.tensor() |> Nx.split(75) + iex> pca = Scholar.Decomposition.PCA.fit(first_batch, num_components: 2) + iex> pca = Scholar.Decomposition.PCA.partial_fit(pca, second_batch) + iex> pca.components + Nx.tensor( + [ + [-0.3229745328426361, 0.09587063640356064, -0.8628664612770081, -0.37677285075187683], + [-0.6786625981330872, -0.7167785167694092, 0.14237160980701447, 0.07332050055265427] + ] + ) + iex> pca.singular_values + Nx.tensor([166.141845703125, 6.078948020935059]) + """ + deftransform partial_fit(model, x) do + if Nx.rank(x) != 2 do + raise ArgumentError, + """ + expected input tensor to have shape {num_samples, num_features}, \ + got tensor with shape: #{inspect(Nx.shape(x))}\ + """ + end + + {num_components, num_features_seen} = Nx.shape(model.components) + {num_samples, num_features} = Nx.shape(x) + + cond do + num_features_seen != num_features -> + raise ArgumentError, + """ + each batch must have the same number of features, \ + got #{num_features_seen} and #{num_features} + """ + + num_components > num_samples -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + num_samples = #{num_samples}, got #{num_components} + """ + + true -> + nil + end + + partial_fit_n(model, x) + end + + defnp partial_fit_n(model, x) do + components = model.components + num_components = Nx.axis_size(components, 0) + singular_values = model.singular_values + num_samples_seen = model.num_samples_seen + mean = model.mean + variance = model.variance + {num_samples, _} = Nx.shape(x) + + {x_mean, x_centered, updated_num_samples_seen, updated_mean, updated_variance} = + incremental_mean_and_variance(x, num_samples_seen, mean, variance) + + mean_correction = + Nx.sqrt(num_samples_seen / updated_num_samples_seen) * num_samples * (mean - x_mean) + + mean_correction = Nx.new_axis(mean_correction, 0) + + matrix = + Nx.concatenate( + [ + Nx.new_axis(singular_values, 1) * components, + x_centered, + mean_correction + ], + axis: 0 + ) + + {u, s, vt} = Nx.LinAlg.svd(matrix, full_matrices?: false) + {_, vt} = Scholar.Decomposition.Utils.flip_svd(u, vt) + updated_components = vt[0..(num_components - 1)] + updated_singular_values = s[0..(num_components - 1)] + + updated_explained_variance = + singular_values * singular_values / (updated_num_samples_seen - 1) + + updated_explained_variance_ratio = + singular_values * singular_values / Nx.sum(updated_variance * updated_num_samples_seen) + + %__MODULE__{ + components: updated_components, + singular_values: updated_singular_values, + num_samples_seen: updated_num_samples_seen, + mean: updated_mean, + variance: updated_variance, + explained_variance: updated_explained_variance, + explained_variance_ratio: updated_explained_variance_ratio, + whiten?: model.whiten? + } + end + + defnp incremental_mean_and_variance(x, num_samples_seen, mean, variance) do + new_num_samples = Nx.axis_size(x, 0) + updated_num_samples_seen = num_samples_seen + new_num_samples + sum = num_samples_seen * mean + new_sum = Nx.sum(x, axes: [0]) + updated_mean = (sum + new_sum) / updated_num_samples_seen + new_mean = new_sum / new_num_samples + x_centered = x - new_mean + correction = Nx.sum(x_centered, axes: [0]) + + new_unnormalized_variance = + Nx.sum(x_centered * x_centered, axes: [0]) - correction * correction / new_num_samples + + unnormalized_variance = num_samples_seen * variance + seen_over_new = num_samples_seen / new_num_samples + + updated_unnormalized_variance = + unnormalized_variance + + new_unnormalized_variance + + seen_over_new / updated_num_samples_seen * + (sum / seen_over_new - new_sum) ** 2 + + updated_variance = updated_unnormalized_variance / updated_num_samples_seen + {new_mean, x_centered, updated_num_samples_seen, updated_mean, updated_variance} + end + + @doc """ + For a fitted `model` performs a decomposition of samples `x`. ## Return Values The function returns a tensor with decomposed data. ## Examples - iex> x = Nx.tensor([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) - iex> model = Scholar.Decomposition.PCA.fit(x) - iex> Scholar.Decomposition.PCA.transform(model, x) + iex> x_fit = Scidata.Iris.download() |> elem(0) |> Nx.tensor() + iex> pca = Scholar.Decomposition.PCA.fit(x_fit, num_components: 2) + iex> x_transform = Nx.tensor([[5.2, 2.6, 2.475, 0.7], [6.1, 3.2, 3.95, 1.3], [7.0, 3.8, 5.425, 1.9]]) + iex> Scholar.Decomposition.PCA.transform(pca, x_transform) Nx.tensor( [ - [1.3832788467407227, 0.2941763997077942], - [2.222006320953369, -0.25037479400634766], - [3.605285167694092, 0.04380160570144653], - [-1.3832788467407227, -0.2941763997077942], - [-2.222006320953369, 0.25037479400634766], - [-3.605285167694092, -0.04380160570144653] + [-1.4739344120025635, -0.48932668566703796], + [0.28113049268722534, 0.2337251454591751], + [2.0361955165863037, 0.9567767977714539] ] ) """ - deftransform transform(model, x, opts \\ []) do - transform_n(model, x, NimbleOptions.validate!(opts, @transform_opts_schema)) + deftransform transform(model, x) do + if Nx.rank(x) != 2 do + raise ArgumentError, + """ + expected input tensor to have shape {num_samples, num_features}, \ + got tensor with shape: #{inspect(Nx.shape(x))}\ + """ + end + + num_features_seen = Nx.axis_size(model.components, 1) + num_features = Nx.axis_size(x, 1) + + if num_features_seen != num_features do + raise ArgumentError, + """ + expected input tensor to have the same number of features \ + as tensor used to fit the model, \ + got #{inspect(num_features)} \ + and #{inspect(num_features_seen)} + """ + end + + transform_n(model, x) end defnp transform_n( %__MODULE__{ components: components, explained_variance: explained_variance, - mean: mean + mean: mean, + whiten?: whiten? } = _model, - x, - opts + x ) do - whiten? = opts[:whiten] + x_centered = x - mean - x = x - mean - - x_transformed = Nx.dot(x, [1], components, [1]) + z = Nx.dot(x_centered, [1], components, [1]) if whiten? do - x_transformed / Nx.sqrt(explained_variance) + z / Nx.sqrt(explained_variance) else - x_transformed + z end end @doc """ Fit the model with `x` and apply the dimensionality reduction on `x`. - This function is analogous to calling `fit/2` and then - `transform/3`, but it is calculated more efficiently. - - ## Options + This function is equivalent to calling `fit/2` and then + `transform/3`, but the result is computed more efficiently. - #{NimbleOptions.docs(@transform_opts_schema)} + #{NimbleOptions.docs(@opts_schema)} ## Return Values @@ -245,80 +438,44 @@ defmodule Scholar.Decomposition.PCA do ## Examples - iex> x = Nx.tensor([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) - iex> Scholar.Decomposition.PCA.fit_transform(x) + iex> x = Scidata.Iris.download() |> elem(0) |> Enum.take(6) |> Nx.tensor() + iex> Scholar.Decomposition.PCA.fit_transform(x, num_components: 2) Nx.tensor( [ - [1.3819537162780762, 0.2936314642429352], - [2.2231407165527344, -0.25125157833099365], - [3.6050944328308105, 0.04237968474626541], - [-1.3819535970687866, -0.29363128542900085], - [-2.2231407165527344, 0.2512516379356384], - [-3.6050944328308105, -0.04237968474626541] + [0.16441848874092102, 0.028548287227749825], + [-0.32804328203201294, 0.20709986984729767], + [-0.3284338414669037, -0.08318747580051422], + [-0.42237386107444763, -0.0735677033662796], + [0.17480169236660004, -0.11189625412225723], + [0.7396301627159119, 0.03300142288208008 + ] ] ) """ - deftransform fit_transform(x, opts \\ []) do - fit_transform_n(x, NimbleOptions.validate!(opts, @fit_transform_opts_schema)) + deftransform fit_transform(x, opts) do + fit_transform_n(x, NimbleOptions.validate!(opts, @opts_schema)) end defnp fit_transform_n(x, opts) do - if Nx.rank(x) != 2 do - raise ArgumentError, "expected x to have rank equal to: 2, got: #{inspect(Nx.rank(x))}" - end - - {num_samples, num_features} = Nx.shape(x) num_components = opts[:num_components] - x = x - Nx.mean(x, axes: [0]) - {decomposer, singular_values, components} = Nx.LinAlg.svd(x, full_matrices?: false) - - num_components = - calculate_num_components( - num_components, - num_features, - num_samples - ) - - {decomposer, _components} = flip_svd(decomposer, components) - decomposer = decomposer[[.., 0..(num_components - 1)]] + mean = Nx.mean(x, axes: [0]) + x_centered = x - mean + {u, s, vt} = Nx.LinAlg.svd(x_centered, full_matrices?: false) + {u, _} = flip_svd(u, vt) + u = u[[.., 0..(num_components - 1)]] - if opts[:whiten] do - decomposer * Nx.sqrt(num_samples - 1) + if opts[:whiten?] do + u * Nx.sqrt(Nx.axis_size(x, 0) - 1) else - decomposer * singular_values[[0..(num_components - 1)]] + u * s[0..(num_components - 1)] end end defnp flip_svd(u, v) do - # columns of u, rows of v max_abs_cols_idx = u |> Nx.abs() |> Nx.argmax(axis: 0, keep_axis: true) signs = u |> Nx.take_along_axis(max_abs_cols_idx, axis: 0) |> Nx.sign() |> Nx.squeeze() u = u * signs v = v * Nx.new_axis(signs, -1) {u, v} end - - deftransformp calculate_num_components( - num_components, - num_features, - num_samples - ) do - default_num_components = min(num_features, num_samples) - - cond do - num_components == nil -> - default_num_components - - num_components > 0 and num_components <= min(num_features, num_samples) and - is_integer(num_components) -> - num_components - - is_integer(num_components) -> - raise ArgumentError, - "expected :num_components to be integer in range 1 to #{inspect(min(num_samples, num_features))}, got: #{inspect(num_components)}" - - true -> - raise ArgumentError, "unexpected type of :num_components, got: #{inspect(num_components)}" - end - end end diff --git a/lib/scholar/manifold/tsne.ex b/lib/scholar/manifold/tsne.ex index 26df9a74..3a8884f1 100644 --- a/lib/scholar/manifold/tsne.ex +++ b/lib/scholar/manifold/tsne.ex @@ -100,7 +100,7 @@ defmodule Scholar.Manifold.TSNE do ## Examples - iex> x = Nx.iota({4,5}) + iex> x = Nx.iota({4, 5}) iex> key = Nx.Random.key(42) iex> Scholar.Manifold.TSNE.fit(x, num_components: 2, key: key) #Nx.Tensor< diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex index 131b1d6d..707c8d67 100644 --- a/lib/scholar/neighbors/brute_knn.ex +++ b/lib/scholar/neighbors/brute_knn.ex @@ -217,7 +217,14 @@ defmodule Scholar.Neighbors.BruteKNN do defn get_batches(tensor, opts) do {size, dim} = Nx.shape(tensor) - batch_size = opts[:batch_size] + batch_size = min(size, opts[:batch_size]) + + min_batch_size = + case opts[:min_batch_size] do + nil -> 0 + b -> b + end + num_batches = div(size, batch_size) leftover_size = rem(size, batch_size) @@ -227,7 +234,7 @@ defmodule Scholar.Neighbors.BruteKNN do |> Nx.reshape({num_batches, batch_size, dim}) leftover = - if leftover_size > 0 do + if leftover_size > min_batch_size do Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0) else nil diff --git a/mix.exs b/mix.exs index b1d53755..439a86cd 100644 --- a/mix.exs +++ b/mix.exs @@ -33,7 +33,8 @@ defmodule Scholar.MixProject do {:nimble_options, "~> 0.5.2 or ~> 1.0"}, {:exla, ">= 0.0.0", only: :test}, {:polaris, "~> 0.1"}, - {:benchee, "~> 1.0", only: :dev} + {:benchee, "~> 1.0", only: :dev}, + {:scidata, "~> 0.1.11", only: :test} ] end diff --git a/mix.lock b/mix.lock index a16434fd..39712046 100644 --- a/mix.lock +++ b/mix.lock @@ -1,19 +1,23 @@ %{ - "benchee": {:hex, :benchee, "1.1.0", "f3a43817209a92a1fade36ef36b86e1052627fd8934a8b937ac9ab3a76c43062", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}], "hexpm", "7da57d545003165a012b587077f6ba90b89210fd88074ce3c60ce239eb5e6d93"}, + "benchee": {:hex, :benchee, "1.3.1", "c786e6a76321121a44229dde3988fc772bca73ea75170a73fd5f4ddf1af95ccf", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: true]}], "hexpm", "76224c58ea1d0391c8309a8ecbfe27d71062878f59bd41a390266bf4ac1cc56d"}, + "castore": {:hex, :castore, "0.1.22", "4127549e411bedd012ca3a308dede574f43819fe9394254ca55ab4895abfa1a2", [:mix], [], "hexpm", "c17576df47eb5aa1ee40cc4134316a99f5cad3e215d5c77b8dd3cfef12a22cac"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, - "elixir_make": {:hex, :elixir_make, "0.7.8", "505026f266552ee5aabca0b9f9c229cbb496c689537c9f922f3eb5431157efc7", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "7a71945b913d37ea89b06966e1342c85cfe549b15e6d6d081e8081c493062c07"}, - "ex_doc": {:hex, :ex_doc, "0.34.0", "ab95e0775db3df71d30cf8d78728dd9261c355c81382bcd4cefdc74610bef13e", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "60734fb4c1353f270c3286df4a0d51e65a2c1d9fba66af3940847cc65a8066d7"}, - "exla": {:hex, :exla, "0.7.1", "790493288cf4441abed98df0c4e98da15a2e3a7fa27cd2a1f74ec0693952c579", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "ec9c1698a9a17b859d79f9b3c1d75c370335580cdd0353db9c2017f86155e2ec"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, + "elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"}, + "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, + "exla": {:hex, :exla, "0.7.3", "51310270a0976974fc758f7b28ebd6ca8e099b3d6fc78b0d484c808e977cb914", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "5b3d5741a24aada21d3b0feb4b99d1fc3c8457f995a63ea16684d8d5678b96ff"}, + "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, - "makeup_erlang": {:hex, :makeup_erlang, "1.0.0", "6f0eff9c9c489f26b69b61440bf1b238d95badae49adac77973cbacae87e3c2e", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "ea7a9307de9d1548d2a72d299058d1fd2339e3d398560a0e46c27dab4891e4d2"}, - "nimble_options": {:hex, :nimble_options, "0.5.2", "42703307b924880f8c08d97719da7472673391905f528259915782bb346e0a1b", [:mix], [], "hexpm", "4da7f904b915fd71db549bcdc25f8d56f378ef7ae07dc1d372cbe72ba950dce0"}, + "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, + "nimble_csv": {:hex, :nimble_csv, "1.2.0", "4e26385d260c61eba9d4412c71cea34421f296d5353f914afe3f2e71cce97722", [:mix], [], "hexpm", "d0628117fcc2148178b034044c55359b26966c6eaa8e2ce15777be3bbc91b12a"}, + "nimble_options": {:hex, :nimble_options, "1.1.1", "e3a492d54d85fc3fd7c5baf411d9d2852922f66e69476317787a7b2bb000a61b", [:mix], [], "hexpm", "821b2470ca9442c4b6984882fe9bb0389371b8ddec4d45a9504f00a66f650b44"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, - "nimble_pool": {:hex, :nimble_pool, "1.0.0", "5eb82705d138f4dd4423f69ceb19ac667b3b492ae570c9f5c900bb3d2f50a847", [:mix], [], "hexpm", "80be3b882d2d351882256087078e1b1952a28bf98d0a287be87e4a24a710b67a"}, - "nx": {:hex, :nx, "0.7.1", "5f6376e3d18408116e8a84b8f4ac851fb07dfe61764a5410ebf0b5dcb69c1b7e", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e3ddd6a3f2a9bac79c67b3933368c25bb5ec814a883fc68aba8fd8a236751777"}, + "nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"}, + "nx": {:hex, :nx, "0.7.3", "51ff45d9f9ff58b616f4221fa54ccddda98f30319bb8caaf86695234a469017a", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "5ff29af84f08db9bda66b8ef7ce92ab583ab4f983629fe00b479f1e5c7c705a6"}, "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, + "scidata": {:hex, :scidata, "0.1.11", "fe3358bac7d740374b4f2a7eff6a1cb02e5ee7f87f7cdb1e8648ad93c533165f", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.1", [hex: :nimble_csv, repo: "hexpm", optional: false]}, {:stb_image, "~> 0.4", [hex: :stb_image, repo: "hexpm", optional: true]}], "hexpm", "90873337a9d5fe880d640517efa93d3c07e46c8ba436de44117f581800549f93"}, "statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "xla": {:hex, :xla, "0.6.0", "67bb7695efa4a23b06211dc212de6a72af1ad5a9e17325e05e0a87e4c241feb8", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "dd074daf942312c6da87c7ed61b62fb1a075bced157f1cc4d47af2d7c9f44fb7"}, diff --git a/test/scholar/decomposition/pca_test.exs b/test/scholar/decomposition/pca_test.exs index e332b971..32e5af3c 100644 --- a/test/scholar/decomposition/pca_test.exs +++ b/test/scholar/decomposition/pca_test.exs @@ -20,33 +20,29 @@ defmodule Scholar.Decomposition.PCATest do end test "fit test - all default options" do - model = PCA.fit(x()) + model = PCA.fit(x(), num_components: 1) assert_all_close( model.components, - Nx.tensor([[-0.83849224, -0.54491354], [0.54491354, -0.83849224]]), + Nx.tensor([[-0.838727593421936, -0.5445511937141418]]), atol: 1.0e-3 ) - assert_all_close(model.explained_variance, Nx.tensor([7.9395432472229, 0.060456883162260056]), + assert_all_close(model.singular_values, Nx.tensor([6.300611972808838]), atol: 1.0e-3) + + assert model.num_samples_seen == Nx.u64(6) + + assert model.mean == Nx.tensor([0.0, 0.0]) + + assert_all_close(model.variance, Nx.tensor([5.599999904632568, 2.4000000953674316]), atol: 1.0e-3 ) - assert_all_close( - model.explained_variance_ratio, - Nx.tensor([0.9924429059028625, 0.007557110395282507]) - ) + assert_all_close(model.explained_variance, Nx.tensor([7.939542293548584]), atol: 1.0e-3) - assert_all_close(model.singular_values, Nx.tensor([6.30061232, 0.54980396]), atol: 1.0e-3) - assert model.mean == Nx.tensor([0.0, 0.0]) - assert model.num_components == 2 - assert model.num_samples == Nx.tensor(6) - assert model.num_features == Nx.tensor(2) - end + assert_all_close(model.explained_variance_ratio, Nx.tensor([0.9924428462982178])) - test "fit test - :num_components is integer" do - model = PCA.fit(x(), num_components: 1) - assert model.num_components == 1 + assert not model.whiten? end test "fit test - :num_components is integer and wide matrix" do @@ -60,6 +56,14 @@ defmodule Scholar.Decomposition.PCATest do atol: 1.0e-3 ) + assert_all_close(model.singular_values, Nx.tensor([38.89730453491211]), atol: 1.0e-3) + + assert model.num_samples_seen == Nx.u64(2) + + assert model.mean == Nx.tensor([28.5, 2.0, 3.5]) + + assert model.variance == Nx.tensor([1512.5, 0.0, 0.5]) + assert_all_close(model.explained_variance, Nx.tensor([1513.000244140625]), atol: 1.0e-3) assert_all_close( @@ -67,42 +71,52 @@ defmodule Scholar.Decomposition.PCATest do Nx.tensor([1.0]) ) - assert_all_close(model.singular_values, Nx.tensor([38.89730453491211]), atol: 1.0e-3) - assert model.mean == Nx.tensor([28.5, 2.0, 3.5]) - assert model.num_components == 1 - assert model.num_samples == Nx.tensor(2) - assert model.num_features == Nx.tensor(3) + assert not model.whiten? end - test "transform test - :whiten set to false" do - model = PCA.fit(x()) + test "transform test - :num_components set to 1" do + model = PCA.fit(x(), num_components: 1) assert_all_close( PCA.transform(model, x()), Nx.tensor([ - [1.3834056854248047, 0.2935786843299866], - [2.221898078918457, -0.2513348460197449], - [3.6053037643432617, 0.0422438383102417], - [-1.3834056854248047, -0.2935786843299866], - [-2.221898078918457, 0.2513348460197449], - [-3.6053037643432617, -0.0422438383102417] + [1.3832788467407227], + [2.222006320953369], + [3.605285167694092], + [-1.3832788467407227], + [-2.222006320953369], + [-3.605285167694092] ]), atol: 1.0e-2 ) end - test "transform test - :whiten set to false and and num components different than min(num_samples, num_components)" do + test "transform test - :num_components set to 2" do model = PCA.fit(x3(), num_components: 2) assert_all_close( model.components, Nx.tensor([ - [0.98732591, 0.15474766, -0.03522361], - [-0.14912572, 0.98053261, 0.12773922085762024] + [0.9874106645584106, 0.1541961133480072, -0.035266418009996414], + [-0.14880891144275665, 0.9811362028121948, 0.1234002411365509] ]), atol: 1.0e-2 ) + assert_all_close(model.singular_values, Nx.tensor([20.272085189819336, 3.1355254650115967]), + atol: 1.0e-2 + ) + + assert model.num_samples_seen == Nx.u64(6) + + assert_all_close(model.mean, Nx.tensor([3.83333333, 0.33333333, 1.66666667]), atol: 1.0e-2) + + assert_all_close( + model.variance, + Nx.tensor([80.16666412353516, 3.866666793823242, 0.6666666865348816]), + atol: 1.0e-2 + ) + assert_all_close( model.explained_variance, Nx.tensor([82.19153594970703, 1.966333031654358]), @@ -115,38 +129,29 @@ defmodule Scholar.Decomposition.PCATest do atol: 1.0e-2 ) - assert_all_close(model.singular_values, Nx.tensor([20.272090911865234, 3.13554849]), - atol: 1.0e-2 - ) - - assert_all_close(model.mean, Nx.tensor([3.83333333, 0.33333333, 1.66666667]), atol: 1.0e-2) - assert model.num_components == 2 - assert model.num_samples == Nx.tensor(6) - assert model.num_features == Nx.tensor(3) - assert_all_close( PCA.transform(model, x3()), Nx.tensor([ - [-5.02537027, -0.41628357768058777], - [-5.977472305297852, -0.39489707350730896], - [-7.084321975708008, -1.3540430068969727], - [-0.6961240172386169, 0.6928002834320068], - [17.23049002, -1.010930061340332], - [1.5527995824813843, 2.48335338] + [-5.025101184844971, -0.42440494894981384], + [-5.977245330810547, -0.3989962935447693], + [-7.083585739135742, -1.3547236919403076], + [-0.6965337991714478, 0.6958313584327698], + [17.231054306030273, -1.001592755317688], + [1.5514134168624878, 2.483886241912842] ]), atol: 1.0e-2 ) end test "transform test - :whiten set to false and different data in fit and transform" do - model = PCA.fit(x(), num_components: 2) + model = PCA.fit(x(), num_components: 1) assert_all_close( PCA.transform(model, x2()), Nx.tensor([ - [-3.018146276473999, -2.8090553283691406], - [-48.54806137084961, 24.394376754760742], - [-25.615192413330078, 8.298306465148926] + [-3.016932487487793], + [-48.558597564697266], + [-25.618776321411133] ]), atol: 1.0e-1, rtol: 1.0e-3 @@ -154,33 +159,33 @@ defmodule Scholar.Decomposition.PCATest do end test "transform test - :whiten set to true" do - model = PCA.fit(x()) + model = PCA.fit(x(), num_components: 1, whiten?: true) assert_all_close( - PCA.transform(model, x(), whiten: true), + PCA.transform(model, x()), Nx.tensor([ - [0.49096643924713135, 1.1939926147460938], - [0.7885448336601257, -1.0221858024597168], - [1.2795112133026123, 0.17180685698986053], - [-0.49096643924713135, -1.1939926147460938], - [-0.7885448336601257, 1.0221858024597168], - [-1.2795112133026123, -0.17180685698986053] + [0.4909214377403259], + [0.7885832190513611], + [1.279504656791687], + [-0.4909214377403259], + [-0.7885832190513611], + [-1.279504656791687] ]), atol: 1.0e-2 ) end - test "fit_transform test - :whiten set to false" do - model = PCA.fit(x()) + test "fit_transform test - :whiten? set to false and num_components set to 1" do + model = PCA.fit(x(), num_components: 1) assert_all_close( PCA.transform(model, x()), - PCA.fit_transform(x()), + PCA.fit_transform(x(), num_components: 1), atol: 1.0e-2 ) end - test "fit_transform test - :whiten set to false and and num components different than min(num_samples, num_components)" do + test "fit_transform test - :whiten? set to false and num_components set to 2" do model = PCA.fit(x3(), num_components: 2) assert_all_close( @@ -190,12 +195,12 @@ defmodule Scholar.Decomposition.PCATest do ) end - test "fit_transform test - :whiten set to true" do - model = PCA.fit(x()) + test "fit_transform test - :whiten? set to true" do + model = PCA.fit(x(), num_components: 1, whiten?: true) assert_all_close( - PCA.transform(model, x(), whiten: true), - PCA.fit_transform(x(), whiten: true), + PCA.transform(model, x()), + PCA.fit_transform(x(), num_components: 1, whiten?: true), atol: 1.0e-2 ) end @@ -203,15 +208,21 @@ defmodule Scholar.Decomposition.PCATest do describe "errors" do test "input rank different than 2" do assert_raise ArgumentError, - "expected x to have rank equal to: 2, got: 1", + """ + expected input tensor to have shape {num_samples, num_features}, \ + got tensor with shape: {4}\ + """, fn -> - PCA.fit(Nx.tensor([1, 2, 3, 4])) + PCA.fit(Nx.tensor([1, 2, 3, 4]), num_components: 1) end end - test "fit test - :num_components bigger than min(num_samples, num_features)" do + test "fit test - :num_components bigger than num_features" do assert_raise ArgumentError, - "expected :num_components to be integer in range 1 to 2, got: 4", + """ + num_components must be less than or equal to \ + num_features = 2, got 4 + """, fn -> PCA.fit(x(), num_components: 4) end @@ -219,25 +230,46 @@ defmodule Scholar.Decomposition.PCATest do test "fit test - :num_components is atom" do assert_raise NimbleOptions.ValidationError, - """ - expected :num_components option to match at least one given type, but didn't match any. Here are the reasons why it didn't match each of the allowed types: - - * invalid value for :num_components option: expected one of [nil], got: :two - * invalid value for :num_components option: expected positive integer, got: :two\ - """, + "invalid value for :num_components option: expected positive integer, got: :two", fn -> PCA.fit(x(), num_components: :two) end end - test "transform test - :whiten is not boolean" do + test "transform test - :whiten? is not boolean" do assert_raise NimbleOptions.ValidationError, - "invalid value for :whiten option: expected boolean, got: :yes", + "invalid value for :whiten? option: expected boolean, got: :yes", fn -> - model = PCA.fit(x()) - - PCA.transform(model, x(), whiten: :yes) + PCA.fit(x(), num_components: 1, whiten?: :yes) end end end + + test "partial_fit" do + model = PCA.fit(x()[0..2], num_components: 1) + model = PCA.partial_fit(model, x()[3..5]) + + assert Nx.shape(model.components) == {1, 2} + assert Nx.shape(model.singular_values) == {1} + assert model.num_samples_seen == Nx.u64(6) + assert model.mean == Nx.tensor([0.0, 0.0]) + assert Nx.shape(model.variance) == {2} + assert Nx.shape(model.explained_variance) == {1} + assert Nx.shape(model.explained_variance_ratio) == {1} + assert not model.whiten? + end + + test "incremental_fit" do + batches = Nx.to_batched(x(), 2) + model = PCA.incremental_fit(batches, num_components: 1) + + assert Nx.shape(model.components) == {1, 2} + assert Nx.shape(model.singular_values) == {1} + assert model.num_samples_seen == Nx.u64(6) + assert model.mean == Nx.tensor([0.0, 0.0]) + assert Nx.shape(model.variance) == {2} + assert Nx.shape(model.explained_variance) == {1} + assert Nx.shape(model.explained_variance_ratio) == {1} + assert not model.whiten? + end end