Skip to content

Commit

Permalink
Manifold learning notebooks (elixir-nx#278)
Browse files Browse the repository at this point in the history
* Add manifold learning algorithm notebook and corrections to trimap

* Persistent outputs

* Remove tests with unsupported metrics

* Apply suggestions from code review

Co-authored-by: Krsto Proroković <[email protected]>

* Apply suggestions from code review

* Format

* Change tests to well define states

* Update notebooks/manifold_learning.livemd

---------

Co-authored-by: Krsto Proroković <[email protected]>
  • Loading branch information
msluszniak and krstopro authored Jun 16, 2024
1 parent 433041f commit accb6b7
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 70 deletions.
63 changes: 33 additions & 30 deletions lib/scholar/manifold/trimap.ex
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ defmodule Scholar.Manifold.Trimap do
@opts_schema NimbleOptions.new!(opts_schema)

defnp tempered_log(x, t) do
if Nx.abs(t - 1) < 1.0e-5 do
if Nx.abs(t - 1.0) < 1.0e-5 do
Nx.log(x)
else
(x ** (1 - t) - 1) * (1 / (1 - t))
1.0 / (1.0 - t) * (x ** (1.0 - t) - 1.0)
end
end

Expand Down Expand Up @@ -195,7 +195,7 @@ defmodule Scholar.Manifold.Trimap do
{samples, key, _, _, _} =
while {samples, key, discard, rejects, i}, Nx.any(discard) do
{new_samples, key} = Nx.Random.randint(key, 0, opts[:maxval], shape: {elem(shape, 1)})
discard = in1d(new_samples, rejects[i])
discard = in1d(new_samples, rejects[i]) or in1d(new_samples, samples)
samples = Nx.select(discard, samples, new_samples)
{samples, key, in1d(samples, rejects[i]), rejects, i}
end
Expand Down Expand Up @@ -245,9 +245,9 @@ defmodule Scholar.Manifold.Trimap do
sim = triplets[[.., 1]]
out = triplets[[.., 2]]

p_sim = handle_dist(inputs[anc], inputs[sim], opts) / (sig[anc] * sig[sim])
p_sim = -(handle_dist(inputs[anc], inputs[sim], opts) ** 2) / (sig[anc] * sig[sim])

p_out = handle_dist(inputs[anc], inputs[out], opts) / (sig[anc] * sig[out])
p_out = -(handle_dist(inputs[anc], inputs[out], opts) ** 2) / (sig[anc] * sig[out])

flip = p_sim < p_out
weights = p_sim - p_out
Expand All @@ -269,7 +269,7 @@ defmodule Scholar.Manifold.Trimap do
hits = Nx.flatten(neighbors)

distances =
handle_dist(inputs[anchors], inputs[hits], opts) |> Nx.reshape({num_points, :auto})
(handle_dist(inputs[anchors], inputs[hits], opts) ** 2) |> Nx.reshape({num_points, :auto})

sigmas = Nx.max(Nx.mean(Nx.sqrt(distances[[.., 3..5]]), axes: [1]), 1.0e-10)

Expand All @@ -282,27 +282,28 @@ defmodule Scholar.Manifold.Trimap do
end

defnp find_triplet_weights(inputs, triplets, neighbors, sigmas, distances, opts \\ []) do
{num_points, num_inliners} = Nx.shape(neighbors)
{num_points, num_inliers} = Nx.shape(neighbors)

p_sim = -Nx.flatten(distances)

num_outliers = div(Nx.axis_size(triplets, 0), num_points * num_inliners)
num_outliers = div(Nx.axis_size(triplets, 0), num_points * num_inliers)

p_sim =
Nx.tile(Nx.reshape(p_sim, {num_points, num_inliners}), [1, num_outliers]) |> Nx.flatten()
Nx.tile(Nx.reshape(p_sim, {num_points, num_inliers}), [1, num_outliers]) |> Nx.flatten()

out_distances = handle_dist(inputs[triplets[[.., 0]]], inputs[triplets[[.., 2]]], opts)
out_distances = handle_dist(inputs[triplets[[.., 0]]], inputs[triplets[[.., 2]]], opts) ** 2

p_out = -out_distances / (sigmas[triplets[[.., 0]]] * sigmas[triplets[[.., 2]]])
p_sim - p_out
end

defnp generate_triplets(key, inputs, opts \\ []) do
num_inliners = opts[:num_inliers]
num_inliers = opts[:num_inliers]
num_random = opts[:num_random]
weight_temp = opts[:weight_temp]
num_points = Nx.axis_size(inputs, 0)
num_extra = min(num_inliners + 50, num_points)

num_extra = min(num_inliers + 50, num_points)

neighbors =
case opts[:knn_algorithm] do
Expand Down Expand Up @@ -364,29 +365,29 @@ defmodule Scholar.Manifold.Trimap do

{knn_distances, neighbors, sigmas} = find_scaled_neighbors(inputs, neighbors, opts)

neighbors = neighbors[[.., 0..num_inliners]]
knn_distances = knn_distances[[.., 0..num_inliners]]
neighbors = neighbors[[.., 0..num_inliers]]
knn_distances = knn_distances[[.., 0..num_inliers]]

{triplets, key} =
sample_knn_triplets(key, neighbors,
num_outliers: opts[:num_outliers],
num_inliers: num_inliners,
num_inliers: num_inliers,
num_points: num_points
)

weights =
find_triplet_weights(
inputs,
triplets,
neighbors[[.., 1..num_inliners]],
neighbors[[.., 1..num_inliers]],
sigmas,
knn_distances[[.., 1..num_inliners]],
knn_distances[[.., 1..num_inliers]],
opts
)

flip = weights < 0
anchors = triplets[[.., 0]] |> Nx.reshape({:auto, 1})
pairs = triplets[[.., 1..2]]
pairs = triplets[[.., 1..-1//1]]

pairs =
Nx.select(
Expand Down Expand Up @@ -446,7 +447,7 @@ defmodule Scholar.Manifold.Trimap do
{loss, num_violated}
end

defn trimap_loss(embedding, triplets, weights) do
defn trimap_loss({embedding, triplets, weights}) do
{loss, _} = trimap_metrics(embedding, triplets, weights)
loss
end
Expand Down Expand Up @@ -499,21 +500,18 @@ defmodule Scholar.Manifold.Trimap do
{triplets, weights, key, applied_pca?} =
case triplets do
{} ->
inputs =
{inputs, applied_pca} =
if num_components > @dim_pca do
inputs = inputs - Nx.mean(inputs, axes: [0])
{u, s, vt} = Nx.LinAlg.SVD.svd(inputs, full_matrices: false)
inputs = Nx.dot(u[[.., 0..@dim_pca]] * s[0..@dim_pca], vt[[0..@dim_pca, ..]])

inputs = inputs - Nx.reduce_min(inputs)
inputs = inputs / Nx.reduce_max(inputs)
inputs - Nx.mean(inputs, axes: [0])
{inputs, Nx.u8(1)}
else
inputs
{inputs, Nx.u8(0)}
end

{triplets, weights, key} = generate_triplets(key, inputs, opts)
{triplets, weights, key, Nx.u8(1)}
{triplets, weights, key, applied_pca}

_ ->
{triplets, weights, key, Nx.u8(0)}
Expand All @@ -535,7 +533,10 @@ defmodule Scholar.Manifold.Trimap do

opts[:init_embedding_type] == 1 ->
{random_embedding, _key} =
Nx.Random.normal(key, shape: {num_points, opts[:num_components]})
Nx.Random.normal(key,
shape: {num_points, opts[:num_components]},
type: to_float_type(inputs)
)

random_embedding * @init_scale
end
Expand All @@ -553,10 +554,12 @@ defmodule Scholar.Manifold.Trimap do
{embeddings, _} =
while {embeddings, {vel, gain, lr, triplets, weights, i = Nx.s64(0)}},
i < opts[:num_iters] do
gamma = if i < @switch_iter, do: @final_momentum, else: @init_momentum
grad = trimap_loss(embeddings + gamma * vel, triplets, weights)
gamma = if i < @switch_iter, do: @init_momentum, else: @final_momentum

gradient =
grad(embeddings + gamma * vel, fn x -> trimap_loss({x, triplets, weights}) end)

{embeddings, vel, gain} = update_embedding_dbd(embeddings, grad, vel, gain, lr, i)
{embeddings, vel, gain} = update_embedding_dbd(embeddings, gradient, vel, gain, lr, i)

{embeddings, {vel, gain, lr, triplets, weights, i + 1}}
end
Expand Down
Binary file added notebooks/files/mammoth.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
149 changes: 149 additions & 0 deletions notebooks/manifold_learning.livemd
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
<!-- livebook:{"persist_outputs":true} -->

# Manifold Learning

```elixir
Mix.install([
{:scholar, github: "elixir-nx/scholar"},
{:explorer, "~> 0.8.2", override: true},
{:exla, "~> 0.7.2"},
{:nx, "~> 0.7.2"},
{:req, "~> 0.4.14"},
{:kino_vega_lite, "~> 0.1.11"},
{:kino, "~> 0.12.3"},
{:kino_explorer, "~> 0.1.18"},
{:tucan, "~> 0.3.1"}
])
```

## Setup

We will use `Explorer` in this notebook, so let's define an alias for its main module DataFrame:

```elixir
require Explorer.DataFrame, as: DF
```

And let's configure `EXLA` as our default backend (where our tensors are stored) and compiler (which compiles Scholar code) across the notebook and all branched sections:

```elixir
Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA)
```

## Testing Manifold Learning Functionalities

In this notebook we test how manifold learning algorithms works and what results we can get from them.

First, let's fetch the dataset that we experiment on. The data represents 3D coordinates of a mammoth. Below we include a figure of original dataset.

![](files/mammoth.png)

```elixir
source = "https://raw.githubusercontent.com/MNoichl/UMAP-examples-mammoth-/master/mammoth_a.csv"

data = Req.get!(source).body

df = DF.load_csv!(data)
```

Now, convert the dataframe into tensor, so we can manipulate the data using `Scholar`.

```elixir
tensor_data = Nx.stack(df, axis: 1)
```

Since there is almost 1 million data points and they are sorted, we shuffle dataset and then use only the part of the dataset.

<!-- livebook:{"branch_parent_index":1} -->

## Trimap

We start with Trimap. It's a manifold learning algorithm that is based of nearest neighbors. It preserves the global structure of dataset, but it doesn't handle in a poroper way the local structure. Let's look what will be the result of the Trimap on mammoth dataset.

```elixir
{tensor_data, key} = Nx.Random.shuffle(Nx.Random.key(42), tensor_data)

trimap_res =
Scholar.Manifold.Trimap.transform(tensor_data[[0..10000, ..]],
key: Nx.Random.key(55),
num_components: 2,
num_inliers: 12,
num_outliers: 4,
weight_temp: 0.5,
learning_rate: 0.1,
metric: :squared_euclidean
)
```

Now, lets plot the results of Trimap algorithm

```elixir
coords = [
x: trimap_res[[.., 0]] |> Nx.to_flat_list(),
y: trimap_res[[.., 1]] |> Nx.to_flat_list()
]

Tucan.layers([
Tucan.scatter(coords, "x", "y", point_size: 1)
])
|> Tucan.set_size(300, 300)
|> Tucan.set_title(
"Mammoth dataset with reduced dimensionality using Trimap",
offset: 25
)
```

For sure, we can recognize mammoth on this picture. Trimap indeed preserved the global structure of data. Result is similar to the projection of 3D mammoth to the YZ plane. Now, plot this projection and compare these two plots.

```elixir
coords = [
x: tensor_data[[0..10000, 1]] |> Nx.to_flat_list(),
y: tensor_data[[0..10000, 2]] |> Nx.to_flat_list()
]

Tucan.layers([
Tucan.scatter(coords, "x", "y", point_size: 1)
])
|> Tucan.set_size(300, 300)
|> Tucan.set_title(
"Mammoth data set with reduced dimensionality using trimap",
offset: 25
)
```

These two plots are similiar but there are some important differences. Even if the second figure seems "prettier" it is less informative than the result of trimap. On the first figure, we can spot two tusks while one the second one they overlap and we see only one. Similarly, legs overlay on the first plot and one the second one they are spread and don't intersect with each other.

## t-SNE

Now, lets try different algorithm: t-SNE

```elixir
tsne_res =
Scholar.Manifold.TSNE.fit(tensor_data[[0..2000, ..]],
key: Nx.Random.key(55),
num_components: 2,
perplexity: 125,
exaggeration: 10.0,
learning_rate: 500,
metric: :squared_euclidean
)
```

```elixir
coords = [
x: tsne_res[[.., 0]] |> Nx.to_flat_list(),
y: tsne_res[[.., 1]] |> Nx.to_flat_list()
]

Tucan.layers([
Tucan.scatter(coords, "x", "y", point_size: 1)
])
|> Tucan.set_size(300, 300)
|> Tucan.set_title(
"Mammoth dataset with reduced dimensionality using Trimap",
offset: 25
)
```

As we see, t-SNE gives completely different results than trimap. This is because t-SNE has a completely different mathematical background of computation. Also t-SNE is slower algorithm, so it can't be used on such big datasets as trimap. However, t-SNE preserves some features of mammoth like small tusks, feets, and corp. You can experiment with parameter *perplexity* which can substantially change the output of the algorithm.
Loading

0 comments on commit accb6b7

Please sign in to comment.