diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 57c882459e..c04ed72679 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -153,9 +153,6 @@ defmodule Nx.Defn.Grad do defp reduce_args(:take_along_axis, %{data: %{args: [arg | _]}}, acc, fun), do: fun.(arg, acc) - defp reduce_args(:take, %{data: %{args: [arg | _]}}, acc, fun), - do: fun.(arg, acc) - defp reduce_args(:gather, %{data: %{args: [arg | _]}}, acc, fun), do: fun.(arg, acc) @@ -704,69 +701,6 @@ defmodule Nx.Defn.Grad do [{t, g}] end - defp grad(:take, [t, i, axis], _ans, g) do - axes_range = 0..(Nx.rank(t) - 1)//1 - - indices_shape = - axes_range - |> Enum.flat_map(fn - ^axis -> Tuple.to_list(i.shape) - _ -> [1] - end) - |> List.to_tuple() - - idx_tiling = - t.shape - |> Tuple.to_list() - |> Enum.with_index(fn - _x, ^axis -> - List.duplicate(1, Nx.rank(i)) - - x, _ -> - x - end) - |> List.flatten() - - num_elements = Tuple.product(g.shape) - - indices_for_axis = - i - |> Nx.reshape(indices_shape) - |> Nx.tile(idx_tiling) - - axis_offset = Nx.rank(i) - 1 - - indices = - axes_range - |> Enum.map(fn - ^axis -> - indices_for_axis - |> Nx.reshape({num_elements, 1}) - - current when current < axis -> - indices_for_axis - |> Nx.shape() - |> Nx.iota(axis: current) - |> Nx.reshape({num_elements, 1}) - - current when current > axis -> - indices_for_axis - |> Nx.shape() - |> Nx.iota(axis: current + axis_offset) - |> Nx.reshape({num_elements, 1}) - end) - |> Nx.concatenate(axis: 1) - - updates = Nx.reshape(g, {num_elements}) - - g = - t - |> Expr.broadcast(0, Nx.shape(t), Nx.axes(t)) - |> Nx.indexed_add(indices, updates) - - [{t, g}] - end - defp grad(:gather, [t, i, opts], _ans, g) do i_axes = opts[:axes] i_shape = i.shape