Skip to content

Commit

Permalink
Add stack as a callback (#1482)
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim authored May 13, 2024
1 parent 510e689 commit e5c82f8
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 55 deletions.
11 changes: 11 additions & 0 deletions exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,17 @@ defmodule EXLA.Backend do
jit([], expr_fun, tensors, [List.to_tuple(tensors)])
end

@impl true
def stack(out, tensors, axis) do
out = Nx.to_template(out)

expr_fun = fn tensors ->
Nx.Defn.Expr.stack(out, Tuple.to_list(tensors), axis)
end

jit([], expr_fun, tensors, [List.to_tuple(tensors)])
end

@impl true
def slice(out, tensor, start_indices, lengths, strides) do
out = Nx.to_template(out)
Expand Down
9 changes: 6 additions & 3 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1293,10 +1293,13 @@ defmodule EXLA.Defn do
end

defp to_operator(:concatenate, [[%Value{} | _rest] = tensors, axis], ans, _state) do
tensors =
tensors
|> Enum.map(&to_type(&1, ans.type))
tensors = Enum.map(tensors, &to_type(&1, ans.type))
Value.concatenate(tensors, axis, expr_to_typespec(ans))
end

defp to_operator(:stack, [[%Value{} | _rest] = tensors, axis], ans, _state) do
reshape_typespec = Typespec.tensor(ans.type, put_elem(ans.shape, axis, 1))
tensors = Enum.map(tensors, &(&1 |> to_type(ans.type) |> Value.reshape(reshape_typespec)))
Value.concatenate(tensors, axis, expr_to_typespec(ans))
end

Expand Down
78 changes: 42 additions & 36 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3251,20 +3251,9 @@ defmodule Nx do
"""
@doc type: :shape, from_backend: false
def new_axis(tensor, axis, name \\ nil) when is_integer(axis) do
apply_vectorized(tensor, fn tensor, offset ->
%{shape: shape, names: names} = tensor = to_tensor(tensor)
rank = tuple_size(shape)
norm = if axis < 0, do: axis + rank + 1, else: axis + offset

if norm not in offset..tuple_size(shape) do
raise ArgumentError,
"new axis position for shape #{inspect(shape)} must be " <>
"a number between #{-rank - 1 + offset} and #{rank - offset}, got: #{axis}"
end

new_shape = Tuple.insert_at(shape, norm, 1)
new_names = List.insert_at(names, norm, name)
impl!(tensor).reshape(%{tensor | shape: new_shape, names: new_names}, tensor)
apply_vectorized(tensor, fn %{shape: shape, names: names} = tensor, offset ->
{shape, names, _axis} = Nx.Shape.new_axis(shape, names, axis, name, 1, offset)
impl!(tensor).reshape(%{tensor | shape: shape, names: names}, tensor)
end)
end

Expand Down Expand Up @@ -14668,28 +14657,35 @@ defmodule Nx do
t

[_ | _] = tensors ->
[%T{vectorized_axes: vectorized_axes} | _] =
tensors = broadcast_vectors(tensors, align_ranks: true)
concatenate_or_stack(
tensors,
fn shapes, names, offset -> Nx.Shape.concatenate(shapes, names, axis, offset) end,
fn out, tensors, axis -> list_impl!(tensors).concatenate(out, tensors, axis) end
)
end
end

offset = length(vectorized_axes)
tensors = if vectorized_axes != [], do: Enum.map(tensors, &devectorize/1), else: tensors
defp concatenate_or_stack(tensors, shape_and_name, callback) do
[%T{vectorized_axes: vectorized_axes} | _] =
tensors = broadcast_vectors(tensors, align_ranks: true)

{types, [s1 | _] = shapes, [n1 | _] = names} =
Enum.reduce(tensors, {[], [], []}, fn
%T{type: t, shape: s, names: n}, {types, shapes, names} ->
{[t | types], [s | shapes], [n | names]}
end)
offset = length(vectorized_axes)
tensors = if vectorized_axes != [], do: Enum.map(tensors, &devectorize/1), else: tensors

{types, shapes, names} =
Enum.reduce(tensors, {[], [], []}, fn
%T{type: t, shape: s, names: n}, {types, shapes, names} ->
{[t | types], [s | shapes], [n | names]}
end)

axis = Nx.Shape.normalize_axis(s1, axis, n1, offset)
output_type = Enum.reduce(types, &Nx.Type.merge/2)
output_type = Enum.reduce(types, &Nx.Type.merge/2)

{output_shape, output_names} =
Nx.Shape.concatenate(Enum.reverse(shapes), Enum.reverse(names), axis)
{output_shape, output_names, axis} =
shape_and_name.(Enum.reverse(shapes), Enum.reverse(names), offset)

out = %{hd(tensors) | type: output_type, shape: output_shape, names: output_names}
result = list_impl!(tensors).concatenate(out, tensors, axis)
vectorize(result, vectorized_axes)
end
out = %{hd(tensors) | type: output_type, shape: output_shape, names: output_names}
result = callback.(out, tensors, axis)
vectorize(result, vectorized_axes)
end

defp flatten_list_or_container(list) when is_list(list) do
Expand Down Expand Up @@ -14807,16 +14803,26 @@ defmodule Nx do
>
"""
@doc type: :ndim, from_backend: false
@doc type: :ndim
def stack(tensors, opts \\ []) do
opts = keyword!(opts, axis: 0, name: nil)
axis = opts[:axis]
name = opts[:name]

tensors
|> flatten_list_or_container()
|> Enum.map(&Nx.new_axis(&1, axis, name))
|> Nx.concatenate(axis: axis)
case flatten_list_or_container(tensors) do
[] ->
raise ArgumentError, "no tensors were given to stack"

[t] ->
Nx.new_axis(t, axis, name)

[_ | _] = tensors ->
concatenate_or_stack(
tensors,
fn shapes, names, offset -> Nx.Shape.stack(shapes, names, axis, name, offset) end,
fn out, tensors, axis -> list_impl!(tensors).stack(out, tensors, axis) end
)
end
end

@doc """
Expand Down
1 change: 1 addition & 0 deletions nx/lib/nx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ defmodule Nx.Backend do
@callback put_slice(out :: tensor, tensor, tensor, list) :: tensor
@callback gather(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor
@callback concatenate(out :: tensor, tensor, axis) :: tensor
@callback stack(out :: tensor, tensor, axis) :: tensor
@callback select(out :: tensor, tensor, tensor, tensor) :: tensor

@callback conv(out :: tensor, tensor, kernel :: tensor, keyword) :: tensor
Expand Down
13 changes: 13 additions & 0 deletions nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1999,6 +1999,19 @@ defmodule Nx.BinaryBackend do
offset
end

@impl true
def stack(out, tensors, axis) do
%{shape: output_shape, type: {_, size} = output_type} = out

tensors
|> Enum.map(fn %{shape: shape} = t ->
t = as_type(%{t | type: output_type}, t)
{to_binary(t), Tuple.insert_at(shape, axis, 1)}
end)
|> bin_concatenate(size, axis, output_shape)
|> then(&from_binary(out, &1))
end

@impl true
def concatenate(out, tensors, axis) do
%{shape: output_shape, type: {_, size} = output_type} = out
Expand Down
2 changes: 1 addition & 1 deletion nx/lib/nx/defn/evaluator.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ defmodule Nx.Defn.Evaluator do
alias Nx.Defn.{Composite, Expr, Tree}

@creation_ops [:eye, :iota, :from_binary]
@list_ops [:concatenate]
@list_ops [:concatenate, :stack]
@indices_ops [:slice, :put_slice]

@impl true
Expand Down
6 changes: 6 additions & 0 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,12 @@ defmodule Nx.Defn.Expr do
expr(out, context, :concatenate, [tensors, axis])
end

@impl true
def stack(out, tensors, axis) do
{tensors, context} = to_exprs(tensors)
expr(out, context, :stack, [tensors, axis])
end

@impl true
def triangular_solve(out, a, b, opts) do
{[a, b], context} = to_exprs([a, b])
Expand Down
15 changes: 15 additions & 0 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,21 @@ defmodule Nx.Defn.Grad do
[{x, g}]
end

defp grad(:stack, [tensors, axis], ans, g) do
zero_axes = List.duplicate(0, Nx.rank(ans))
ans_shape_list = Tuple.to_list(ans.shape)

{pairs, _} =
Enum.map_reduce(tensors, 0, fn t, limit ->
current_limit = 1 + limit
start = List.replace_at(zero_axes, axis, limit)
len = List.replace_at(ans_shape_list, axis, 1)
{{t, Nx.slice(g, start, len)}, current_limit}
end)

pairs
end

defp grad(:concatenate, [tensors, axis], ans, g) do
zero_axes = List.duplicate(0, Nx.rank(ans))
ans_shape_list = Tuple.to_list(ans.shape)
Expand Down
3 changes: 2 additions & 1 deletion nx/lib/nx/defn/tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ defmodule Nx.Defn.Tree do
{[%{token | hooks: hooks}], acc}
end

def apply_args(%T{data: %Expr{op: :concatenate, args: [list | args]}}, _type, acc, fun) do
def apply_args(%T{data: %Expr{op: op, args: [list | args]}}, _type, acc, fun)
when op in [:concatenate, :stack] do
{list, acc} = Enum.map_reduce(list, acc, fun)
{[list | args], acc}
end
Expand Down
67 changes: 53 additions & 14 deletions nx/lib/nx/shape.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1648,17 +1648,65 @@ defmodule Nx.Shape do
{shape, names}
end

@doc """
Returns the shape and name of new axis.
"""
def new_axis(shape, names, axis, name, size, offset) do
rank = tuple_size(shape)
norm = if axis < 0, do: axis + rank + 1, else: axis + offset

if norm not in offset..tuple_size(shape) do
raise ArgumentError,
"new axis position for shape #{inspect(shape)} must be " <>
"a number between #{-rank - 1 + offset} and #{rank - offset}, got: #{axis}"
end

new_shape = Tuple.insert_at(shape, norm, size)
new_names = List.insert_at(names, norm, name)
{new_shape, new_names, norm}
end

@doc """
Returns the shape and names after a stack.
## Examples
iex> Nx.Shape.stack([{3, 2}, {3, 2}, {3, 2}], [[nil, nil], [nil, :z], [:y, nil]], 0, :x, 0)
{{3, 3, 2}, [:x, :y, :z], 0}
"""
def stack(shapes, names, axis, name, offset) do
names =
Enum.zip_with(names, fn zipped ->
Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis))
end)

case Enum.uniq(shapes) do
[shape] ->
new_axis(shape, names, axis, name, length(shapes), offset)

shapes ->
raise ArgumentError,
"can only stack tensors of the same shape, got distinct shapes: #{inspect(shapes)}"
end
end

@doc """
Returns the shape and names after a concat.
## Examples
iex> Nx.Shape.concatenate([{2, 3, 2}, {1, 3, 2}, {4, 3, 2}], [[:x, :y, :z], [:x, :y, :z], [:x, :y, :z]], 0)
{{7, 3, 2}, [:x, :y, :z]}
iex> Nx.Shape.concatenate([{2, 3, 2}, {1, 3, 2}, {4, 3, 2}], [[:x, :y, :z], [:x, :y, :z], [:x, :y, :z]], 0, 0)
{{7, 3, 2}, [:x, :y, :z], 0}
"""
def concatenate(shapes, names, axis) do
names = validate_concat_names!(names, axis)
{concat_dims(shapes, axis), names}
def concatenate([s1 | _] = shapes, [n1 | _] = names, axis, offset) do
axis = normalize_axis(s1, axis, n1, offset)

names =
Enum.zip_with(names, fn zipped ->
Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis))
end)

{concat_dims(shapes, axis), names, axis}
end

defp concat_dims([s1 | shapes] = all_shapes, axis) do
Expand Down Expand Up @@ -2120,15 +2168,6 @@ defmodule Nx.Shape do
)
end

defp validate_concat_names!(names, axis) do
_ =
Enum.zip_with(names, fn zipped ->
Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis))
end)

hd(names)
end

def fft({}) do
raise ArgumentError, "expected a tensor with rank > 0, got tensor with rank 0"
end
Expand Down
10 changes: 10 additions & 0 deletions torchx/lib/torchx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,16 @@ defmodule Torchx.Backend do
|> to_nx(out)
end

@impl true
def stack(out, tensors, axis) do
reshape = put_elem(out.shape, axis, 1)

tensors
|> Enum.map(&(&1 |> from_nx() |> Torchx.reshape(reshape)))
|> Torchx.concatenate(axis)
|> to_nx(out)
end

@impl true
def gather(out, tensor, indices, opts) do
tensor_axes = Nx.axes(tensor)
Expand Down

0 comments on commit e5c82f8

Please sign in to comment.