diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 8e0b87f362..460ceea900 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -325,7 +325,6 @@ defmodule EXLA.Backend do {:reverse, [:tensor, :axes], [:tensor]}, {:dot, [:left, :c1, :b1, :right, :c2, :b2], [:left, :right]}, {:clip, [:tensor, :min, :max], [:tensor, :min, :max]}, - {:take, [:tensor, :indices, :axis], [:tensor, :indices]}, {:take_along_axis, [:tensor, :indices, :axis], [:tensor, :indices]}, {:gather, [:input, :indices, :opts], [:input, :indices]}, {:select, [:pred, :on_true, :on_false], [:pred, :on_true, :on_false]}, diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 3cde646c68..8cfb7e8710 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1220,29 +1220,6 @@ defmodule EXLA.Defn do Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans)) end - defp to_operator(:take, [%Value{} = tensor, indices, axis], ans, _state) do - tensor_rank = tensor |> op_shape() |> tuple_size() - indices_rank = indices |> op_shape() |> tuple_size() - result_rank = tensor_rank - 1 + indices_rank - - index_vector_dim = indices_rank - slice_sizes = tensor |> op_shape() |> put_elem(axis, 1) |> Tuple.to_list() - offset_dims = result_rank |> axes_for_rank() |> delete_slice(axis, indices_rank) - collapsed_slice_dims = [axis] - start_index_map = [axis] - - Value.gather( - tensor, - indices, - index_vector_dim, - slice_sizes, - offset_dims, - collapsed_slice_dims, - start_index_map, - expr_to_typespec(ans) - ) - end - defp to_operator(:take_along_axis, [%Value{} = tensor, indices, axis], ans, state) do %{shape: indices_shape} = indices_typespec = Value.get_typespec(indices) indices_rank = tuple_size(indices_shape) @@ -1962,11 +1939,6 @@ defmodule EXLA.Defn do # Helpers - defp delete_slice(enumerable, index, length) do - {left, right} = Enum.split(enumerable, index) - left ++ Enum.drop(right, length) - end - defp apply_mlir_broadcasted_bin_op(op, out, left, right) do left_typespec = Value.get_typespec(left) right_typespec = Value.get_typespec(right) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 87831150b3..a0d87bd1fe 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -14130,13 +14130,17 @@ defmodule Nx do else tensor = devectorize(tensor, keep_names: false) indices = devectorize(indices, keep_names: false) + gather_indices = new_axis(indices, rank(indices)) - impl!(tensor).take( - %{tensor | shape: inner_shape, names: inner_names}, - tensor, - indices, - axis - ) + {indices_axes, tensor_axes} = Enum.split(axes(inner_shape), rank(indices)) + {leading, trailing} = Enum.split(tensor_axes, axis) + + transpose_axes = leading ++ indices_axes ++ trailing + + tensor + |> gather(gather_indices, axes: [axis]) + |> transpose(axes: transpose_axes) + |> reshape(inner_shape, names: inner_names) end end diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 5bc4a9fc65..339c7010dd 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -73,7 +73,6 @@ defmodule Nx.Backend do @callback clip(out :: tensor, tensor, min :: tensor, max :: tensor) :: tensor @callback slice(out :: tensor, tensor, list, list, list) :: tensor @callback put_slice(out :: tensor, tensor, tensor, list) :: tensor - @callback take(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor @callback take_along_axis(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor @callback gather(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor @callback concatenate(out :: tensor, tensor, axis) :: tensor diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 7ff9bc13f4..a4f5de664a 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -1939,47 +1939,6 @@ defmodule Nx.BinaryBackend do from_binary(out, data) end - @impl true - def take(out, tensor, indices, axis) do - # We iterate over the indices in a flat manner, - # and take a unit tensor slice along axis given - # by each index. Then we concatenate the tensors - # along the axis, which gives us the result with - # index dimensions flattened and we just reshape. - - %T{type: {_, size}, shape: shape} = tensor - %T{type: {_, idx_size}} = indices - - data = to_binary(tensor) - tensor_rank = tuple_size(shape) - slice_start = List.duplicate(0, tensor_rank) - slice_lengths = shape |> Tuple.to_list() |> List.replace_at(axis, 1) - slice_shape = List.to_tuple(slice_lengths) - strides = List.duplicate(1, tensor_rank) - - slices = - for <> do - idx = binary_to_number(bin, indices.type) - - if idx < 0 or idx >= elem(shape, axis) do - raise ArgumentError, - "index #{idx} is out of bounds for axis #{axis} in shape #{inspect(shape)}" - end - - slice_start = List.replace_at(slice_start, axis, idx) - - slice_data = - bin_slice(data, shape, size, slice_start, slice_lengths, strides, slice_shape) - - {slice_data, slice_shape} - end - - concat_shape = put_elem(tensor.shape, axis, length(slices)) - result_data = bin_concatenate(slices, size, axis, concat_shape) - - from_binary(out, result_data) - end - @impl true def take_along_axis( %T{type: output_type} = output, diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 35a7ebb65e..b5fb56bc72 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1183,12 +1183,6 @@ defmodule Nx.Defn.Expr do expr(out, context, :put_slice, [tensor, start, slice]) end - @impl true - def take(out, tensor, indices, axis) do - {[tensor, indices], context} = to_exprs([tensor, indices]) - expr(out, context, :take, [tensor, indices, axis]) - end - @impl true def take_along_axis(out, tensor, indices, axis) do {[tensor, indices], context} = to_exprs([tensor, indices]) diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index f660849a2c..80b5127454 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -337,52 +337,6 @@ defmodule Torchx.Backend do |> to_nx(out) end - @impl true - def take(out, t, i, axis) do - axes = Nx.axes(t) - - indices_shape = - axes - |> Enum.map(fn - ^axis -> Tuple.product(i.shape) - _ -> 1 - end) - |> List.to_tuple() - - idx_tiling = - t.shape - |> Tuple.to_list() - |> Enum.with_index(fn - _x, ^axis -> 1 - x, _ -> x - end) - - indices_for_axis = - i - |> Nx.reshape(indices_shape) - |> Nx.tile(idx_tiling) - - num_elements = Tuple.product(indices_for_axis.shape) - - indices = - axes - |> Enum.map(fn - ^axis -> - Nx.reshape(indices_for_axis, {num_elements, 1}) - - current -> - # current when current < axis -> - indices_for_axis - |> Nx.shape() - |> Nx.iota(axis: current, backend: __MODULE__) - |> Nx.reshape({num_elements, 1}) - end) - |> Nx.concatenate(axis: 1) - - # TODO: maybe rewrite it as gather now behaves differently - gather(out, t, indices, []) - end - @impl true def gather(out, tensor, indices, opts) do tensor_axes = Nx.axes(tensor)