Skip to content

Commit

Permalink
xd again tsvd bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
norm4nn committed Oct 31, 2024
1 parent ac119e0 commit 15662fa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
11 changes: 8 additions & 3 deletions lib/scholar/decomposition/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@ defmodule Scholar.Decomposition.Utils do
import Nx.Defn
require Nx

defn flip_svd(u, v) do
max_abs_cols_idx = u |> Nx.abs() |> Nx.argmax(axis: 0, keep_axis: true)
signs = u |> Nx.take_along_axis(max_abs_cols_idx, axis: 0) |> Nx.sign() |> Nx.squeeze()
defn flip_svd(u, v, u_based \\ true) do
base = if u_based do
u
else
Nx.transpose(v)
end
max_abs_cols_idx = base |> Nx.abs() |> Nx.argmax(axis: 0, keep_axis: true)
signs = base |> Nx.take_along_axis(max_abs_cols_idx, axis: 0) |> Nx.sign() |> Nx.squeeze()
u = u * signs
v = v * Nx.new_axis(signs, -1)
{u, v}
Expand Down
16 changes: 13 additions & 3 deletions test/scholar/decomposition/truncated_svd_tests.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,27 @@ defmodule Scholar.Decomposition.TruncatedSVDTest do
alias Scholar.Decomposition.TruncatedSVD
doctest TruncatedSVD

defp key do
Nx.Random.key(1)
end

defp x do
Nx.tensor()
key = key()
{x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0, 0.0, 0.0]), Nx.tensor([[3.0, 2.0, 1.0, 9.0], [1.0, 2.0, 3.0, 8.2], [1.3, 1.0, 2.2, 2.4], [1.8, 1.0, 2.0, 2.9]]), shape: {50}, type: :f32)
x
end
# key = Nx.Random.key(1)
# {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0, 0.0, 0.0]), Nx.tensor([[3.0, 2.0, 1.0, 9.0], [1.0, 2.0, 3.0, 8.2], [1.3, 1.0, 2.2, 2.4], [1.8, 1.0, 2.0, 2.9]]), shape: {50}, type: :f32)
# tsvd = Scholar.Decomposition.TruncatedSVD.fit_transform(x, num_components: 2, key: key)


test "fit test - all default options" do

key = key()
x = x()
end

test "fit_transform test - all default options" do

#tsvd = Scholar.Decomposition.TruncatedSVD.fit_transform(x, num_components: 2, key: key)
end

test "fit_transform test - :num_components" do
Expand Down

0 comments on commit 15662fa

Please sign in to comment.