diff --git a/lib/scholar/covariance.ex b/lib/scholar/covariance.ex deleted file mode 100644 index f1ffc37c..00000000 --- a/lib/scholar/covariance.ex +++ /dev/null @@ -1,179 +0,0 @@ -defmodule Scholar.Covariance do - @moduledoc ~S""" - Algorithms to estimate the covariance of features given a set of points. - """ - import Nx.Defn - - opts = [ - center: [ - type: :boolean, - default: true, - doc: """ - If `true`, data will be centered before computation. - If `false`, data will not be centered before computation. - Useful when working with data whose mean is almost, but not exactly zero. - """ - ], - biased: [ - type: :boolean, - default: true, - doc: """ - If `true`, the matrix will be computed using biased covariation. If `false`, - algorithm uses unbiased covariation. - """ - ] - ] - - @opts_schema NimbleOptions.new!(opts) - - @deprecated "Use Nx.convariance/2 instead" - @doc """ - Computes covariance matrix for sample inputs `x`. - - The value on the position $Cov_{ij}$ in the $Cov$ matrix is calculated using the formula: - - #{~S''' - $$ Cov(X\_i, X\_j) = \frac{\sum\_{k}\left(x\_k - - \bar{x}\right)\left(y\_k - \bar{y}\right)}{N - 1} - $$ - Where: - * $X_i$ is a $i$th row of input - - * $x_k$ is a $k$th value of $X_i$ - - * $y_k$ is a $k$th value of $X_j$ - - * $\bar{x}$ is the mean of $X_i$ - - * $\bar{y}$ is the mean of $X_j$ - - * $N$ is the number of samples - - This is a non-biased version of covariance. - The biased version has $N$ in denominator instead of $N - 1$. - '''} - - ## Options - - #{NimbleOptions.docs(@opts_schema)} - - ## Example - - iex> Scholar.Covariance.covariance_matrix(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]])) - #Nx.Tensor< - f32[3][3] - [ - [104.22222137451172, 195.5555419921875, -13.333333015441895], - [195.5555419921875, 1089.5555419921875, 1.3333333730697632], - [-13.333333015441895, 1.3333333730697632, 2.6666667461395264] - ] - > - - iex> Scholar.Covariance.covariance_matrix(Nx.tensor([[3, 6], [2, 3], [7, 9], [5, 3]])) - #Nx.Tensor< - f32[2][2] - [ - [3.6875, 3.1875], - [3.1875, 6.1875] - ] - > - - iex> Scholar.Covariance.covariance_matrix(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), - ...> biased: false - ...> ) - #Nx.Tensor< - f32[3][3] - [ - [156.3333282470703, 293.33331298828125, -20.0], - [293.33331298828125, 1634.333251953125, 2.0], - [-20.0, 2.0, 4.0] - ] - > - """ - deftransform covariance_matrix(x, opts \\ []) do - covariance_matrix_n(x, NimbleOptions.validate!(opts, @opts_schema)) - end - - defnp covariance_matrix_n(x, opts) do - if Nx.rank(x) != 2 do - raise ArgumentError, "expected data to have rank equal 2, got: #{inspect(Nx.rank(x))}" - end - - num_samples = Nx.axis_size(x, 0) - x = if opts[:center], do: x - Nx.mean(x, axes: [0]), else: x - matrix = Nx.dot(x, [0], x, [0]) - - if opts[:biased] do - matrix / num_samples - else - matrix / (num_samples - 1) - end - end - - @deprecated "Use Scholar.Stats.correlation_matrix/2 instead" - @doc """ - Computes correlation matrix for sample inputs `x`. - - The value on the position $Corr_{ij}$ in the $Corr$ matrix is calculated using the formula: - #{~S''' - $$ Corr(X\_i, X\_j) = \frac{Cov(X\_i, X\_j)}{\sqrt{Cov(X\_i, X\_i)Cov(X\_j, X\_j)}} $$ - Where: - * $X_i$ is a $i$th row of input - - * $Cov(X\_i, X\_j)$ is covariance between features $X_i$ and $X_j$ - '''} - - ## Options - - #{NimbleOptions.docs(@opts_schema)} - - ## Example - - iex> Scholar.Covariance.correlation_matrix(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]])) - #Nx.Tensor< - f32[3][3] - [ - [1.0, 0.580316960811615, -0.7997867465019226], - [0.580316960811615, 1.0, 0.024736011400818825], - [-0.7997867465019226, 0.024736011400818825, 1.0] - ] - > - - iex> Scholar.Covariance.correlation_matrix(Nx.tensor([[3, 6], [2, 3], [7, 9], [5, 3]])) - #Nx.Tensor< - f32[2][2] - [ - [1.0, 0.6673083305358887], - [0.6673083305358887, 1.0] - ] - > - - iex> Scholar.Covariance.correlation_matrix(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), - ...> biased: false - ...> ) - #Nx.Tensor< - f32[3][3] - [ - [1.0, 0.5803170204162598, -0.7997867465019226], - [0.5803170204162598, 1.0, 0.024736013263463974], - [-0.7997867465019226, 0.024736013263463974, 1.0] - ] - > - """ - - deftransform correlation_matrix(x, opts \\ []) do - correlation_matrix_n(x, NimbleOptions.validate!(opts, @opts_schema)) - end - - defnp correlation_matrix_n(x, opts) do - variances = - if opts[:biased] do - Nx.variance(x, axes: [0]) - else - Nx.variance(x, axes: [0], ddof: 1) - end - - Scholar.Covariance.covariance_matrix(x, opts) / - Nx.sqrt(Nx.new_axis(variances, 1) * Nx.new_axis(variances, 0)) - end -end diff --git a/lib/scholar/linear/isotonic_regression.ex b/lib/scholar/linear/isotonic_regression.ex index 22996eb0..2198200a 100644 --- a/lib/scholar/linear/isotonic_regression.ex +++ b/lib/scholar/linear/isotonic_regression.ex @@ -4,6 +4,9 @@ defmodule Scholar.Linear.IsotonicRegression do observations by solving a convex optimization problem. It is a form of regression analysis that can be used as an alternative to polynomial regression to fit nonlinear data. + + Time complexity of isotonic regression is $O(N^2)$ where $N$ is the + number of points. """ require Nx import Nx.Defn, except: [transform: 2] @@ -306,6 +309,24 @@ defmodule Scholar.Linear.IsotonicRegression do } end + @doc """ + Preprocesses the `model` for prediction. + + Returns an updated `model`. This is a special version of `preprocess/1` that + does not trim duplicates so it can be used in defns. It is not recommended + to use this function directly. + """ + defn special_preprocess(model) do + %__MODULE__{ + model + | preprocess: + Scholar.Interpolation.Linear.fit( + model.x_thresholds, + model.y_thresholds + ) + } + end + deftransform check_preprocess(model) do if model.preprocess == {} do raise ArgumentError, diff --git a/lib/scholar/manifold/mds.ex b/lib/scholar/manifold/mds.ex new file mode 100644 index 00000000..4d86d7e4 --- /dev/null +++ b/lib/scholar/manifold/mds.ex @@ -0,0 +1,401 @@ +defmodule Scholar.Manifold.MDS do + @moduledoc """ + TSNE (t-Distributed Stochastic Neighbor Embedding) is a nonlinear dimensionality reduction technique. + + ## References + + * [t-SNE: t-Distributed Stochastic Neighbor Embedding](http://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf) + """ + import Nx.Defn + import Scholar.Shared + alias Scholar.Metrics.Distance + + @derive {Nx.Container, containers: [:embedding, :stress, :n_iter]} + defstruct [:embedding, :stress, :n_iter] + + opts_schema = [ + num_components: [ + type: :pos_integer, + default: 2, + doc: ~S""" + Dimension of the embedded space. + """ + ], + metric: [ + type: :boolean, + default: true, + doc: ~S""" + If `true`, use dissimilarities as metric distances in the embedding space. + """ + ], + normalized_stress: [ + type: :boolean, + default: false, + doc: ~S""" + If `true`, normalize the stress by the sum of squared dissimilarities. + Only valid if `metric` is `false`. + """ + ], + eps: [ + type: :float, + default: 1.0e-3, + doc: ~S""" + Tolerance for stopping criterion. + """ + ], + max_iter: [ + type: :pos_integer, + default: 300, + doc: ~S""" + Maximum number of iterations for the optimization. + """ + ], + key: [ + type: {:custom, Scholar.Options, :key, []}, + doc: """ + Determines random number generation for centroid initialization. + If the key is not provided, it is set to `Nx.Random.key(System.system_time())`. + """ + ], + n_init: [ + type: :pos_integer, + default: 4, + doc: ~S""" + Number of times the embedding will be computed with different centroid seeds. + The final embedding is the embedding with the lowest stress. + """ + ] + ] + + @opts_schema NimbleOptions.new!(opts_schema) + + # initialize x randomly or pass the init x earlier + defnp smacof(dissimilarities, x, max_iter, opts) do + similarities_flat = Nx.flatten(dissimilarities) + similarities_flat_indices = lower_triangle_indices(dissimilarities) + + similarities_flat_w = Nx.take(similarities_flat, similarities_flat_indices) + + metric = if opts[:metric], do: 1, else: 0 + normalized_stress = if opts[:normalized_stress], do: 1, else: 0 + eps = opts[:eps] + n = Nx.axis_size(dissimilarities, 0) + + {{x, stress, i}, _} = + while {{x, _stress = Nx.Constants.infinity(Nx.type(dissimilarities)), i = 0}, + {dissimilarities, max_iter, similarities_flat_indices, similarities_flat, + similarities_flat_w, old_stress = Nx.Constants.infinity(Nx.type(dissimilarities)), + metric, normalized_stress, eps, stop_value = 0}}, + i < max_iter and not stop_value do + dis = Distance.pairwise_euclidean(x) + + disparities = + if metric do + dissimilarities + else + dis_flat = Nx.flatten(dis) + + dis_flat_indices = lower_triangle_indices(dis) + + dis_flat_w = Nx.take(dis_flat, dis_flat_indices) + + disparities_flat_model = + Scholar.Linear.IsotonicRegression.fit(similarities_flat_w, dis_flat_w, + increasing: true + ) + + model = Scholar.Linear.IsotonicRegression.special_preprocess(disparities_flat_model) + + disparities_flat = + Scholar.Linear.IsotonicRegression.predict(model, similarities_flat_w) + + disparities = + Nx.indexed_put( + dis_flat, + Nx.new_axis(similarities_flat_indices, -1), + disparities_flat + ) + + disparities = Nx.reshape(disparities, {n, n}) + + disparities * Nx.sqrt(n * (n - 1) / 2 / Nx.sum(disparities ** 2)) + end + + stress = Nx.sum((Nx.flatten(dis) - Nx.flatten(disparities)) ** 2) / 2 + + stress = + if normalized_stress do + Nx.sqrt(stress / (Nx.sum(Nx.flatten(disparities) ** 2) / 2)) + else + stress + end + + dis = Nx.select(dis == 0, 1.0e-5, dis) + ratio = disparities / dis + b = -ratio + b = Nx.put_diagonal(b, Nx.take_diagonal(b) + Nx.sum(ratio, axes: [1])) + x = Nx.dot(b, x) * (1.0 / n) + + dis = Nx.sum(Nx.sqrt(Nx.sum(x ** 2, axes: [1]))) + + stop_value = if old_stress - stress / dis < eps, do: 1, else: 0 + + old_stress = stress / dis + + {{x, stress, i + 1}, + {dissimilarities, max_iter, similarities_flat_indices, similarities_flat, + similarities_flat_w, old_stress, metric, normalized_stress, eps, stop_value}} + end + + {x, stress, i} + end + + defnp mds_main_loop(dissimilarities, x, _key, opts) do + n_init = opts[:n_init] + + type = Nx.Type.merge(to_float_type(x), to_float_type(dissimilarities)) + dissimilarities = Nx.as_type(dissimilarities, type) + x = Nx.as_type(x, type) + + dissimilarities = Distance.pairwise_euclidean(dissimilarities) + + {{best, best_stress, best_iter}, _} = + while {{best = x, best_stress = Nx.Constants.infinity(type), best_iter = 0}, + {n_init, dissimilarities, x, i = 0}}, + i < n_init do + {temp, stress, iter} = smacof(dissimilarities, x, opts[:max_iter], opts) + + {best, best_stress, best_iter} = + if stress < best_stress, do: {temp, stress, iter}, else: {best, best_stress, best_iter} + + {{best, best_stress, best_iter}, {n_init, dissimilarities, x, i + 1}} + end + + {best, best_stress, best_iter} + end + + defnp mds_main_loop(dissimilarities, key, opts) do + n_init = opts[:n_init] + max_iter = opts[:max_iter] + num_samples = Nx.axis_size(dissimilarities, 0) + + type = to_float_type(dissimilarities) + dissimilarities = Nx.as_type(dissimilarities, type) + + {dummy, new_key} = + Nx.Random.uniform(key, + shape: {num_samples, opts[:num_components]}, + type: type + ) + + dissimilarities = Distance.pairwise_euclidean(dissimilarities) + + {{best, best_stress, best_iter}, _} = + while {{best = dummy, best_stress = Nx.Constants.infinity(type), best_iter = 0}, + {n_init, new_key, max_iter, dissimilarities, i = 0}}, + i < n_init do + num_samples = Nx.axis_size(dissimilarities, 0) + + {x, new_key} = + Nx.Random.uniform(new_key, shape: {num_samples, opts[:num_components]}, type: type) + + {temp, stress, iter} = smacof(dissimilarities, x, max_iter, opts) + + {best, best_stress, best_iter} = + if stress < best_stress, do: {temp, stress, iter}, else: {best, best_stress, best_iter} + + {{best, best_stress, best_iter}, {n_init, new_key, max_iter, dissimilarities, i + 1}} + end + + {best, best_stress, best_iter} + end + + defnp lower_triangle_indices(tensor) do + n = Nx.axis_size(tensor, 0) + + temp = Nx.broadcast(Nx.s64(0), {div(n * (n - 1), 2)}) + + {temp, _} = + while {temp, {i = 0, j = 0}}, i < n ** 2 do + {temp, j} = + if Nx.remainder(i, n) < Nx.quotient(i, n) do + temp = Nx.indexed_put(temp, Nx.new_axis(j, -1), i) + {temp, j + 1} + else + {temp, j} + end + + {temp, {i + 1, j}} + end + + temp + end + + @doc """ + Fits MDS for sample inputs `x`. It is simpyfied version of `fit/3` function. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Return Values + + Returns struct with embedded data, stress value, and number of iterations for best run. + + ## Examples + + iex> x = Nx.iota({4,5}) + iex> key = Nx.Random.key(42) + iex> Scholar.Manifold.MDS.fit(x, key: key) + %Scholar.Manifold.MDS{ + embedding: Nx.tensor( + [ + [16.3013916015625, -3.444634437561035], + [5.866805553436279, 1.6378790140151978], + [-5.487184524536133, 0.5837264657020569], + [-16.681013107299805, 1.2230290174484253] + ] + ), + stress: Nx.tensor( + 0.3993147909641266 + ), + n_iter: Nx.tensor( + 23 + ) + } + """ + deftransform fit(x) do + opts = NimbleOptions.validate!([], @opts_schema) + key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) + fit_n(x, key, opts) + end + + @doc """ + Fits MDS for sample inputs `x`. It is simpyfied version of `fit/3` function. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Return Values + + Returns struct with embedded data, stress value, and number of iterations for best run. + + ## Examples + + iex> x = Nx.iota({4,5}) + iex> key = Nx.Random.key(42) + iex> Scholar.Manifold.MDS.fit(x, num_components: 2, key: key) + %Scholar.Manifold.MDS{ + embedding: Nx.tensor( + [ + [16.3013916015625, -3.444634437561035], + [5.866805553436279, 1.6378790140151978], + [-5.487184524536133, 0.5837264657020569], + [-16.681013107299805, 1.2230290174484253] + ] + ), + stress: Nx.tensor( + 0.3993147909641266 + ), + n_iter: Nx.tensor( + 23 + ) + } + """ + deftransform fit(x, opts) when is_list(opts) do + opts = NimbleOptions.validate!(opts, @opts_schema) + key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) + fit_n(x, key, opts) + end + + defnp fit_n(x, key, opts) do + {best, best_stress, best_iter} = mds_main_loop(x, key, opts) + %__MODULE__{embedding: best, stress: best_stress, n_iter: best_iter} + end + + @doc """ + Fits MDS for sample inputs `x`. It is simpyfied version of `fit/3` function. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Return Values + + Returns struct with embedded data, stress value, and number of iterations for best run. + + ## Examples + + iex> x = Nx.iota({4,5}) + iex> key = Nx.Random.key(42) + iex> init = Nx.reverse(Nx.iota({4,2})) + iex> Scholar.Manifold.MDS.fit(x, init) + %Scholar.Manifold.MDS{ + embedding: Nx.tensor( + [ + [11.858541488647461, 11.858541488647461], + [3.9528470039367676, 3.9528470039367676], + [-3.9528470039367676, -3.9528470039367676], + [-11.858541488647461, -11.858541488647461] + ] + ), + stress: Nx.tensor( + 0.0 + ), + n_iter: Nx.tensor( + 3 + ) + } + """ + deftransform fit(x, init) do + opts = NimbleOptions.validate!([], @opts_schema) + key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) + fit_n(x, init, key, opts) + end + + @doc """ + Fits MDS for sample inputs `x`. It is simpyfied version of `fit/3` function. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Return Values + + Returns struct with embedded data, stress value, and number of iterations for best run. + + ## Examples + + iex> x = Nx.iota({4,5}) + iex> key = Nx.Random.key(42) + iex> init = Nx.reverse(Nx.iota({4,3})) + iex> Scholar.Manifold.MDS.fit(x, init, num_components: 3, key: key) + %Scholar.Manifold.MDS{ + embedding: Nx.tensor( + [ + [9.682458877563477, 9.682458877563477, 9.682458877563477], + [3.2274858951568604, 3.2274858951568604, 3.2274858951568604], + [-3.2274863719940186, -3.2274863719940186, -3.2274863719940186], + [-9.682458877563477, -9.682458877563477, -9.682458877563477] + ] + ), + stress: Nx.tensor( + 9.094947017729282e-12 + ), + n_iter: Nx.tensor( + 3 + ) + } + """ + deftransform fit(x, init, opts) when is_list(opts) do + opts = NimbleOptions.validate!(opts, @opts_schema) + key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) + fit_n(x, init, key, opts) + end + + defnp fit_n(x, init, key, opts) do + {best, best_stress, best_iter} = mds_main_loop(x, init, key, opts) + %__MODULE__{embedding: best, stress: best_stress, n_iter: best_iter} + end +end diff --git a/test/scholar/manifold/mds_test.exs b/test/scholar/manifold/mds_test.exs new file mode 100644 index 00000000..87d5cd93 --- /dev/null +++ b/test/scholar/manifold/mds_test.exs @@ -0,0 +1,212 @@ +defmodule Scholar.Manifold.MDSTest do + use Scholar.Case, async: true + alias Scholar.Manifold.MDS + doctest MDS + + def x() do + Nx.iota({10, 50}) + end + + def key() do + Nx.Random.key(42) + end + + test "non-default num_components" do + key = key() + x = x() + model = EXLA.jit_apply(&MDS.fit(&1, num_components: 5, key: &2), [x, key]) + + assert_all_close( + model.embedding, + Nx.tensor([ + [ + 57.28269577026367, + -678.6760864257812, + 811.1503295898438, + -251.1714324951172, + 1156.7987060546875 + ], + [ + 7.623606204986572, + -544.2373046875, + 604.0946655273438, + -225.99559020996094, + 903.2800903320312 + ], + [ + -7.334737300872803, + -429.81671142578125, + 402.1512145996094, + -163.3682861328125, + 639.9016723632812 + ], + [ + 13.86670207977295, + -296.5096435546875, + 223.15061950683594, + -84.07274627685547, + 374.4827575683594 + ], + [ + 38.73623275756836, + -134.54620361328125, + 50.4241943359375, + -38.010799407958984, + 113.90003967285156 + ], + [ + 18.940887451171875, + 30.962879180908203, + -127.7795639038086, + 45.001678466796875, + -131.29234313964844 + ], + [ + 18.05344581604004, + 222.0098114013672, + -292.34197998046875, + 86.87554168701172, + -378.58544921875 + ], + [ + -3.060556173324585, + 429.6268005371094, + -436.2151794433594, + 146.84103393554688, + -621.5556640625 + ], + [ + -55.395423889160156, + 613.6642456054688, + -565.1470947265625, + 225.3615264892578, + -882.4739379882812 + ], + [ + -88.7128677368164, + 787.5221557617188, + -669.4872436523438, + 258.53912353515625, + -1174.455810546875 + ] + ]) + ) + + assert_all_close(model.stress, 698.4426879882812) + assert_all_close(model.n_iter, Nx.tensor(152)) + end + + test "non-default metric" do + key = key() + x = x() + model = EXLA.jit_apply(&MDS.fit(&1, metric: false, key: &2), [x, key]) + + assert_all_close( + model.embedding, + Nx.tensor([ + [-0.23465712368488312, 0.6921732425689697], + [-0.3380763530731201, 0.4378605782985687], + [-0.15237200260162354, 0.26230522990226746], + [0.09990488737821579, 0.2603200674057007], + [0.15598554909229279, 0.03315458819270134], + [0.41043558716773987, 0.13559512794017792], + [0.24686546623706818, -0.24366283416748047], + [0.1395486444234848, -0.4151153564453125], + [-0.07875102013349533, -0.530768096446991], + [-0.21976199746131897, -0.6417303681373596] + ]) + ) + + assert_all_close(model.stress, 0.1966342180967331) + assert_all_close(model.n_iter, Nx.tensor(38)) + end + + test "option normalized_stress with metric set to false" do + key = key() + x = x() + + model = + EXLA.jit_apply(&MDS.fit(&1, metric: false, key: &2, normalized_stress: true), [x, key]) + + assert_all_close( + model.embedding, + Nx.tensor([ + [-0.17997372150421143, 0.7225074768066406], + [-0.3138044774532318, 0.3934117257595062], + [-0.0900932177901268, 0.19507794082164764], + [0.2092301845550537, 0.295993834733963], + [0.24611115455627441, 0.0019988759886473417], + [0.4951189458370209, 0.08028026670217514], + [0.12963972985744476, -0.3193856179714203], + [0.19291982054710388, -0.44776636362075806], + [-0.2770233750343323, -0.4146113097667694], + [-0.3582141101360321, -0.5444929003715515] + ]) + ) + + assert_all_close(model.stress, 0.13638167083263397) + assert_all_close(model.n_iter, Nx.tensor(20)) + end + + test "epsilon set to a smaller then default value" do + key = key() + x = x() + + model = + EXLA.jit_apply(&MDS.fit(&1, metric: false, key: &2, normalized_stress: true, eps: 1.0e-4), [ + x, + key + ]) + + assert_all_close( + model.embedding, + Nx.tensor([ + [-0.35130712389945984, 0.6258886456489563], + [-0.4270354211330414, 0.4396686255931854], + [-0.30671024322509766, 0.2688262462615967], + [-0.12758131325244904, 0.18020282685756683], + [-0.05403336510062218, 0.01867777667939663], + [0.17203716933727264, 0.044468216598033905], + [0.2791652977466583, -0.09437420219182968], + [0.2869844138622284, -0.3071449398994446], + [0.2768166959285736, -0.49931082129478455], + [0.2563020884990692, -0.678329348564148] + ]) + ) + + # as expected smaller value of stress (loss) and bigger number of iterations + assert_all_close(model.stress, 0.03167537972331047) + assert_all_close(model.n_iter, Nx.tensor(116)) + end + + test "smaller max_iter value (100)" do + key = key() + x = x() + + model = + EXLA.jit_apply( + &MDS.fit(&1, metric: false, key: &2, normalized_stress: true, eps: 1.0e-4, max_iter: 100), + [x, key] + ) + + assert_all_close( + model.embedding, + Nx.tensor([ + [-0.34521010518074036, 0.6345276236534119], + [-0.4247266352176666, 0.43899187445640564], + [-0.2903931438922882, 0.2677172124385834], + [-0.09941618889570236, 0.19031266868114471], + [-0.03261081129312515, 0.019261524081230164], + [0.2049849033355713, 0.07233452051877975], + [0.29381951689720154, -0.09455471485853195], + [0.27441200613975525, -0.320201575756073], + [0.2368578165769577, -0.5156480669975281], + [0.19262047111988068, -0.6936381459236145] + ]) + ) + + # same params as in previous test, but smaller number of iterations, cupped on 100 + assert_all_close(model.stress, 0.040396787226200104) + assert_all_close(model.n_iter, Nx.tensor(100)) + end +end