diff --git a/lib/scholar/manifold/mds.ex b/lib/scholar/manifold/mds.ex index 4be27464..7976c9da 100644 --- a/lib/scholar/manifold/mds.ex +++ b/lib/scholar/manifold/mds.ex @@ -328,7 +328,6 @@ defmodule Scholar.Manifold.MDS do ## Examples iex> x = Nx.iota({4,5}) - iex> key = Nx.Random.key(42) iex> init = Nx.reverse(Nx.iota({4,2})) iex> Scholar.Manifold.MDS.fit(x, init) %Scholar.Manifold.MDS{ diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index 2dc93524..47d7ae04 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -188,13 +188,13 @@ defmodule Scholar.Neighbors.KDTree do {level, tags, _tensor, _band} = while {level = Nx.u32(0), tags, tensor, band}, level < levels - 1 do k = rem(level, dims) - indices = Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) + indices = Nx.argsort(tensor[[.., k]] + band * tags, type: :u32, stable: true) tags = update_tags(tags, indices, level, levels, size) {level + 1, tags, tensor, band} end k = rem(level, dims) - Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) + Nx.argsort(tensor[[.., k]] + band * tags, type: :u32, stable: true) end defnp update_tags(tags, indices, level, levels, size) do