Skip to content

Commit

Permalink
Grad for gather/indexed ops with axes (#1360)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Nov 11, 2023
1 parent 0e92f61 commit bb76359
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 262 deletions.
13 changes: 6 additions & 7 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1517,22 +1517,21 @@ defmodule EXLA.Defn do
Value.gather(tensor, indices, slice_sizes, offset_dims, axes, axes, index_vector_dim)
end

defp to_operator(:gather, [tensor, indices, opts], ans, _state) do
defp to_operator(:gather, [tensor, indices, opts], _ans, _state) do
axes = Keyword.fetch!(opts, :axes)
tensor_shape = op_shape(tensor)
tensor_rank = tuple_size(tensor_shape)
tensor_axes = axes_for_rank(tensor_rank)
index_vector_dim = tuple_size(op_shape(indices)) - 1

slice_sizes =
for i <- 0..(tensor_rank - 1) do
for i <- tensor_axes do
if i in axes, do: 1, else: elem(tensor_shape, i)
end

offset_dims = axes_for_rank(tensor_rank) -- axes

tensor
|> EXLA.Op.gather(indices, index_vector_dim, slice_sizes, offset_dims, axes, axes)
|> EXLA.Op.reshape(ans.shape)
batch_size = tensor_rank - length(axes)
offset_dims = count_up(batch_size, batch_size)
EXLA.Op.gather(tensor, indices, index_vector_dim, slice_sizes, offset_dims, axes, axes)
end

defp to_operator(:reverse, [%Value{} = tensor, axes], _ans, _state) do
Expand Down
92 changes: 36 additions & 56 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12663,48 +12663,6 @@ defmodule Nx do
]
>
iex> Nx.transpose(Nx.iota({2, 3, 4}, names: [:batch, :x, :y]), axes: [:y, :batch, :x])
#Nx.Tensor<
s64[y: 4][batch: 2][x: 3]
[
[
[0, 4, 8],
[12, 16, 20]
],
[
[1, 5, 9],
[13, 17, 21]
],
[
[2, 6, 10],
[14, 18, 22]
],
[
[3, 7, 11],
[15, 19, 23]
]
]
>
iex> Nx.transpose(Nx.iota({2, 3, 4}, names: [:batch, :x, :y]), axes: [:batch, :y, :x])
#Nx.Tensor<
s64[batch: 2][y: 4][x: 3]
[
[
[0, 4, 8],
[1, 5, 9],
[2, 6, 10],
[3, 7, 11]
],
[
[12, 16, 20],
[13, 17, 21],
[14, 18, 22],
[15, 19, 23]
]
]
>
### Vectorized tensors
For vectorized tensors, transpose will manipulate the inner shape only,
Expand Down Expand Up @@ -14331,31 +14289,53 @@ defmodule Nx do
### Gathering subsets
iex> t = Nx.tensor([[1, 2, 3], [3, 4, 5]])
iex> Nx.gather(t, Nx.tensor([[1], [0], [1]]))
iex> Nx.gather(t, Nx.tensor([[1], [0]]))
#Nx.Tensor<
s64[3][3]
s64[2][3]
[
[3, 4, 5],
[1, 2, 3],
[3, 4, 5]
[1, 2, 3]
]
>
The `:axes` option controls which dimensions the indexes point to,
this can be useful, for example, to access columns instead of rows:
this can be useful, for example, to access columns instead of rows.
Note can also access the same index several times:
iex> t = Nx.tensor([[[1, 2, 3]], [[4, 5, 6]]])
iex> Nx.gather(t, Nx.tensor([[1], [0]]), axes: [2])
iex> t = Nx.tensor([[1, 2, 3], [4, 5, 6]])
iex> Nx.gather(t, Nx.tensor([[1], [0], [2], [1]]), axes: [1])
#Nx.Tensor<
s64[2][2][1]
s64[4][2]
[
[2, 5],
[1, 4],
[3, 6],
[2, 5]
]
>
The overall output shape will have the format of the indices shape
(except the last element) followed by all non-indexed dimensions of
the tensor. Here is a more complex example:
iex> t = Nx.iota({2, 1, 3})
iex> Nx.gather(t, Nx.tensor([[[1], [0], [2]]]), axes: [2])
#Nx.Tensor<
s64[1][3][2][1]
[
[
[2],
[1]
],
[
[5],
[4]
[
[1],
[4]
],
[
[0],
[3]
],
[
[2],
[5]
]
]
]
>
Expand Down
Loading

0 comments on commit bb76359

Please sign in to comment.