Skip to content

Commit

Permalink
Remove the Nx.take backend callback (#1439)
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin-Philip authored May 12, 2024
1 parent ad45733 commit 24ed6fa
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 129 deletions.
1 change: 0 additions & 1 deletion exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ defmodule EXLA.Backend do
{:reverse, [:tensor, :axes], [:tensor]},
{:dot, [:left, :c1, :b1, :right, :c2, :b2], [:left, :right]},
{:clip, [:tensor, :min, :max], [:tensor, :min, :max]},
{:take, [:tensor, :indices, :axis], [:tensor, :indices]},
{:take_along_axis, [:tensor, :indices, :axis], [:tensor, :indices]},
{:gather, [:input, :indices, :opts], [:input, :indices]},
{:select, [:pred, :on_true, :on_false], [:pred, :on_true, :on_false]},
Expand Down
28 changes: 0 additions & 28 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1220,29 +1220,6 @@ defmodule EXLA.Defn do
Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans))
end

defp to_operator(:take, [%Value{} = tensor, indices, axis], ans, _state) do
tensor_rank = tensor |> op_shape() |> tuple_size()
indices_rank = indices |> op_shape() |> tuple_size()
result_rank = tensor_rank - 1 + indices_rank

index_vector_dim = indices_rank
slice_sizes = tensor |> op_shape() |> put_elem(axis, 1) |> Tuple.to_list()
offset_dims = result_rank |> axes_for_rank() |> delete_slice(axis, indices_rank)
collapsed_slice_dims = [axis]
start_index_map = [axis]

Value.gather(
tensor,
indices,
index_vector_dim,
slice_sizes,
offset_dims,
collapsed_slice_dims,
start_index_map,
expr_to_typespec(ans)
)
end

defp to_operator(:take_along_axis, [%Value{} = tensor, indices, axis], ans, state) do
%{shape: indices_shape} = indices_typespec = Value.get_typespec(indices)
indices_rank = tuple_size(indices_shape)
Expand Down Expand Up @@ -1962,11 +1939,6 @@ defmodule EXLA.Defn do

# Helpers

defp delete_slice(enumerable, index, length) do
{left, right} = Enum.split(enumerable, index)
left ++ Enum.drop(right, length)
end

defp apply_mlir_broadcasted_bin_op(op, out, left, right) do
left_typespec = Value.get_typespec(left)
right_typespec = Value.get_typespec(right)
Expand Down
16 changes: 10 additions & 6 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14130,13 +14130,17 @@ defmodule Nx do
else
tensor = devectorize(tensor, keep_names: false)
indices = devectorize(indices, keep_names: false)
gather_indices = new_axis(indices, rank(indices))

impl!(tensor).take(
%{tensor | shape: inner_shape, names: inner_names},
tensor,
indices,
axis
)
{indices_axes, tensor_axes} = Enum.split(axes(inner_shape), rank(indices))
{leading, trailing} = Enum.split(tensor_axes, axis)

transpose_axes = leading ++ indices_axes ++ trailing

tensor
|> gather(gather_indices, axes: [axis])
|> transpose(axes: transpose_axes)
|> reshape(inner_shape, names: inner_names)
end
end

Expand Down
1 change: 0 additions & 1 deletion nx/lib/nx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ defmodule Nx.Backend do
@callback clip(out :: tensor, tensor, min :: tensor, max :: tensor) :: tensor
@callback slice(out :: tensor, tensor, list, list, list) :: tensor
@callback put_slice(out :: tensor, tensor, tensor, list) :: tensor
@callback take(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor
@callback take_along_axis(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor
@callback gather(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor
@callback concatenate(out :: tensor, tensor, axis) :: tensor
Expand Down
41 changes: 0 additions & 41 deletions nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1939,47 +1939,6 @@ defmodule Nx.BinaryBackend do
from_binary(out, data)
end

@impl true
def take(out, tensor, indices, axis) do
# We iterate over the indices in a flat manner,
# and take a unit tensor slice along axis given
# by each index. Then we concatenate the tensors
# along the axis, which gives us the result with
# index dimensions flattened and we just reshape.

%T{type: {_, size}, shape: shape} = tensor
%T{type: {_, idx_size}} = indices

data = to_binary(tensor)
tensor_rank = tuple_size(shape)
slice_start = List.duplicate(0, tensor_rank)
slice_lengths = shape |> Tuple.to_list() |> List.replace_at(axis, 1)
slice_shape = List.to_tuple(slice_lengths)
strides = List.duplicate(1, tensor_rank)

slices =
for <<bin::size(idx_size)-bitstring <- to_binary(indices)>> do
idx = binary_to_number(bin, indices.type)

if idx < 0 or idx >= elem(shape, axis) do
raise ArgumentError,
"index #{idx} is out of bounds for axis #{axis} in shape #{inspect(shape)}"
end

slice_start = List.replace_at(slice_start, axis, idx)

slice_data =
bin_slice(data, shape, size, slice_start, slice_lengths, strides, slice_shape)

{slice_data, slice_shape}
end

concat_shape = put_elem(tensor.shape, axis, length(slices))
result_data = bin_concatenate(slices, size, axis, concat_shape)

from_binary(out, result_data)
end

@impl true
def take_along_axis(
%T{type: output_type} = output,
Expand Down
6 changes: 0 additions & 6 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1183,12 +1183,6 @@ defmodule Nx.Defn.Expr do
expr(out, context, :put_slice, [tensor, start, slice])
end

@impl true
def take(out, tensor, indices, axis) do
{[tensor, indices], context} = to_exprs([tensor, indices])
expr(out, context, :take, [tensor, indices, axis])
end

@impl true
def take_along_axis(out, tensor, indices, axis) do
{[tensor, indices], context} = to_exprs([tensor, indices])
Expand Down
46 changes: 0 additions & 46 deletions torchx/lib/torchx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -337,52 +337,6 @@ defmodule Torchx.Backend do
|> to_nx(out)
end

@impl true
def take(out, t, i, axis) do
axes = Nx.axes(t)

indices_shape =
axes
|> Enum.map(fn
^axis -> Tuple.product(i.shape)
_ -> 1
end)
|> List.to_tuple()

idx_tiling =
t.shape
|> Tuple.to_list()
|> Enum.with_index(fn
_x, ^axis -> 1
x, _ -> x
end)

indices_for_axis =
i
|> Nx.reshape(indices_shape)
|> Nx.tile(idx_tiling)

num_elements = Tuple.product(indices_for_axis.shape)

indices =
axes
|> Enum.map(fn
^axis ->
Nx.reshape(indices_for_axis, {num_elements, 1})

current ->
# current when current < axis ->
indices_for_axis
|> Nx.shape()
|> Nx.iota(axis: current, backend: __MODULE__)
|> Nx.reshape({num_elements, 1})
end)
|> Nx.concatenate(axis: 1)

# TODO: maybe rewrite it as gather now behaves differently
gather(out, t, indices, [])
end

@impl true
def gather(out, tensor, indices, opts) do
tensor_axes = Nx.axes(tensor)
Expand Down

0 comments on commit 24ed6fa

Please sign in to comment.