diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 4f394d4c6a..973c1b7353 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -261,7 +261,6 @@ defmodule EXLA.Backend do {:dot, [:left, :c1, :b1, :right, :c2, :b2], [:left, :right]}, {:clip, [:tensor, :min, :max], [:tensor, :min, :max]}, {:take, [:tensor, :indices, :axis], [:tensor, :indices]}, - {:take_along_axis, [:tensor, :indices, :axis], [:tensor, :indices]}, {:gather, [:input, :indices, :opts], [:input, :indices]}, {:select, [:pred, :on_true, :on_false], [:pred, :on_true, :on_false]}, {:conv, [:tensor, :kernel, :opts], [:tensor, :kernel]}, diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index d29e4bbcd4..e08a38135b 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1608,44 +1608,6 @@ defmodule EXLA.Defn do ) end - defp to_operator(:take_along_axis, [%mod{} = tensor, indices, axis], _ans, state) do - indices_shape = op_shape(indices) - indices_rank = tuple_size(indices_shape) - - axes_range = 0..(indices_rank - 1)//1 - - index_vector_dim = indices_rank - slice_sizes = List.duplicate(1, indices_rank) - offset_dims = [] - collapsed_slice_dims = Enum.to_list(axes_range) - start_index_map = Enum.to_list(axes_range) - - indices_exla_shape = mod.get_shape(indices) - - iotas = - Enum.map(axes_range, fn axis -> - mod.iota(state.builder, indices_exla_shape, axis) - end) - - new_axis_shape = Tuple.append(indices_shape, 1) - - indices = - iotas - |> List.replace_at(axis, indices) - |> Enum.map(&mod.reshape(&1, new_axis_shape)) - |> mod.concatenate(indices_rank) - - mod.gather( - tensor, - indices, - index_vector_dim, - slice_sizes, - offset_dims, - collapsed_slice_dims, - start_index_map - ) - end - defp to_operator(:gather, [%mod{} = tensor, indices, opts], _ans, _state) do axes = Keyword.fetch!(opts, :axes) tensor_shape = op_shape(tensor)