diff --git a/lib/scholar/decomposition/truncated_svd.ex b/lib/scholar/decomposition/truncated_svd.ex index a6d28974..1a19b63c 100644 --- a/lib/scholar/decomposition/truncated_svd.ex +++ b/lib/scholar/decomposition/truncated_svd.ex @@ -127,7 +127,7 @@ defmodule Scholar.Decomposition.TruncatedSVD do > iex> key = Nx.Random.key(0) iex> x = Nx.tensor([[0, 0, 3], [1, 0, 3], [1, 1, 3], [3, 3, 3], [4, 4.5, 3]]) - iex> tsvd = Scholar.Decomposition.TruncatedSVD.fit_transform(x, num_components: 2, key: key) + iex> Scholar.Decomposition.TruncatedSVD.fit_transform(x, num_components: 2, key: key) #Nx.Tensor< f32[5][2] [ diff --git a/test/scholar/decomposition/truncated_svd_test.exs b/test/scholar/decomposition/truncated_svd_test.exs index 92002a32..096023b6 100644 --- a/test/scholar/decomposition/truncated_svd_test.exs +++ b/test/scholar/decomposition/truncated_svd_test.exs @@ -24,7 +24,7 @@ defmodule Scholar.Decomposition.TruncatedSVDTest do type: :f32 ) - tsvd = Scholar.Decomposition.TruncatedSVD.fit(x, key: key) + model = Scholar.Decomposition.TruncatedSVD.fit(x, key: key) assert_all_close( model.components, @@ -80,7 +80,7 @@ defmodule Scholar.Decomposition.TruncatedSVDTest do x_reduced = Scholar.Decomposition.TruncatedSVD.fit_transform(x, key: key) assert_all_close( - model.singular_values, + x_reduced, Nx.tensor([ [4.441530227661133, -1.5630521774291992], [-2.187946081161499, -1.2309558391571045], @@ -117,7 +117,7 @@ defmodule Scholar.Decomposition.TruncatedSVDTest do x_reduced = Scholar.Decomposition.TruncatedSVD.fit_transform(x, key: key, num_components: 3) assert_all_close( - model.singular_values, + x_reduced, Nx.tensor([ [4.441530704498291, -1.5630513429641724, 0.08955635130405426], [-2.1879451274871826, -1.2309576272964478, 1.2222723960876465], @@ -154,7 +154,7 @@ defmodule Scholar.Decomposition.TruncatedSVDTest do x_reduced = Scholar.Decomposition.TruncatedSVD.fit_transform(x, key: key, num_oversamples: 20) assert_all_close( - model.singular_values, + x_reduced, Nx.tensor([ [4.441530227661133, -1.5630521774291992], [-2.187946081161499, -1.2309565544128418], @@ -191,7 +191,7 @@ defmodule Scholar.Decomposition.TruncatedSVDTest do x_reduced = Scholar.Decomposition.TruncatedSVD.fit_transform(x, key: key, num_iter: 20) assert_all_close( - model.singular_values, + x_reduced, Nx.tensor([ [4.441530227661133, -1.5630522966384888], [-2.18794584274292, -1.2309566736221313],