Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
norm4nn committed Nov 1, 2024
1 parent efed62b commit 5dc01cc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion lib/scholar/decomposition/truncated_svd.ex
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ defmodule Scholar.Decomposition.TruncatedSVD do
>
iex> key = Nx.Random.key(0)
iex> x = Nx.tensor([[0, 0, 3], [1, 0, 3], [1, 1, 3], [3, 3, 3], [4, 4.5, 3]])
iex> tsvd = Scholar.Decomposition.TruncatedSVD.fit_transform(x, num_components: 2, key: key)
iex> Scholar.Decomposition.TruncatedSVD.fit_transform(x, num_components: 2, key: key)
#Nx.Tensor<
f32[5][2]
[
Expand Down
10 changes: 5 additions & 5 deletions test/scholar/decomposition/truncated_svd_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ defmodule Scholar.Decomposition.TruncatedSVDTest do
type: :f32
)

tsvd = Scholar.Decomposition.TruncatedSVD.fit(x, key: key)
model = Scholar.Decomposition.TruncatedSVD.fit(x, key: key)

assert_all_close(
model.components,
Expand Down Expand Up @@ -80,7 +80,7 @@ defmodule Scholar.Decomposition.TruncatedSVDTest do
x_reduced = Scholar.Decomposition.TruncatedSVD.fit_transform(x, key: key)

assert_all_close(
model.singular_values,
x_reduced,
Nx.tensor([
[4.441530227661133, -1.5630521774291992],
[-2.187946081161499, -1.2309558391571045],
Expand Down Expand Up @@ -117,7 +117,7 @@ defmodule Scholar.Decomposition.TruncatedSVDTest do
x_reduced = Scholar.Decomposition.TruncatedSVD.fit_transform(x, key: key, num_components: 3)

assert_all_close(
model.singular_values,
x_reduced,
Nx.tensor([
[4.441530704498291, -1.5630513429641724, 0.08955635130405426],
[-2.1879451274871826, -1.2309576272964478, 1.2222723960876465],
Expand Down Expand Up @@ -154,7 +154,7 @@ defmodule Scholar.Decomposition.TruncatedSVDTest do
x_reduced = Scholar.Decomposition.TruncatedSVD.fit_transform(x, key: key, num_oversamples: 20)

assert_all_close(
model.singular_values,
x_reduced,
Nx.tensor([
[4.441530227661133, -1.5630521774291992],
[-2.187946081161499, -1.2309565544128418],
Expand Down Expand Up @@ -191,7 +191,7 @@ defmodule Scholar.Decomposition.TruncatedSVDTest do
x_reduced = Scholar.Decomposition.TruncatedSVD.fit_transform(x, key: key, num_iter: 20)

assert_all_close(
model.singular_values,
x_reduced,
Nx.tensor([
[4.441530227661133, -1.5630522966384888],
[-2.18794584274292, -1.2309566736221313],
Expand Down

0 comments on commit 5dc01cc

Please sign in to comment.