Skip to content

Commit

Permalink
Remove other uses of :take
Browse files Browse the repository at this point in the history
josevalim committed May 12, 2024
1 parent 24ed6fa commit 50ecf0a
Showing 1 changed file with 0 additions and 66 deletions.
66 changes: 0 additions & 66 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
@@ -153,9 +153,6 @@ defmodule Nx.Defn.Grad do
defp reduce_args(:take_along_axis, %{data: %{args: [arg | _]}}, acc, fun),
do: fun.(arg, acc)

defp reduce_args(:take, %{data: %{args: [arg | _]}}, acc, fun),
do: fun.(arg, acc)

defp reduce_args(:gather, %{data: %{args: [arg | _]}}, acc, fun),
do: fun.(arg, acc)

@@ -704,69 +701,6 @@ defmodule Nx.Defn.Grad do
[{t, g}]
end

defp grad(:take, [t, i, axis], _ans, g) do
axes_range = 0..(Nx.rank(t) - 1)//1

indices_shape =
axes_range
|> Enum.flat_map(fn
^axis -> Tuple.to_list(i.shape)
_ -> [1]
end)
|> List.to_tuple()

idx_tiling =
t.shape
|> Tuple.to_list()
|> Enum.with_index(fn
_x, ^axis ->
List.duplicate(1, Nx.rank(i))

x, _ ->
x
end)
|> List.flatten()

num_elements = Tuple.product(g.shape)

indices_for_axis =
i
|> Nx.reshape(indices_shape)
|> Nx.tile(idx_tiling)

axis_offset = Nx.rank(i) - 1

indices =
axes_range
|> Enum.map(fn
^axis ->
indices_for_axis
|> Nx.reshape({num_elements, 1})

current when current < axis ->
indices_for_axis
|> Nx.shape()
|> Nx.iota(axis: current)
|> Nx.reshape({num_elements, 1})

current when current > axis ->
indices_for_axis
|> Nx.shape()
|> Nx.iota(axis: current + axis_offset)
|> Nx.reshape({num_elements, 1})
end)
|> Nx.concatenate(axis: 1)

updates = Nx.reshape(g, {num_elements})

g =
t
|> Expr.broadcast(0, Nx.shape(t), Nx.axes(t))
|> Nx.indexed_add(indices, updates)

[{t, g}]
end

defp grad(:gather, [t, i, opts], _ans, g) do
i_axes = opts[:axes]
i_shape = i.shape

0 comments on commit 50ecf0a

Please sign in to comment.