diff --git a/lib/scholar/cluster/hierarchical.ex b/lib/scholar/cluster/hierarchical.ex new file mode 100644 index 00000000..481b6cd2 --- /dev/null +++ b/lib/scholar/cluster/hierarchical.ex @@ -0,0 +1,453 @@ +defmodule Scholar.Cluster.Hierarchical do + @moduledoc """ + Performs [hierarchical, agglomerative clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering#Agglomerative_clustering_example) + on a dataset. + + Hierarchical clustering is good for when the number of clusters is not known ahead of time. + It also allows for the creation of a [dendrogram plot](https://en.wikipedia.org/wiki/Dendrogram) + (regardless of the dimensionality of the dataset) which can be used to select the number of + clusters in a post-processing step. + + ## Limitations + + Due to the requirements of the current implementation, only these options are supported: + + * `dissimilarity: :euclidean` + * `linkage: :average | :complete | :single | :ward | :weighted` + + Our current algorithm is $O(\\frac{n^2}{p} \\cdot \\log(n))$ where $n$ is the number of data points + and $p$ is the number of processors. + This is better than the generic algorithm which is $O(n^3)$. + It is also parallel, which means that runtime decreases in direct proportion to the number of + processors. + + However, the implementation requires certain theoretical properties of the dissimilarities and + linkages. + As such, we've restricted the options to only those combinations with the correct properties. + + In the future, we plan to add additional algorithms which won't have the same restrictions. + """ + import Nx.Defn + + defstruct [:clades, :dissimilarities, :num_points, :sizes] + + @dissimilarity_types [ + :euclidean + # :precomputed + ] + + @linkage_types [ + :average, + # :centroid, + :complete, + # :median, + :single, + :ward, + :weighted + ] + + @fit_opts_schema [ + dissimilarity: [ + type: {:in, @dissimilarity_types}, + default: :euclidean, + doc: """ + Pairwise dissimilarity function: computes the 'dissimilarity' between each pair of data points. + Dissimilarity is analogous to distance, but without the expectation that the triangle + inequality holds. + + Choices: + + * `:euclidean` - L2 norm. + + See "Limitations" in the moduledoc for an explanation of the lack of choices. + """ + ], + linkage: [ + type: {:in, @linkage_types}, + default: :single, + doc: ~S""" + Linkage function: how to compute the intra-clade dissimilarity of two clades if they were + merged. + + Choices: + + * `:average` - The unweighted average dissimilarity across all pairs of points. + + * `:complete` - (Historic name) The maximum dissimilarity across all pairs of points. + + * `:single` - (Historic name) The minimum dissimilarity across all pairs of points. + + * `:ward` - (Named for [Ward's method](https://en.wikipedia.org/wiki/Ward%27s_method)) + The minimum increase in sum of squares (MISSQ) of dissimilarities. + + * `:weighted` - The weighted average dissimilarity across all pairs of points. + """ + ] + ] + @doc """ + Use hierarchical clustering to form the initial model to be clustered with `labels_list/2` or + `labels_map/2`. + + ## Options + + #{NimbleOptions.docs(@fit_opts_schema)} + + ## Return values + + Returns a `Scholar.Cluster.Hierarchical` struct with the following fields: + + * `clades` (`Nx.Tensor` with shape `{n - 1, 2}`) - + Contains the indices of the pair of clades merged at each step of the agglomerative + clustering process. + + Agglomerative clustering starts by considering each datum in `data` its own singleton group + or ["clade"](https://en.wikipedia.org/wiki/Clade). + It then picks two clades to merge into a new clade containing the data from both. + It does this until there is a single clade remaining. + + Since each datum starts as its own clade, e.g. `data[0]` is clade `0`, indexing of new clades + starts at `n` where `n` is the size of the original `data` tensor. + If `clades[k] == [i, j]`, then clades `i` and `j` were merged to form `k + n`. + + * `dissimilarities` (`Nx.Tensor` with shape `{n - 1}`) - + Contains a metric that measures the intra-clade closeness of each newly formed clade. + Represented by the heights of each clade in a dendrogram plot. + Determined by both the `:dissimilarity` and `:linkage` options. + + * `num_points` (`pos_integer/0`) - + Number of points in the dataset. + Must be $\\geq 3$. + + * `sizes` (`Nx.Tensor` with shape `{n - 1}`) - + `sizes[i]` is the size of clade `i`. + If clade `k` was created by merging clades `i` and `j`, then + `sizes[k] == sizes[i] + sizes[j]`. + + ## Examples + + iex> data = Nx.tensor([[2], [7], [9], [0], [3]]) + iex> Hierarchical.fit(data) + %Scholar.Cluster.Hierarchical{ + clades: Nx.tensor([[0, 4], [1, 2], [3, 5], [6, 7]]), + dissimilarities: Nx.tensor([1.0, 2.0, 2.0, 4.0]), + num_points: 5, + sizes: Nx.tensor([2, 2, 3, 5]) + } + """ + deftransform fit(%Nx.Tensor{} = data, opts \\ []) do + opts = NimbleOptions.validate!(opts, @fit_opts_schema) + dissimilarity = opts[:dissimilarity] + linkage = opts[:linkage] + + dissimilarity_fun = + case dissimilarity do + # :precomputed -> &Function.identity/1 + :euclidean -> &Scholar.Metrics.Distance.pairwise_euclidean/1 + end + + update_fun = + case linkage do + :average -> &average/6 + # :centroid -> ¢roid/6 + :complete -> &complete/6 + # :median -> &median/6 + :single -> &single/6 + :ward -> &ward/6 + :weighted -> &weighted/6 + end + + dendrogram_fun = + case linkage do + # TODO: :centroid, :median + l when l in [:average, :complete, :single, :ward, :weighted] -> + ¶llel_nearest_neighbor/3 + end + + n = + case Nx.shape(data) do + {n, _num_features} -> + n + + other -> + raise ArgumentError, + "Expected a rank 2 (`{num_obs, num_features}`) tensor, found shape: #{inspect(other)}." + end + + if n < 3 do + raise ArgumentError, "Must have a minimum of 3 data points, found: #{n}." + end + + {clades, diss, sizes} = dendrogram_fun.(data, dissimilarity_fun, update_fun) + + %__MODULE__{ + clades: clades, + dissimilarities: diss, + num_points: n, + sizes: sizes + } + end + + # Clade functions + + defnp parallel_nearest_neighbor(data, dissimilarity_fun, update_fun) do + pairwise = dissimilarity_fun.(data) + {n, _} = Nx.shape(pairwise) + pairwise = Nx.broadcast(:infinity, {n}) |> Nx.make_diagonal() |> Nx.add(pairwise) + clades = Nx.broadcast(-1, {n - 1, 2}) + sizes = Nx.broadcast(1, {2 * n - 1}) + pointers = Nx.broadcast(-1, {2 * n - 2}) + diss = Nx.tensor(:infinity, type: Nx.type(pairwise)) |> Nx.broadcast({n - 1}) + + {{clades, diss, sizes}, _} = + while {{clades, diss, sizes}, {count = 0, pointers, pairwise}}, count < n - 1 do + # Indexes of who I am nearest to + nearest = Nx.argmin(pairwise, axis: 1) + + # Take who I am nearest to is nearest to + nearest_of_nearest = Nx.take(nearest, nearest) + + # If the entry is pointing back at me, then we are a clade + clades_selector = nearest_of_nearest == Nx.iota({n}) + + # Now let's get the links that form clades. + # They are bidirectional but let's keep only one side. + links = Nx.select(clades_selector and nearest > nearest_of_nearest, nearest, n) + + {clades, count, pointers, pairwise, diss, sizes} = + merge_clades(clades, count, pointers, pairwise, diss, sizes, links, n, update_fun) + + {{clades, diss, sizes}, {count, pointers, pairwise}} + end + + sizes = sizes[n..(2 * n - 2)] + perm = Nx.argsort(diss, stable: false, type: :u32) + {clades[perm], diss[perm], sizes[perm]} + end + + defnp merge_clades(clades, count, pointers, pairwise, diss, sizes, links, n, update_fun) do + {{clades, count, pointers, pairwise, diss, sizes}, _} = + while {{clades, count, pointers, pairwise, diss, sizes}, links}, + i <- 0..(Nx.size(links) - 1) do + # i < j because of how links is formed. + # i will become the new clade index and we "infinity-out" j. + j = links[i] + + if j == n do + {{clades, count, pointers, pairwise, diss, sizes}, links} + else + # Clades a and b (i and j of pairwise) are being merged into c. + indices = [i, j] |> Nx.stack() |> Nx.new_axis(-1) + a = find_clade(pointers, i) + b = find_clade(pointers, j) + c = count + n + + # Update clades + new_clade = Nx.stack([a, b]) |> Nx.sort() |> Nx.new_axis(0) + clades = Nx.put_slice(clades, [count, 0], new_clade) + + # Update sizes + sa = sizes[i] + sb = sizes[j] + sc = sa + sb + sizes = Nx.indexed_put(sizes, Nx.stack([i, c]) |> Nx.new_axis(-1), Nx.stack([sc, sc])) + + # Update dissimilarities + diss = Nx.indexed_put(diss, Nx.stack([count]), pairwise[i][j]) + + # Update pointers + pointers = Nx.indexed_put(pointers, indices, Nx.stack([c, c])) + + # Update pairwise + updates = + update_fun.(pairwise[i], pairwise[j], pairwise[i][j], sa, sb, sc) + |> Nx.indexed_put(indices, Nx.broadcast(:infinity, {2})) + + pairwise = + pairwise + |> Nx.put_slice([i, 0], Nx.reshape(updates, {1, n})) + |> Nx.put_slice([0, i], Nx.reshape(updates, {n, 1})) + |> Nx.put_slice([j, 0], Nx.broadcast(:infinity, {1, n})) + |> Nx.put_slice([0, j], Nx.broadcast(:infinity, {n, 1})) + + {{clades, count + 1, pointers, pairwise, diss, sizes}, links} + end + end + + {clades, count, pointers, pairwise, diss, sizes} + end + + defnp find_clade(pointers, i) do + {i, _, _} = + while {_current = i, next = pointers[i], pointers}, next != -1 do + {next, pointers[next], pointers} + end + + i + end + + # Dissimilarity update functions + + defnp average(dac, dbc, _dab, sa, sb, _sc), + do: (sa * dac + sb * dbc) / (sa + sb) + + # defnp centroid(dac, dbc, dab, sa, sb, _sc), + # do: Nx.sqrt((sa * dac + sb * dbc) / (sa + sb) - sa * sb * dab / (sa + sb) ** 2) + + defnp complete(dac, dbc, _dab, _sa, _sb, _sc), + do: Nx.max(dac, dbc) + + # defnp median(dac, dbc, dab, _sa, _sb, _sc), + # do: Nx.sqrt(dac / 2 + dbc / 2 - dab / 4) + + defnp single(dac, dbc, _dab, _sa, _sb, _sc), + do: Nx.min(dac, dbc) + + defnp ward(dac, dbc, dab, sa, sb, sc), + do: Nx.sqrt(((sa + sc) * dac + (sb + sc) * dbc - sc * dab) / (sa + sb + sc)) + + defnp weighted(dac, dbc, _dab, _sa, _sb, _sc), + do: (dac + dbc) / 2 + + # Cluster label functions + + @label_opts_schema [ + cluster_by: [ + type: :non_empty_keyword_list, + required: true, + keys: [ + height: [ + type: :float, + doc: "Height of the dendrogram to use as the split point for clusters." + ], + num_clusters: [ + type: :pos_integer, + doc: "Number of clusters to form." + ] + ], + doc: """ + How to select which clades from the dendrogram should form the final clusters. + Must provide either a height or a number of clusters. + """ + ] + ] + @doc """ + Cluster a `Scholar.Cluster.Hierarchical` struct into a map of cluster labels to member indices. + + ## Options + + #{NimbleOptions.docs(@label_opts_schema)} + + ## Return values + + Returns a map where the keys are integers from `0..(k - 1)` where `k` is the number of clusters. + Each value is a cluster represented by a list of member indices. + E.g. if the result map was `%{0 => [0, 1], 1 => [2]}`, then elements `[0, 1]` of the data would + be in cluster `0` and the singleton element `[2]` would be in cluster `1`. + + Cluster labels are arbitrary, but deterministic. + + ## Examples + + iex> data = Nx.tensor([[5], [5], [5], [10], [10]]) + iex> model = Hierarchical.fit(data) + iex> Hierarchical.labels_map(model, cluster_by: [num_clusters: 2]) + %{0 => [0, 1, 2], 1 => [3, 4]} + """ + def labels_map(%__MODULE__{} = model, opts) do + opts = NimbleOptions.validate!(opts, @label_opts_schema) + + raw_clusters = + case opts[:cluster_by] do + [height: height] -> + cluster_by_height(model, height) + + [num_clusters: num_clusters] -> + cond do + num_clusters > model.num_points -> + raise ArgumentError, "`num_clusters` may not exceed number of data points." + + num_clusters == model.num_points -> + Nx.broadcast(0, {model.num_points}) + + # The other cases are validated by NimbleOptions. + true -> + cluster_by_num_clusters(model, num_clusters) + end + + _ -> + raise ArgumentError, "Must pass exactly one of `:height` or `:num_clusters`" + end + + # Give the clusters labels 0..(k - 1) and ensure those labels are deterministic by sorting by + # the minimum element. + raw_clusters + |> Enum.sort_by(fn {_label, cluster} -> Enum.min(cluster) end) + |> Enum.with_index() + |> Enum.flat_map(fn {{_, v}, i} -> v |> Enum.sort() |> Enum.map(&{&1, i}) end) + |> Enum.group_by(fn {_, v} -> v end, fn {k, _} -> k end) + end + + @doc """ + Cluster a `Scholar.Cluster.Hierarchical` struct into a list of cluster labels. + + ## Options + + #{NimbleOptions.docs(@label_opts_schema)} + + ## Return values + + Returns a list of length `n` and values `0..(k - 1)` where `n` is the number of data points and + `k` is the number of clusters formed. + The `i`th element of the result list is the label of the `i`th data point's cluster. + + Cluster labels are arbitrary, but deterministic. + + ## Examples + + iex> data = Nx.tensor([[5], [5], [5], [10], [10]]) + iex> model = Hierarchical.fit(data) + iex> Hierarchical.labels_list(model, cluster_by: [num_clusters: 2]) + [0, 0, 0, 1, 1] + """ + def labels_list(%__MODULE__{} = model, opts) do + model + |> labels_map(opts) + |> Enum.flat_map(fn {k, vs} -> Enum.map(vs, &{&1, k}) end) + |> Enum.sort() + |> Enum.map(fn {_, v} -> v end) + end + + defp cluster_by_height(model, height_cutoff) do + clusters = Map.new(0..(model.num_points - 1), &{&1, [&1]}) + + Enum.zip(Nx.to_list(model.clades), Nx.to_list(model.dissimilarities)) + |> Enum.with_index(model.num_points) + |> Enum.reduce_while(clusters, fn {{[a, b], height}, c}, clusters -> + if height >= height_cutoff do + {:halt, clusters} + else + {:cont, merge_clusters(clusters, a, b, c)} + end + end) + end + + defp cluster_by_num_clusters(model, num_clusters) do + clusters = Map.new(0..(model.num_points - 1), &{&1, [&1]}) + + Nx.to_list(model.clades) + |> Enum.with_index(model.num_points) + |> Enum.reduce_while(clusters, fn {[a, b], c}, clusters -> + if c + num_clusters == 2 * model.num_points do + {:halt, clusters} + else + {:cont, merge_clusters(clusters, a, b, c)} + end + end) + end + + defp merge_clusters(clusters, a, b, c) do + clusters + |> Map.put(c, clusters[a] ++ clusters[b]) + |> Map.drop([a, b]) + end +end diff --git a/mix.exs b/mix.exs index 85d2ed85..971e8817 100644 --- a/mix.exs +++ b/mix.exs @@ -30,7 +30,7 @@ defmodule Scholar.MixProject do defp deps do [ {:ex_doc, "~> 0.30", only: :docs}, - {:nx, "~> 0.6.3 or ~> 0.7", override: true}, + {:nx, "~> 0.6.4 or ~> 0.7", override: true}, {:nimble_options, "~> 0.5.2 or ~> 1.0"}, {:exla, "~> 0.6.3 or ~> 0.7", optional: true}, {:polaris, "~> 0.1"}, @@ -54,6 +54,7 @@ defmodule Scholar.MixProject do logo: "images/scholar_simplified.png", extra_section: "Guides", extras: [ + # "notebooks/hierarchical_clustering.livemd", "README.md", "notebooks/linear_regression.livemd", "notebooks/k_means.livemd", @@ -66,6 +67,7 @@ defmodule Scholar.MixProject do Scholar.Cluster.AffinityPropagation, Scholar.Cluster.DBSCAN, Scholar.Cluster.GaussianMixture, + Scholar.Cluster.Hierarchical, Scholar.Cluster.KMeans, Scholar.Decomposition.PCA, Scholar.Integrate, diff --git a/mix.lock b/mix.lock index 2d8829a8..5e93f393 100644 --- a/mix.lock +++ b/mix.lock @@ -11,7 +11,7 @@ "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, "nimble_options": {:hex, :nimble_options, "0.5.2", "42703307b924880f8c08d97719da7472673391905f528259915782bb346e0a1b", [:mix], [], "hexpm", "4da7f904b915fd71db549bcdc25f8d56f378ef7ae07dc1d372cbe72ba950dce0"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, - "nx": {:hex, :nx, "0.6.3", "058d0173a9af9a688dfba74f471a5abf8f858625abdd06ba3c597cfeffed1599", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "955f86f39df744e64d576c89ba276b7b4cf15ea01d484d9f0c6d54485c091b7b"}, + "nx": {:hex, :nx, "0.6.4", "948d9f42f81e63fc901d243ac0a985c8bb87358be62e27826cfd67f58bc640af", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bb9c2e2e3545b5eb4739d69046a988daaa212d127dba7d97801c291616aff6d6"}, "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, "statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, diff --git a/notebooks/hierarchical_clustering.livemd b/notebooks/hierarchical_clustering.livemd new file mode 100644 index 00000000..b7096029 --- /dev/null +++ b/notebooks/hierarchical_clustering.livemd @@ -0,0 +1,217 @@ +# Hierarchical Clustering + +```elixir +app_root = Path.join(__DIR__, "..") + +Mix.install( + [ + {:kino, "~> 0.10.0"}, + {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, + {:scholar, path: app_root} + ], + config_path: Path.join(app_root, "config/config.exs"), + lockfile: Path.join(app_root, "mix.lock") +) +``` + +## Introduction + +```elixir +defmodule Scholar.Kino.CanvasDendrogram do + use Kino.JS + + def new(graph), do: Kino.JS.new(__MODULE__, graph) + + asset "main.js" do + """ + function scale_fun(from_range, to_range) { + let [from_min, from_max] = from_range; + let [to_min, to_max] = to_range; + return function(x) { + return to_min + (x - from_min) / (from_max - from_min) * (to_max - to_min) + } + } + + function text_height(ctx, x) { + let text_metrics = ctx.measureText(x); + return text_metrics.actualBoundingBoxAscent + text_metrics.actualBoundingBoxDescent; + } + + function text_width(ctx, x) { + let text_metrics = ctx.measureText(x); + return text_metrics.actualBoundingBoxRight + text_metrics.actualBoundingBoxLeft; + } + + function draw(ctx, canvas_params, dendrogram) { + let height = canvas_params.height; + let width = canvas_params.width; + + let clades = dendrogram.clades + let dissimilarities = dendrogram.dissimilarities + let num_leaves = dendrogram.num_points + let max_dissimilarity = dissimilarities[dissimilarities.length - 1]; + + let x_min = 0; + let x_max = num_leaves - 1; + let y_min = 0; + let y_max = max_dissimilarity; + + let x_tick_labels = [...Array(x_max + 1).keys()]; + let y_tick_labels = [...Array(Math.floor(y_max) + 1).keys()]; + let x_tick_label_height = Math.max(...x_tick_labels.map((l) => text_height(ctx, l))); + let y_tick_label_width = Math.max(...y_tick_labels.map((l) => text_width(ctx, l))); + let x_tick_area_height = 10; + let y_tick_area_width = 10; + + let margin = 10; + let plot_left = margin + y_tick_label_width + y_tick_area_width; + let plot_top = margin; + let plot_height = height - 2*margin - x_tick_label_height - x_tick_area_height; + let plot_width = width - 2*margin - y_tick_label_width - y_tick_area_width; + + let x_range = x_max - x_min; + let y_range = y_max - y_min; + let data_margin = 0.1; + let x_data_min = x_min - x_range * data_margin / 2; + let x_data_max = x_max + x_range * data_margin / 2; + let y_data_min = y_min - y_range * data_margin / 2; + let y_data_max = y_max + y_range * data_margin / 2; + let scale_x = scale_fun([x_data_min, x_data_max], [plot_left, plot_left + plot_width]); + let scale_y = scale_fun([y_data_min, y_data_max], [plot_top + plot_height, plot_top]); + + // Axes + ctx.beginPath(); + ctx.moveTo(plot_left, plot_top); + ctx.lineTo(plot_left, plot_top + plot_height); + ctx.lineTo(plot_left + plot_width, plot_top + plot_height); + ctx.lineTo(plot_left + plot_width, plot_top); + ctx.lineTo(plot_left, plot_top); + ctx.stroke(); + ctx.closePath(); + + // x-ticks + for(let x of x_tick_labels) { + ctx.beginPath(); + ctx.moveTo(scale_x(x), scale_y(y_data_min)); + ctx.lineTo(scale_x(x), scale_y(y_data_min) + x_tick_area_height / 2); + ctx.stroke(); + ctx.closePath(); + } + + // y-ticks + for(let y of y_tick_labels) { + ctx.beginPath(); + ctx.moveTo(scale_x(x_data_min), scale_y(y)); + ctx.lineTo(scale_x(x_data_min) - y_tick_area_width / 2, scale_y(y)); + ctx.stroke(); + ctx.closePath(); + } + + // x-tick labels + ctx.textAlign = "center"; + for(let x of x_tick_labels) { + ctx.strokeText(x, scale_x(x), scale_y(y_data_min) + x_tick_area_height + x_tick_label_height); + } + + // y-tick labels + ctx.textBaseline = "middle"; + ctx.textAlign = "end"; + for(let y of y_tick_labels) { + ctx.strokeText(y, scale_x(x_data_min) - y_tick_area_width, scale_y(y)); + } + + // Leaves + let coords = new Map(); + for (let i = 0; i < num_leaves; i++) { + let x = scale_x(i); + let y = scale_y(0); + + ctx.beginPath(); + ctx.arc(x, y, 5, 0, Math.PI * 2); + ctx.fill(); + + coords.set(i, [x, y]); + } + + // Clades + for (let i = 0; i < clades.length; i++) { + let [a, b] = clades[i] + let c = i + num_leaves + let d = dissimilarities[i] + + let [ax, ay] = coords.get(a); + let [bx, by] = coords.get(b); + let cx = (ax + bx) / 2; + let cy = scale_y(d); + + ctx.beginPath(); + ctx.moveTo(ax, ay); + ctx.lineTo(ax, cy); + ctx.lineTo(bx, cy); + ctx.lineTo(bx, by); + ctx.stroke(); + ctx.closePath(); + + ctx.beginPath(); + ctx.arc(cx, cy, 5, 0, Math.PI * 2); + ctx.fill(); + + coords.delete(a); + coords.delete(b); + coords.set(c, [cx, cy]); + } + } + + export function init(ctx, input) { + let dendrogram = input.dendrogram + let canvas_params = input.canvas + let canvas_el_id = "dendrogram-plot"; + + ctx.root.innerHTML = + ` + `; + + let canvas_el = document.getElementById(canvas_el_id); + + // Check for canvas support + if (canvas_el.getContext) { + let canvas_ctx = canvas_el.getContext("2d"); + draw(canvas_ctx, canvas_params, dendrogram); + } + } + """ + end +end +``` + +```elixir +# 5 | 0 1 3 4 +# 4 | 2 5 +# 3 | +# 2 | 6 +# 1 | 7 8 +# 0 +-+-+-+-+-+ +# 0 1 2 3 4 5 + +# Tensor form of the data sketched above +data = Nx.tensor([[1, 5], [2, 5], [1, 4], [4, 5], [5, 5], [5, 4], [1, 2], [1, 1], [2, 1]]) + +# Build model from data +model = Scholar.Cluster.Hierarchical.fit(data, dissimilarity: :euclidean, linkage: :average) + +# Make a JSON-serializable "dendrogram" +dendrogram = + model + |> Map.from_struct() + |> Map.new(fn + {k, %Nx.Tensor{} = v} -> {k, Nx.to_list(v)} + {k, v} -> {k, v} + end) + +# Plot +Scholar.Kino.CanvasDendrogram.new(%{dendrogram: dendrogram, canvas: %{width: 400, height: 400}}) +``` diff --git a/test/scholar/cluster/hierarchical_test.exs b/test/scholar/cluster/hierarchical_test.exs new file mode 100644 index 00000000..6c4e5d55 --- /dev/null +++ b/test/scholar/cluster/hierarchical_test.exs @@ -0,0 +1,218 @@ +defmodule Scholar.Cluster.HierarchicalTest do + use Scholar.Case, async: true + + alias Scholar.Cluster.Hierarchical + + doctest Hierarchical + + describe "basic example" do + test "works" do + # This diagram represents data. `0` appears at the coordinates (1, 5). The 0th entry of data + # is `[1, 5]`. Same for 1, etc. + # + # 5 | 0 1 3 4 + # 4 | 2 5 + # 3 | + # 2 | 6 + # 1 | 7 8 + # 0 +-+-+-+-+-+ + # 0 1 2 3 4 5 + data = Nx.tensor([[1, 5], [2, 5], [1, 4], [4, 5], [5, 5], [5, 4], [1, 2], [1, 1], [2, 1]]) + + # This diagram represents the sequence of expected merges. The data starts off with all + # points as singleton clades. The first step of the algorithm merges singleton clades + # 0: [0] and 1: [1] to form clade 9: [0, 1]. This process continues until all clades have + # been merged into a single clade with all points. + # + # 0 1 2 3 4 5 6 7 8 + # 8: [0] [1] [2] [3] [4] [5] [6] [7] [8] + # 9 2 3 4 5 6 7 8 + # 9: [01] [2] [3] [4] [5] [6] [7] [8] + # ---- + # 9 2 10 5 6 7 8 + # 10: [01] [2] [34] [5] [6] [7] [8] + # ---- + # 9 2 10 5 11 8 + # 11: [01] [2] [34] [5] [67] [8] + # ---- + # 12 10 5 11 8 + # 12: [012] [34] [5] [67] [8] + # ----- + # 12 13 11 8 + # 13: [012] [345] [67] [8] + # ----- + # 12 13 14 + # 14: [012] [345] [678] + # ----- + # 15 14 + # 15: [012345] [678] + # -------- + # 16 + # 16: [012345678] + # ----------- + model = Hierarchical.fit(data, dissimilarity: :euclidean, linkage: :single) + + # The dendrogram formation part of the algorithm should've formed the following clades, + # dissimilarities, and sizes (which collectively form the dendrogram). + assert model.clades == + Nx.tensor([ + [0, 1], + [3, 4], + [6, 7], + [2, 9], + [5, 10], + [8, 11], + [12, 13], + [14, 15] + ]) + + assert model.dissimilarities == + Nx.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0]) + + assert model.sizes == Nx.tensor([2, 2, 2, 3, 3, 3, 6, 9]) + + # The clustering part of the algorithm uses the `cluster_by: [num_clusters: 3]` option to + # take the model and form 3 clusters. + labels_map = Hierarchical.labels_map(model, cluster_by: [num_clusters: 3]) + assert labels_map == %{0 => [0, 1, 2], 1 => [3, 4, 5], 2 => [6, 7, 8]} + + # We can also return a list of each datum's cluster label. + labels_list = Hierarchical.labels_list(model, cluster_by: [num_clusters: 3]) + assert labels_list == [0, 0, 0, 1, 1, 1, 2, 2, 2] + end + end + + describe "linkages" do + setup do + %{data: Nx.tensor([[1, 5], [2, 5], [1, 4], [4, 5], [5, 5], [5, 4], [1, 2], [1, 1], [2, 1]])} + end + + test "average", %{data: data} do + model = Hierarchical.fit(data, linkage: :average) + + assert model.dissimilarities == + Nx.tensor([ + 1.0, + 1.0, + 1.0, + 1.2071068286895752, + 1.2071068286895752, + 1.2071068286895752, + 3.396751642227173, + 4.092065334320068 + ]) + end + + test "complete", %{data: data} do + model = Hierarchical.fit(data, linkage: :complete) + + assert model.dissimilarities == + Nx.tensor([ + 1.0, + 1.0, + 1.0, + # sqrt(2) + 1.4142135381698608, + 1.4142135381698608, + 1.4142135381698608, + # sqrt(17) + 4.123105525970459, + # 4 * sqrt(2) + 5.656854152679443 + ]) + end + + test "single", %{data: data} do + model = Hierarchical.fit(data, linkage: :single) + assert model.dissimilarities == Nx.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0]) + end + + @tag :skip + test "ward", %{data: data} do + model = Hierarchical.fit(data, linkage: :ward) + + # (Approximately, from Scipy) + assert model.dissimilarities == + Nx.tensor([ + 1.0, + 1.0, + 1.0, + 1.29099445, + 1.29099445, + 1.29099445, + 5.77350269, + 7.45355992 + ]) + end + + test "weighted", %{data: data} do + model = Hierarchical.fit(data, linkage: :weighted) + + assert model.dissimilarities == + Nx.tensor([ + 1.0, + 1.0, + 1.0, + 1.2071068286895752, + 1.2071068286895752, + 1.2071068286895752, + 3.32379412651062, + 4.1218791007995605 + ]) + end + end + + describe "cluster labels" do + setup do + %{model: Hierarchical.fit(Nx.tensor([[2], [7], [9], [0], [3]]))} + end + + test "cluster by height", %{model: model} do + labels_map = Hierarchical.labels_map(model, cluster_by: [height: 2.5]) + assert labels_map == %{0 => [0, 3, 4], 1 => [1, 2]} + labels_list = Hierarchical.labels_list(model, cluster_by: [height: 2.5]) + assert labels_list == [0, 1, 1, 0, 0] + end + + test "cluster by number of clusters", %{model: model} do + labels_map = Hierarchical.labels_map(model, cluster_by: [num_clusters: 3]) + assert labels_map == %{0 => [0, 4], 1 => [1, 2], 2 => [3]} + labels_list = Hierarchical.labels_list(model, cluster_by: [num_clusters: 3]) + assert labels_list == [0, 1, 1, 2, 0] + end + end + + describe "errors" do + test "need a rank 2 tensor" do + assert_raise( + ArgumentError, + "Expected a rank 2 (`{num_obs, num_features}`) tensor, found shape: {3}.", + fn -> + Hierarchical.fit(Nx.tensor([1, 2, 3])) + end + ) + end + + test "need at least 3 data points" do + assert_raise(ArgumentError, "Must have a minimum of 3 data points, found: 2.", fn -> + Hierarchical.fit(Nx.tensor([[1], [2]])) + end) + end + + test "num_clusters may not exceed number of data points" do + model = Hierarchical.fit(Nx.tensor([[1], [2], [3]])) + + assert_raise(ArgumentError, "`num_clusters` may not exceed number of data points.", fn -> + Hierarchical.labels_list(model, cluster_by: [num_clusters: 4]) + end) + end + + test "additional option validations" do + model = Hierarchical.fit(Nx.tensor([[1], [2], [3]])) + + assert_raise(ArgumentError, "Must pass exactly one of `:height` or `:num_clusters`", fn -> + Hierarchical.labels_list(model, cluster_by: [num_clusters: 2, height: 1.0]) + end) + end + end +end