From dacf106f63125a3b57866e19449fdf15cc7a27ca Mon Sep 17 00:00:00 2001 From: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com> Date: Wed, 15 Nov 2023 17:21:10 +0100 Subject: [PATCH] MDS (#205) --- lib/scholar/covariance.ex | 182 --------- lib/scholar/manifold/mds.ex | 412 ++++++++++++++++++++ lib/scholar/metrics/distance.ex | 10 +- test/scholar/covariance/covariance_test.exs | 13 - test/scholar/manifold/mds_test.exs | 214 ++++++++++ test/scholar/manifold/tsne_test.exs | 140 +++---- 6 files changed, 702 insertions(+), 269 deletions(-) delete mode 100644 lib/scholar/covariance.ex create mode 100644 lib/scholar/manifold/mds.ex delete mode 100644 test/scholar/covariance/covariance_test.exs create mode 100644 test/scholar/manifold/mds_test.exs diff --git a/lib/scholar/covariance.ex b/lib/scholar/covariance.ex deleted file mode 100644 index 57fefa51..00000000 --- a/lib/scholar/covariance.ex +++ /dev/null @@ -1,182 +0,0 @@ -defmodule Scholar.Covariance do - @moduledoc ~S""" - Algorithms to estimate the covariance of features given a set of points. - - Time complexity of covariance estimation is $O(N * K^2)$ where $N$ is the number of samples - and $K$ is the number of features. - """ - import Nx.Defn - - opts = [ - center: [ - type: :boolean, - default: true, - doc: """ - If `true`, data will be centered before computation. - If `false`, data will not be centered before computation. - Useful when working with data whose mean is almost, but not exactly zero. - """ - ], - biased: [ - type: :boolean, - default: true, - doc: """ - If `true`, the matrix will be computed using biased covariation. If `false`, - algorithm uses unbiased covariation. - """ - ] - ] - - @opts_schema NimbleOptions.new!(opts) - - @deprecated "Use Nx.convariance/2 instead" - @doc """ - Computes covariance matrix for sample inputs `x`. - - The value on the position $Cov_{ij}$ in the $Cov$ matrix is calculated using the formula: - - #{~S''' - $$ Cov(X\_i, X\_j) = \frac{\sum\_{k}\left(x\_k - - \bar{x}\right)\left(y\_k - \bar{y}\right)}{N - 1} - $$ - Where: - * $X_i$ is a $i$th row of input - - * $x_k$ is a $k$th value of $X_i$ - - * $y_k$ is a $k$th value of $X_j$ - - * $\bar{x}$ is the mean of $X_i$ - - * $\bar{y}$ is the mean of $X_j$ - - * $N$ is the number of samples - - This is a non-biased version of covariance. - The biased version has $N$ in denominator instead of $N - 1$. - '''} - - ## Options - - #{NimbleOptions.docs(@opts_schema)} - - ## Example - - iex> Scholar.Covariance.covariance_matrix(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]])) - #Nx.Tensor< - f32[3][3] - [ - [104.22222137451172, 195.5555419921875, -13.333333015441895], - [195.5555419921875, 1089.5555419921875, 1.3333333730697632], - [-13.333333015441895, 1.3333333730697632, 2.6666667461395264] - ] - > - - iex> Scholar.Covariance.covariance_matrix(Nx.tensor([[3, 6], [2, 3], [7, 9], [5, 3]])) - #Nx.Tensor< - f32[2][2] - [ - [3.6875, 3.1875], - [3.1875, 6.1875] - ] - > - - iex> Scholar.Covariance.covariance_matrix(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), - ...> biased: false - ...> ) - #Nx.Tensor< - f32[3][3] - [ - [156.3333282470703, 293.33331298828125, -20.0], - [293.33331298828125, 1634.333251953125, 2.0], - [-20.0, 2.0, 4.0] - ] - > - """ - deftransform covariance_matrix(x, opts \\ []) do - covariance_matrix_n(x, NimbleOptions.validate!(opts, @opts_schema)) - end - - defnp covariance_matrix_n(x, opts) do - if Nx.rank(x) != 2 do - raise ArgumentError, "expected data to have rank equal 2, got: #{inspect(Nx.rank(x))}" - end - - num_samples = Nx.axis_size(x, 0) - x = if opts[:center], do: x - Nx.mean(x, axes: [0]), else: x - matrix = Nx.dot(x, [0], x, [0]) - - if opts[:biased] do - matrix / num_samples - else - matrix / (num_samples - 1) - end - end - - @deprecated "Use Scholar.Stats.correlation_matrix/2 instead" - @doc """ - Computes correlation matrix for sample inputs `x`. - - The value on the position $Corr_{ij}$ in the $Corr$ matrix is calculated using the formula: - #{~S''' - $$ Corr(X\_i, X\_j) = \frac{Cov(X\_i, X\_j)}{\sqrt{Cov(X\_i, X\_i)Cov(X\_j, X\_j)}} $$ - Where: - * $X_i$ is a $i$th row of input - - * $Cov(X\_i, X\_j)$ is covariance between features $X_i$ and $X_j$ - '''} - - ## Options - - #{NimbleOptions.docs(@opts_schema)} - - ## Example - - iex> Scholar.Covariance.correlation_matrix(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]])) - #Nx.Tensor< - f32[3][3] - [ - [1.0, 0.580316960811615, -0.7997867465019226], - [0.580316960811615, 1.0, 0.024736011400818825], - [-0.7997867465019226, 0.024736011400818825, 1.0] - ] - > - - iex> Scholar.Covariance.correlation_matrix(Nx.tensor([[3, 6], [2, 3], [7, 9], [5, 3]])) - #Nx.Tensor< - f32[2][2] - [ - [1.0, 0.6673083305358887], - [0.6673083305358887, 1.0] - ] - > - - iex> Scholar.Covariance.correlation_matrix(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), - ...> biased: false - ...> ) - #Nx.Tensor< - f32[3][3] - [ - [1.0, 0.5803170204162598, -0.7997867465019226], - [0.5803170204162598, 1.0, 0.024736013263463974], - [-0.7997867465019226, 0.024736013263463974, 1.0] - ] - > - """ - - deftransform correlation_matrix(x, opts \\ []) do - correlation_matrix_n(x, NimbleOptions.validate!(opts, @opts_schema)) - end - - defnp correlation_matrix_n(x, opts) do - variances = - if opts[:biased] do - Nx.variance(x, axes: [0]) - else - Nx.variance(x, axes: [0], ddof: 1) - end - - Scholar.Covariance.covariance_matrix(x, opts) / - Nx.sqrt(Nx.new_axis(variances, 1) * Nx.new_axis(variances, 0)) - end -end diff --git a/lib/scholar/manifold/mds.ex b/lib/scholar/manifold/mds.ex new file mode 100644 index 00000000..4be27464 --- /dev/null +++ b/lib/scholar/manifold/mds.ex @@ -0,0 +1,412 @@ +defmodule Scholar.Manifold.MDS do + @moduledoc """ + TSNE (t-Distributed Stochastic Neighbor Embedding) is a nonlinear dimensionality reduction technique. + + ## References + + * [t-SNE: t-Distributed Stochastic Neighbor Embedding](http://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf) + """ + import Nx.Defn + import Scholar.Shared + alias Scholar.Metrics.Distance + + @derive {Nx.Container, containers: [:embedding, :stress, :n_iter]} + defstruct [:embedding, :stress, :n_iter] + + opts_schema = [ + num_components: [ + type: :pos_integer, + default: 2, + doc: ~S""" + Dimension of the embedded space. + """ + ], + metric: [ + type: :boolean, + default: true, + doc: ~S""" + If `true`, use dissimilarities as metric distances in the embedding space. + """ + ], + normalized_stress: [ + type: :boolean, + default: false, + doc: ~S""" + If `true`, normalize the stress by the sum of squared dissimilarities. + Only valid if `metric` is `false`. + """ + ], + eps: [ + type: :float, + default: 1.0e-3, + doc: ~S""" + Tolerance for stopping criterion. + """ + ], + max_iter: [ + type: :pos_integer, + default: 300, + doc: ~S""" + Maximum number of iterations for the optimization. + """ + ], + key: [ + type: {:custom, Scholar.Options, :key, []}, + doc: """ + Determines random number generation for centroid initialization. + If the key is not provided, it is set to `Nx.Random.key(System.system_time())`. + """ + ], + n_init: [ + type: :pos_integer, + default: 8, + doc: ~S""" + Number of times the embedding will be computed with different centroid seeds. + The final embedding is the embedding with the lowest stress. + """ + ] + ] + + @opts_schema NimbleOptions.new!(opts_schema) + + # initialize x randomly or pass the init x earlier + defnp smacof(dissimilarities, x, max_iter, opts) do + similarities_flat = Nx.flatten(dissimilarities) + similarities_flat_indices = lower_triangle_indices(dissimilarities) + + similarities_flat_w = Nx.take(similarities_flat, similarities_flat_indices) + + metric = if opts[:metric], do: 1, else: 0 + normalized_stress = if opts[:normalized_stress], do: 1, else: 0 + eps = opts[:eps] + n = Nx.axis_size(dissimilarities, 0) + + {{x, stress, i}, _} = + while {{x, _stress = Nx.Constants.infinity(Nx.type(dissimilarities)), i = 0}, + {dissimilarities, max_iter, similarities_flat_indices, similarities_flat, + similarities_flat_w, old_stress = Nx.Constants.infinity(Nx.type(dissimilarities)), + metric, normalized_stress, eps, stop_value = 0}}, + i < max_iter and not stop_value do + dis = Distance.pairwise_euclidean(x) + + disparities = + if metric do + dissimilarities + else + dis_flat = Nx.flatten(dis) + + dis_flat_indices = lower_triangle_indices(dis) + + dis_flat_w = Nx.take(dis_flat, dis_flat_indices) + + disparities_flat_model = + Scholar.Linear.IsotonicRegression.fit(similarities_flat_w, dis_flat_w, + increasing: true + ) + + model = special_preprocess(disparities_flat_model) + + disparities_flat = + Scholar.Linear.IsotonicRegression.predict(model, similarities_flat_w) + + disparities = + Nx.indexed_put( + dis_flat, + Nx.new_axis(similarities_flat_indices, -1), + disparities_flat + ) + + disparities = Nx.reshape(disparities, {n, n}) + + disparities * Nx.sqrt(n * (n - 1) / 2 / Nx.sum(disparities ** 2)) + end + + stress = Nx.sum((Nx.flatten(dis) - Nx.flatten(disparities)) ** 2) / 2 + + stress = + if normalized_stress do + Nx.sqrt(stress / (Nx.sum(Nx.flatten(disparities) ** 2) / 2)) + else + stress + end + + dis = Nx.select(dis == 0, 1.0e-5, dis) + ratio = disparities / dis + b = -ratio + b = Nx.put_diagonal(b, Nx.take_diagonal(b) + Nx.sum(ratio, axes: [1])) + x = 1.0 / n * Nx.dot(b, x) + + dis = Nx.sum(Nx.sqrt(Nx.sum(x ** 2, axes: [1]))) + + stop_value = if old_stress - stress / dis < eps, do: 1, else: 0 + + old_stress = stress / dis + + {{x, stress, i + 1}, + {dissimilarities, max_iter, similarities_flat_indices, similarities_flat, + similarities_flat_w, old_stress, metric, normalized_stress, eps, stop_value}} + end + + {x, stress, i} + end + + defnp mds_main_loop(dissimilarities, x, _key, opts) do + n_init = opts[:n_init] + + type = Nx.Type.merge(to_float_type(x), to_float_type(dissimilarities)) + dissimilarities = Nx.as_type(dissimilarities, type) + x = Nx.as_type(x, type) + + dissimilarities = Distance.pairwise_euclidean(dissimilarities) + + {{best, best_stress, best_iter}, _} = + while {{best = x, best_stress = Nx.Constants.infinity(type), best_iter = 0}, + {n_init, dissimilarities, x, i = 0}}, + i < n_init do + {temp, stress, iter} = smacof(dissimilarities, x, opts[:max_iter], opts) + + {best, best_stress, best_iter} = + if stress < best_stress, do: {temp, stress, iter}, else: {best, best_stress, best_iter} + + {{best, best_stress, best_iter}, {n_init, dissimilarities, x, i + 1}} + end + + {best, best_stress, best_iter} + end + + defnp mds_main_loop(dissimilarities, key, opts) do + n_init = opts[:n_init] + max_iter = opts[:max_iter] + num_samples = Nx.axis_size(dissimilarities, 0) + + type = to_float_type(dissimilarities) + dissimilarities = Nx.as_type(dissimilarities, type) + + {dummy, key} = + Nx.Random.uniform(key, + shape: {num_samples, opts[:num_components]}, + type: type + ) + + dissimilarities = Distance.pairwise_euclidean(dissimilarities) + + {{best, best_stress, best_iter}, _} = + while {{best = dummy, best_stress = Nx.Constants.infinity(type), best_iter = 0}, + {n_init, key, max_iter, dissimilarities, i = 0}}, + i < n_init do + num_samples = Nx.axis_size(dissimilarities, 0) + + {x, key} = + Nx.Random.uniform(key, shape: {num_samples, opts[:num_components]}, type: type) + + {temp, stress, iter} = smacof(dissimilarities, x, max_iter, opts) + + {best, best_stress, best_iter} = + if stress < best_stress, do: {temp, stress, iter}, else: {best, best_stress, best_iter} + + {{best, best_stress, best_iter}, {n_init, key, max_iter, dissimilarities, i + 1}} + end + + {best, best_stress, best_iter} + end + + defnp lower_triangle_indices(tensor) do + n = Nx.axis_size(tensor, 0) + + temp = Nx.broadcast(Nx.s64(0), {div(n * (n - 1), 2)}) + + {temp, _} = + while {temp, {i = 0, j = 0}}, i < n ** 2 do + {temp, j} = + if Nx.remainder(i, n) < Nx.quotient(i, n) do + temp = Nx.indexed_put(temp, Nx.new_axis(j, -1), i) + {temp, j + 1} + else + {temp, j} + end + + {temp, {i + 1, j}} + end + + temp + end + + @doc """ + Fits MDS for sample inputs `x`. It is simpyfied version of `fit/3` function. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Return Values + + Returns struct with embedded data, stress value, and number of iterations for best run. + + ## Examples + + iex> x = Nx.iota({4,5}) + iex> key = Nx.Random.key(42) + iex> Scholar.Manifold.MDS.fit(x, key: key) + %Scholar.Manifold.MDS{ + embedding: Nx.tensor( + [ + [13.072145462036133, -10.424199104309082], + [5.13038969039917, -2.341259479522705], + [-5.651908874511719, 1.7662434577941895], + [-12.550626754760742, 10.999215126037598] + ] + ), + stress: Nx.tensor( + 0.36994707584381104 + ), + n_iter: Nx.tensor( + 20 + ) + } + """ + deftransform fit(x) do + opts = NimbleOptions.validate!([], @opts_schema) + key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) + fit_n(x, key, opts) + end + + @doc """ + Fits MDS for sample inputs `x`. It is simpyfied version of `fit/3` function. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Return Values + + Returns struct with embedded data, stress value, and number of iterations for best run. + + ## Examples + + iex> x = Nx.iota({4,5}) + iex> key = Nx.Random.key(42) + iex> Scholar.Manifold.MDS.fit(x, num_components: 2, key: key) + %Scholar.Manifold.MDS{ + embedding: Nx.tensor( + [ + [13.072145462036133, -10.424199104309082], + [5.13038969039917, -2.341259479522705], + [-5.651908874511719, 1.7662434577941895], + [-12.550626754760742, 10.999215126037598] + ] + ), + stress: Nx.tensor( + 0.36994707584381104 + ), + n_iter: Nx.tensor( + 20 + ) + } + """ + deftransform fit(x, opts) when is_list(opts) do + opts = NimbleOptions.validate!(opts, @opts_schema) + key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) + fit_n(x, key, opts) + end + + defnp fit_n(x, key, opts) do + {best, best_stress, best_iter} = mds_main_loop(x, key, opts) + %__MODULE__{embedding: best, stress: best_stress, n_iter: best_iter} + end + + @doc """ + Fits MDS for sample inputs `x`. It is simpyfied version of `fit/3` function. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Return Values + + Returns struct with embedded data, stress value, and number of iterations for best run. + + ## Examples + + iex> x = Nx.iota({4,5}) + iex> key = Nx.Random.key(42) + iex> init = Nx.reverse(Nx.iota({4,2})) + iex> Scholar.Manifold.MDS.fit(x, init) + %Scholar.Manifold.MDS{ + embedding: Nx.tensor( + [ + [11.858541488647461, 11.858541488647461], + [3.9528470039367676, 3.9528470039367676], + [-3.9528470039367676, -3.9528470039367676], + [-11.858541488647461, -11.858541488647461] + ] + ), + stress: Nx.tensor( + 0.0 + ), + n_iter: Nx.tensor( + 3 + ) + } + """ + deftransform fit(x, init) do + opts = NimbleOptions.validate!([], @opts_schema) + key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) + fit_n(x, init, key, opts) + end + + @doc """ + Fits MDS for sample inputs `x`. It is simpyfied version of `fit/3` function. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Return Values + + Returns struct with embedded data, stress value, and number of iterations for best run. + + ## Examples + + iex> x = Nx.iota({4,5}) + iex> key = Nx.Random.key(42) + iex> init = Nx.reverse(Nx.iota({4,3})) + iex> Scholar.Manifold.MDS.fit(x, init, num_components: 3, key: key) + %Scholar.Manifold.MDS{ + embedding: Nx.tensor( + [ + [9.682458877563477, 9.682458877563477, 9.682458877563477], + [3.2274858951568604, 3.2274858951568604, 3.2274858951568604], + [-3.2274863719940186, -3.2274863719940186, -3.2274863719940186], + [-9.682458877563477, -9.682458877563477, -9.682458877563477] + ] + ), + stress: Nx.tensor( + 9.094947017729282e-12 + ), + n_iter: Nx.tensor( + 3 + ) + } + """ + deftransform fit(x, init, opts) when is_list(opts) do + opts = NimbleOptions.validate!(opts, @opts_schema) + key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) + fit_n(x, init, key, opts) + end + + defnp fit_n(x, init, key, opts) do + {best, best_stress, best_iter} = mds_main_loop(x, init, key, opts) + %__MODULE__{embedding: best, stress: best_stress, n_iter: best_iter} + end + + defnp special_preprocess(model) do + %Scholar.Linear.IsotonicRegression{ + model + | preprocess: + Scholar.Interpolation.Linear.fit( + model.x_thresholds, + model.y_thresholds + ) + } + end +end diff --git a/lib/scholar/metrics/distance.ex b/lib/scholar/metrics/distance.ex index d937c36b..2d8c7292 100644 --- a/lib/scholar/metrics/distance.ex +++ b/lib/scholar/metrics/distance.ex @@ -598,7 +598,8 @@ defmodule Scholar.Metrics.Distance do """ defn pairwise_squared_euclidean(x) do x_norm = Nx.sum(x * x, axes: [1], keep_axes: true) - Nx.max(0, x_norm + Nx.transpose(x_norm) - 2 * Nx.dot(x, [-1], x, [-1])) + dist = Nx.max(0, x_norm + Nx.transpose(x_norm) - 2 * Nx.dot(x, [-1], x, [-1])) + Nx.put_diagonal(dist, Nx.broadcast(Nx.tensor(0, type: Nx.type(dist)), {Nx.axis_size(x, 0)})) end @doc """ @@ -688,8 +689,8 @@ defmodule Scholar.Metrics.Distance do [ [0.0, 0.0793418288230896, 0.1139642596244812, 0.13029760122299194, 0.1397092342376709, 0.14581435918807983], [0.0793418288230896, 0.0, 0.0032819509506225586, 0.006624102592468262, 0.008954286575317383, 0.01060718297958374], - [0.1139642596244812, 0.0032819509506225586, 1.1920928955078125e-7, 5.82277774810791e-4, 0.0013980269432067871, 0.0020949840545654297], - [0.13029760122299194, 0.006624102592468262, 5.82277774810791e-4, 5.960464477539063e-8, 1.7595291137695312e-4, 4.686713218688965e-4], + [0.1139642596244812, 0.0032819509506225586, 0.0, 5.82277774810791e-4, 0.0013980269432067871, 0.0020949840545654297], + [0.13029760122299194, 0.006624102592468262, 5.82277774810791e-4, 0.0, 1.7595291137695312e-4, 4.686713218688965e-4], [0.1397092342376709, 0.008954286575317383, 0.0013980269432067871, 1.7595291137695312e-4, 0.0, 7.027387619018555e-5], [0.14581435918807983, 0.01060718297958374, 0.0020949840545654297, 4.686713218688965e-4, 7.027387619018555e-5, 0.0] ] @@ -697,6 +698,7 @@ defmodule Scholar.Metrics.Distance do """ defn pairwise_cosine(x) do x_normalized = Scholar.Preprocessing.normalize(x, axes: [1]) - Nx.max(0, 1 - Nx.dot(x_normalized, [-1], x_normalized, [-1])) + dist = Nx.max(0, 1 - Nx.dot(x_normalized, [-1], x_normalized, [-1])) + Nx.put_diagonal(dist, Nx.broadcast(Nx.tensor(0, type: Nx.type(dist)), {Nx.axis_size(x, 0)})) end end diff --git a/test/scholar/covariance/covariance_test.exs b/test/scholar/covariance/covariance_test.exs deleted file mode 100644 index 340ce444..00000000 --- a/test/scholar/covariance/covariance_test.exs +++ /dev/null @@ -1,13 +0,0 @@ -defmodule Scholar.CovarianceTest do - use Scholar.Case, async: true - alias Scholar.Covariance - doctest Covariance - - describe "errors" do - test "rank of input not equal to 2" do - assert_raise ArgumentError, - "expected data to have rank equal 2, got: 3", - fn -> Covariance.covariance_matrix(Nx.tensor([[[1, 2], [3, 4]]])) end - end - end -end diff --git a/test/scholar/manifold/mds_test.exs b/test/scholar/manifold/mds_test.exs new file mode 100644 index 00000000..9548c42f --- /dev/null +++ b/test/scholar/manifold/mds_test.exs @@ -0,0 +1,214 @@ +defmodule Scholar.Manifold.MDSTest do + use Scholar.Case, async: true + alias Scholar.Manifold.MDS + doctest MDS + + def x() do + Nx.iota({10, 50}) + end + + def key() do + Nx.Random.key(42) + end + + test "all default" do + model = MDS.fit(x(), key: key()) + + assert_all_close( + model.embedding, + Nx.tensor([ + [-1200.2181396484375, -1042.11083984375], + [-985.0137939453125, -750.6790771484375], + [-706.14013671875, -532.040771484375], + [-402.91387939453125, -344.8670959472656], + [-163.916015625, -77.55931091308594], + [137.63134765625, 111.43733215332031], + [450.9678649902344, 284.375], + [712.8345947265625, 524.7731323242188], + [935.5824584960938, 807.938720703125], + [1221.1859130859375, 1018.7330932617188] + ]) + ) + + assert_all_close(model.stress, 390.99090576171875) + assert_all_close(model.n_iter, Nx.tensor(93)) + end + + test "non-default num_components" do + model = MDS.fit(x(), num_components: 5, key: key()) + + assert_all_close( + model.embedding, + Nx.tensor([ + [ + -753.6793823242188, + -1215.4837646484375, + -604.7247314453125, + 136.4171905517578, + 306.40069580078125 + ], + [ + -536.5817260742188, + -982.4620971679688, + -457.0782165527344, + 113.92396545410156, + 232.4468994140625 + ], + [ + -368.63018798828125, + -696.4574584960938, + -344.57952880859375, + 120.76351165771484, + 164.8445281982422 + ], + [ + -216.7689666748047, + -406.25250244140625, + -248.1519317626953, + 85.04608154296875, + 65.24085235595703 + ], + [ + -52.247528076171875, + -147.1763916015625, + -101.35457611083984, + 38.1723747253418, + -30.632429122924805 + ], + [ + 112.39735412597656, + 137.04685974121094, + 19.412824630737305, + -15.030402183532715, + -70.50560760498047 + ], + [ + 270.4787902832031, + 401.2673034667969, + 195.0669403076172, + -36.55331039428711, + -105.96172332763672 + ], + [ + 381.5384216308594, + 700.9933471679688, + 348.5281982421875, + -61.6308708190918, + -142.2161407470703 + ], + [ + 508.359130859375, + 980.1181030273438, + 507.714111328125, + -153.78958129882812, + -164.67311096191406 + ], + [ + 655.1340942382812, + 1228.4066162109375, + 685.1668701171875, + -227.31895446777344, + -254.94395446777344 + ] + ]) + ) + + assert_all_close(model.stress, 641.52490234375) + assert_all_close(model.n_iter, Nx.tensor(130)) + end + + test "non-default metric" do + model = MDS.fit(x(), metric: false, key: key()) + + assert_all_close( + model.embedding, + Nx.tensor([ + [0.4611709713935852, -0.2790529131889343], + [0.10750522464513779, 0.3869015574455261], + [0.10845339298248291, -0.619588315486908], + [-0.3274216949939728, -0.2036580592393875], + [0.432122141122818, 0.4288368821144104], + [-0.2664470970630646, 0.1712798774242401], + [-0.46502357721328735, 0.015750018879771233], + [0.35657963156700134, 0.028018075972795486], + [-0.11095760017633438, -0.3872125744819641], + [-0.20736312866210938, 0.41101184487342834] + ]) + ) + + assert_all_close(model.stress, 1.2879878282546997) + assert_all_close(model.n_iter, Nx.tensor(18)) + end + + test "option normalized_stress with metric set to false" do + model = + MDS.fit(x(), metric: false, key: key(), normalized_stress: true) + + assert_all_close( + model.embedding, + Nx.tensor([ + [-0.5107499957084656, -0.5828369855880737], + [-0.008806264027953148, -0.4549526870250702], + [-0.5534653663635254, -0.02513509802520275], + [0.11427811533212662, 0.17350295186042786], + [0.45669451355934143, 0.20050597190856934], + [0.010616336017847061, -0.09705149382352829], + [-0.27859434485435486, 0.3822994530200958], + [0.353694885969162, -0.17320780456066132], + [0.49716615676879883, -0.10724353790283203], + [-0.12109922617673874, 0.4835425913333893] + ]) + ) + + assert_all_close(model.stress, 0.24878354370594025) + assert_all_close(model.n_iter, Nx.tensor(8)) + end + + test "epsilon set to a smaller then default value" do + model = MDS.fit(x(), key: key(), eps: 1.0e-4) + + assert_all_close( + model.embedding, + Nx.tensor([ + [-1210.0882568359375, -1031.7977294921875], + [-975.5465087890625, -762.0400390625], + [-702.4406127929688, -536.8583984375], + [-407.59564208984375, -339.2669372558594], + [-155.48812866210938, -88.38337707519531], + [139.16014099121094, 109.21498107910156], + [439.32861328125, 299.55438232421875], + [705.1881713867188, 533.9835205078125], + [944.5883178710938, 798.3893432617188], + [1222.8939208984375, 1017.2042846679688] + ]) + ) + + # as expected smaller value of stress (loss) and bigger number of iterations that all default + assert_all_close(model.stress, 86.7530288696289) + assert_all_close(model.n_iter, Nx.tensor(197)) + end + + test "smaller max_iter value (100)" do + model = MDS.fit(x(), key: key(), eps: 1.0e-4, max_iter: 100) + + assert_all_close( + model.embedding, + Nx.tensor([ + [-1201.354736328125, -1040.9530029296875], + [-983.8963012695312, -752.0218505859375], + [-705.769775390625, -532.5292358398438], + [-403.48602294921875, -344.1869201660156], + [-162.94749450683594, -78.80620574951172], + [137.8325958251953, 111.14128875732422], + [449.6497497558594, 286.1120300292969], + [711.979736328125, 525.80078125], + [936.6356811523438, 806.8446655273438], + [1221.3565673828125, 1018.59814453125] + ]) + ) + + # same params as in previous test, but smaller number of iterations, cupped on 100 + assert_all_close(model.stress, 337.1789245605469) + assert_all_close(model.n_iter, Nx.tensor(100)) + end +end diff --git a/test/scholar/manifold/tsne_test.exs b/test/scholar/manifold/tsne_test.exs index 884321d4..9d6b8402 100644 --- a/test/scholar/manifold/tsne_test.exs +++ b/test/scholar/manifold/tsne_test.exs @@ -101,16 +101,16 @@ defmodule Scholar.Manifold.TSNETest do expected = Nx.tensor([ - [-40.56040954589844, -68.16744995117188], - [4.963836669921875, 15.062470436096191], - [52.409217834472656, 62.55381393432617], - [3.6575255393981934, -21.331544876098633], - [1.9406723976135254, 9.503897666931152], - [50.68553161621094, 97.37788391113281], - [12.861334800720215, -5.144858360290527], - [-0.08688473701477051, 11.097810745239258], - [-16.302968978881836, 4.448402404785156], - [-69.56835174560547, -105.40238189697266] + [-6.426266193389893, 8.332388877868652], + [-4.163736343383789, -2.494565963745117], + [-31.295236587524414, -31.229860305786133], + [-15.800895690917969, 30.764965057373047], + [2.5463109016418457, 0.6726856231689453], + [65.74525451660156, 27.774755477905273], + [-75.81666564941406, 25.76173973083496], + [-1.8086342811584473, -11.712549209594727], + [62.6073112487793, -31.151100158691406], + [4.412214279174805, -16.7202091217041] ]) assert_all_close(embedding, expected) @@ -133,16 +133,16 @@ defmodule Scholar.Manifold.TSNETest do expected = Nx.tensor([ - [-0.720294713973999, 7.20734977722168, -6.391678810119629], - [-7.667532444000244, 3.2291605472564697, 9.317089080810547], - [-0.24787211418151855, -2.148440361022949, -3.8253002166748047], - [2.0674548149108887, 5.781192779541016, 5.811742782592773], - [28.587512969970703, -37.42428207397461, -90.23539733886719], - [-17.36945152282715, 17.298032760620117, 71.87922668457031], - [1.9807157516479492, -1.1625878810882568, 7.809732437133789], - [-1.256453275680542, 3.8095149993896484, 2.0424277782440186], - [-0.840096116065979, 3.5864388942718506, 6.742667198181152], - [-4.535242557525635, -0.17502427101135254, -3.1499686241149902] + [-1.2834149599075317, -65.1137924194336, -93.11331939697266], + [6.327005386352539, -6.078601360321045, -13.47922134399414], + [-10.256635665893555, 4.116176605224609, 3.168473243713379], + [9.371475219726562, 7.126360893249512, -4.854560852050781], + [17.84250259399414, 72.40726470947266, 102.68474578857422], + [-1.725927472114563, -0.2951417863368988, 6.848171234130859], + [1.584486961364746, -2.4704296588897705, -1.3732502460479736], + [-6.182180881500244, -13.586234092712402, -5.3855509757995605], + [-15.67674732208252, 4.1522440910339355, 14.045616149902344], + [-0.0035001635551452637, -0.25596773624420166, -8.5401611328125] ]) assert_all_close(embedding, expected) @@ -153,16 +153,16 @@ defmodule Scholar.Manifold.TSNETest do expected = Nx.tensor([ - [-9.334256172180176, 69.1719741821289], - [12.588174819946289, -16.575504302978516], - [35.39522933959961, 56.1950798034668], - [-31.067523956298828, 178.23777770996094], - [-134.37229919433594, -197.263671875], - [-58.42694854736328, -212.14434814453125], - [26.330211639404297, 16.41795539855957], - [124.88433074951172, -55.47932052612305], - [43.05698776245117, -13.188211441040039], - [-9.053865432739258, 174.62806701660156] + [-47.81830596923828, 123.67109680175781], + [1.1313371658325195, 12.589241027832031], + [171.11619567871094, -143.08013916015625], + [174.58798217773438, 18.175519943237305], + [-12.315837860107422, 16.606094360351562], + [-171.61473083496094, -7.717001914978027], + [-145.23251342773438, 91.37753295898438], + [18.64778709411621, -117.26519775390625], + [-28.962860107421875, 24.56144142150879], + [40.460567474365234, -18.91954231262207] ]) assert_all_close(embedding, expected) @@ -173,16 +173,16 @@ defmodule Scholar.Manifold.TSNETest do expected = Nx.tensor([ - [-1.5070298910140991, 1.7770516872406006], - [1.0464317798614502, 6.572760581970215], - [3.0646090507507324, 0.7151472568511963], - [-3.227736234664917, 0.7268509268760681], - [2.950016498565674, 2.1362667083740234], - [-2.0649163722991943, 1.4648656845092773], - [-0.26644080877304077, 0.9178383946418762], - [0.7323529720306396, -13.481656074523926], - [1.5609462261199951, -0.4036065340042114], - [-2.288417100906372, -0.4257314205169678] + [20.1014347076416, -8.380684852600098], + [-0.44876933097839355, -0.0935119092464447], + [-0.706710159778595, 5.1501970291137695], + [-3.5239882469177246, -2.9893593788146973], + [-0.4266725778579712, -2.9560775756835938], + [-0.4950576424598694, 8.39889907836914], + [5.351792335510254, -7.0025529861450195], + [-7.3377509117126465, 3.231973171234131], + [-8.974485397338867, 0.6503140926361084], + [-3.5397887229919434, 3.9906702041625977] ]) assert_all_close(embedding, expected) @@ -193,16 +193,16 @@ defmodule Scholar.Manifold.TSNETest do expected = Nx.tensor([ - [24.384241104125977, 9.98823070526123], - [9.558063507080078, 8.126869201660156], - [-86.89659881591797, -21.715559005737305], - [-1.2773358821868896, -6.753499984741211], - [-5.65122127532959, -13.650869369506836], - [0.5404881238937378, -24.523698806762695], - [2.29805064201355, -14.350067138671875], - [-10.488426208496094, -19.941368103027344], - [123.80097961425781, 99.55735778808594], - [-56.26921463012695, -16.739578247070312] + [-18.915882110595703, -2.7932233810424805], + [-71.29559326171875, -81.77560424804688], + [14.130539894104004, -157.9973907470703], + [-10.224050521850586, -4.1823410987854], + [-18.021028518676758, -9.254148483276367], + [-26.070053100585938, -1.6795881986618042], + [-13.233553886413574, 18.421363830566406], + [31.8947811126709, 104.57353973388672], + [129.38308715820312, 127.35003662109375], + [-17.648744583129883, 7.335616111755371] ]) assert_all_close(embedding, expected) @@ -213,16 +213,16 @@ defmodule Scholar.Manifold.TSNETest do expected = Nx.tensor([ - [17.54478645324707, -4.611420631408691], - [12.03654670715332, 4.872840881347656], - [6.249638080596924, -1.6746138334274292], - [94.60688018798828, 55.435829162597656], - [10.510167121887207, 4.246967315673828], - [19.59783935546875, 0.7140516638755798], - [-56.172481536865234, -14.603594779968262], - [-3.935652256011963, -2.9040825366973877], - [-116.89958953857422, -42.6511116027832], - [16.4613094329834, 1.176856517791748] + [-156.66343688964844, -34.973445892333984], + [78.5311050415039, -101.2839126586914], + [75.17900848388672, 34.4267692565918], + [-43.72159194946289, -102.33932495117188], + [153.0662078857422, 101.59364318847656], + [84.43276977539062, 0.04915308952331543], + [-44.736968994140625, -35.98069381713867], + [-79.83869171142578, 370.6204528808594], + [-32.189876556396484, -120.83247375488281], + [-34.056640625, -111.28446960449219] ]) assert_all_close(embedding, expected) @@ -253,16 +253,16 @@ defmodule Scholar.Manifold.TSNETest do expected = Nx.tensor([ - [-67.28031921386719, -6.817644119262695], - [-16.588186264038086, 93.91526794433594], - [14.03242015838623, -4.765425205230713], - [-20.412851333618164, 7.198407173156738], - [6.30552864074707, -16.696596145629883], - [36.38324737548828, -9.629066467285156], - [-26.198514938354492, -6.854467391967773], - [92.41073608398438, -39.33834457397461], - [9.657635688781738, -9.092098236083984], - [-28.308706283569336, -7.919832706451416] + [-20.31999397277832, -16.71266746520996], + [1.3730521202087402, -8.214937210083008], + [-9.726605415344238, 5.619704723358154], + [17.57297706604004, 20.00632095336914], + [4.100193023681641, -14.283998489379883], + [-35.61519241333008, 39.55061340332031], + [-9.747889518737793, -1.3829679489135742], + [4.221410274505615, -10.61603832244873], + [22.889684677124023, 52.71660232543945], + [25.252023696899414, -66.6830825805664] ]) assert_all_close(embedding, expected)