Skip to content

Commit

Permalink
Remove EXLA implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin-Philip committed Feb 10, 2024
1 parent 9dc356b commit 7c3ec92
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 39 deletions.
1 change: 0 additions & 1 deletion exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ defmodule EXLA.Backend do
{:dot, [:left, :c1, :b1, :right, :c2, :b2], [:left, :right]},
{:clip, [:tensor, :min, :max], [:tensor, :min, :max]},
{:take, [:tensor, :indices, :axis], [:tensor, :indices]},
{:take_along_axis, [:tensor, :indices, :axis], [:tensor, :indices]},
{:gather, [:input, :indices, :opts], [:input, :indices]},
{:select, [:pred, :on_true, :on_false], [:pred, :on_true, :on_false]},
{:conv, [:tensor, :kernel, :opts], [:tensor, :kernel]},
Expand Down
38 changes: 0 additions & 38 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1608,44 +1608,6 @@ defmodule EXLA.Defn do
)
end

defp to_operator(:take_along_axis, [%mod{} = tensor, indices, axis], _ans, state) do
indices_shape = op_shape(indices)
indices_rank = tuple_size(indices_shape)

axes_range = 0..(indices_rank - 1)//1

index_vector_dim = indices_rank
slice_sizes = List.duplicate(1, indices_rank)
offset_dims = []
collapsed_slice_dims = Enum.to_list(axes_range)
start_index_map = Enum.to_list(axes_range)

indices_exla_shape = mod.get_shape(indices)

iotas =
Enum.map(axes_range, fn axis ->
mod.iota(state.builder, indices_exla_shape, axis)
end)

new_axis_shape = Tuple.append(indices_shape, 1)

indices =
iotas
|> List.replace_at(axis, indices)
|> Enum.map(&mod.reshape(&1, new_axis_shape))
|> mod.concatenate(indices_rank)

mod.gather(
tensor,
indices,
index_vector_dim,
slice_sizes,
offset_dims,
collapsed_slice_dims,
start_index_map
)
end

defp to_operator(:gather, [%mod{} = tensor, indices, opts], _ans, _state) do
axes = Keyword.fetch!(opts, :axes)
tensor_shape = op_shape(tensor)
Expand Down

0 comments on commit 7c3ec92

Please sign in to comment.