Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stack as a callback #1482

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,17 @@ defmodule EXLA.Backend do
jit([], expr_fun, tensors, [List.to_tuple(tensors)])
end

@impl true
def stack(out, tensors, axis) do
out = Nx.to_template(out)

expr_fun = fn tensors ->
Nx.Defn.Expr.stack(out, Tuple.to_list(tensors), axis)
end

jit([], expr_fun, tensors, [List.to_tuple(tensors)])
end

@impl true
def slice(out, tensor, start_indices, lengths, strides) do
out = Nx.to_template(out)
Expand Down
9 changes: 6 additions & 3 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1330,10 +1330,13 @@ defmodule EXLA.Defn do
end

defp to_operator(:concatenate, [[%Value{} | _rest] = tensors, axis], ans, _state) do
tensors =
tensors
|> Enum.map(&to_type(&1, ans.type))
tensors = Enum.map(tensors, &to_type(&1, ans.type))
Value.concatenate(tensors, axis, expr_to_typespec(ans))
end

defp to_operator(:stack, [[%Value{} | _rest] = tensors, axis], ans, _state) do
reshape_typespec = Typespec.tensor(ans.type, put_elem(ans.shape, axis, 1))
tensors = Enum.map(tensors, &(&1 |> to_type(ans.type) |> Value.reshape(reshape_typespec)))
Value.concatenate(tensors, axis, expr_to_typespec(ans))
end

Expand Down
78 changes: 42 additions & 36 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3251,20 +3251,9 @@ defmodule Nx do
"""
@doc type: :shape, from_backend: false
def new_axis(tensor, axis, name \\ nil) when is_integer(axis) do
apply_vectorized(tensor, fn tensor, offset ->
%{shape: shape, names: names} = tensor = to_tensor(tensor)
rank = tuple_size(shape)
norm = if axis < 0, do: axis + rank + 1, else: axis + offset

if norm not in offset..tuple_size(shape) do
raise ArgumentError,
"new axis position for shape #{inspect(shape)} must be " <>
"a number between #{-rank - 1 + offset} and #{rank - offset}, got: #{axis}"
end

new_shape = Tuple.insert_at(shape, norm, 1)
new_names = List.insert_at(names, norm, name)
impl!(tensor).reshape(%{tensor | shape: new_shape, names: new_names}, tensor)
apply_vectorized(tensor, fn %{shape: shape, names: names} = tensor, offset ->
{shape, names, _axis} = Nx.Shape.new_axis(shape, names, axis, name, 1, offset)
impl!(tensor).reshape(%{tensor | shape: shape, names: names}, tensor)
end)
end

Expand Down Expand Up @@ -14653,28 +14642,35 @@ defmodule Nx do
t

[_ | _] = tensors ->
[%T{vectorized_axes: vectorized_axes} | _] =
tensors = broadcast_vectors(tensors, align_ranks: true)
concatenate_or_stack(
tensors,
fn shapes, names, offset -> Nx.Shape.concatenate(shapes, names, axis, offset) end,
fn out, tensors, axis -> list_impl!(tensors).concatenate(out, tensors, axis) end
)
end
end

offset = length(vectorized_axes)
tensors = if vectorized_axes != [], do: Enum.map(tensors, &devectorize/1), else: tensors
defp concatenate_or_stack(tensors, shape_and_name, callback) do
[%T{vectorized_axes: vectorized_axes} | _] =
tensors = broadcast_vectors(tensors, align_ranks: true)

{types, [s1 | _] = shapes, [n1 | _] = names} =
Enum.reduce(tensors, {[], [], []}, fn
%T{type: t, shape: s, names: n}, {types, shapes, names} ->
{[t | types], [s | shapes], [n | names]}
end)
offset = length(vectorized_axes)
tensors = if vectorized_axes != [], do: Enum.map(tensors, &devectorize/1), else: tensors

{types, shapes, names} =
Enum.reduce(tensors, {[], [], []}, fn
%T{type: t, shape: s, names: n}, {types, shapes, names} ->
{[t | types], [s | shapes], [n | names]}
end)

axis = Nx.Shape.normalize_axis(s1, axis, n1, offset)
output_type = Enum.reduce(types, &Nx.Type.merge/2)
output_type = Enum.reduce(types, &Nx.Type.merge/2)

{output_shape, output_names} =
Nx.Shape.concatenate(Enum.reverse(shapes), Enum.reverse(names), axis)
{output_shape, output_names, axis} =
shape_and_name.(Enum.reverse(shapes), Enum.reverse(names), offset)

out = %{hd(tensors) | type: output_type, shape: output_shape, names: output_names}
result = list_impl!(tensors).concatenate(out, tensors, axis)
vectorize(result, vectorized_axes)
end
out = %{hd(tensors) | type: output_type, shape: output_shape, names: output_names}
result = callback.(out, tensors, axis)
vectorize(result, vectorized_axes)
end

defp flatten_list_or_container(list) when is_list(list) do
Expand Down Expand Up @@ -14792,16 +14788,26 @@ defmodule Nx do
>

"""
@doc type: :ndim, from_backend: false
@doc type: :ndim
def stack(tensors, opts \\ []) do
opts = keyword!(opts, axis: 0, name: nil)
axis = opts[:axis]
name = opts[:name]

tensors
|> flatten_list_or_container()
|> Enum.map(&Nx.new_axis(&1, axis, name))
|> Nx.concatenate(axis: axis)
case flatten_list_or_container(tensors) do
[] ->
raise ArgumentError, "no tensors were given to stack"

[t] ->
Nx.new_axis(t, axis, name)

[_ | _] = tensors ->
concatenate_or_stack(
tensors,
fn shapes, names, offset -> Nx.Shape.stack(shapes, names, axis, name, offset) end,
fn out, tensors, axis -> list_impl!(tensors).stack(out, tensors, axis) end
)
end
end

@doc """
Expand Down
1 change: 1 addition & 0 deletions nx/lib/nx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ defmodule Nx.Backend do
@callback take_along_axis(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor
@callback gather(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor
@callback concatenate(out :: tensor, tensor, axis) :: tensor
@callback stack(out :: tensor, tensor, axis) :: tensor
@callback select(out :: tensor, tensor, tensor, tensor) :: tensor

@callback conv(out :: tensor, tensor, kernel :: tensor, keyword) :: tensor
Expand Down
13 changes: 13 additions & 0 deletions nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2036,6 +2036,19 @@ defmodule Nx.BinaryBackend do
offset
end

@impl true
def stack(out, tensors, axis) do
%{shape: output_shape, type: {_, size} = output_type} = out

tensors
|> Enum.map(fn %{shape: shape} = t ->
t = as_type(%{t | type: output_type}, t)
{to_binary(t), Tuple.insert_at(shape, axis, 1)}
end)
|> bin_concatenate(size, axis, output_shape)
|> then(&from_binary(out, &1))
end

@impl true
def concatenate(out, tensors, axis) do
%{shape: output_shape, type: {_, size} = output_type} = out
Expand Down
2 changes: 1 addition & 1 deletion nx/lib/nx/defn/evaluator.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ defmodule Nx.Defn.Evaluator do
alias Nx.Defn.{Composite, Expr, Tree}

@creation_ops [:eye, :iota, :from_binary]
@list_ops [:concatenate]
@list_ops [:concatenate, :stack]
@indices_ops [:slice, :put_slice]

@impl true
Expand Down
6 changes: 6 additions & 0 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,12 @@ defmodule Nx.Defn.Expr do
expr(out, context, :concatenate, [tensors, axis])
end

@impl true
def stack(out, tensors, axis) do
{tensors, context} = to_exprs(tensors)
expr(out, context, :stack, [tensors, axis])
end

@impl true
def triangular_solve(out, a, b, opts) do
{[a, b], context} = to_exprs([a, b])
Expand Down
15 changes: 15 additions & 0 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,21 @@ defmodule Nx.Defn.Grad do
[{x, g}]
end

defp grad(:stack, [tensors, axis], ans, g) do
zero_axes = List.duplicate(0, Nx.rank(ans))
ans_shape_list = Tuple.to_list(ans.shape)

{pairs, _} =
Enum.map_reduce(tensors, 0, fn t, limit ->
current_limit = 1 + limit
start = List.replace_at(zero_axes, axis, limit)
len = List.replace_at(ans_shape_list, axis, 1)
{{t, Nx.slice(g, start, len)}, current_limit}
end)

pairs
end

defp grad(:concatenate, [tensors, axis], ans, g) do
zero_axes = List.duplicate(0, Nx.rank(ans))
ans_shape_list = Tuple.to_list(ans.shape)
Expand Down
3 changes: 2 additions & 1 deletion nx/lib/nx/defn/tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ defmodule Nx.Defn.Tree do
{[%{token | hooks: hooks}], acc}
end

def apply_args(%T{data: %Expr{op: :concatenate, args: [list | args]}}, _type, acc, fun) do
def apply_args(%T{data: %Expr{op: op, args: [list | args]}}, _type, acc, fun)
when op in [:concatenate, :stack] do
{list, acc} = Enum.map_reduce(list, acc, fun)
{[list | args], acc}
end
Expand Down
67 changes: 53 additions & 14 deletions nx/lib/nx/shape.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1648,17 +1648,65 @@ defmodule Nx.Shape do
{shape, names}
end

@doc """
Returns the shape and name of new axis.
"""
def new_axis(shape, names, axis, name, size, offset) do
rank = tuple_size(shape)
norm = if axis < 0, do: axis + rank + 1, else: axis + offset

if norm not in offset..tuple_size(shape) do
raise ArgumentError,
"new axis position for shape #{inspect(shape)} must be " <>
"a number between #{-rank - 1 + offset} and #{rank - offset}, got: #{axis}"
end

new_shape = Tuple.insert_at(shape, norm, size)
new_names = List.insert_at(names, norm, name)
{new_shape, new_names, norm}
end

@doc """
Returns the shape and names after a stack.

## Examples

iex> Nx.Shape.stack([{3, 2}, {3, 2}, {3, 2}], [[nil, nil], [nil, :z], [:y, nil]], 0, :x, 0)
{{3, 3, 2}, [:x, :y, :z], 0}
"""
def stack(shapes, names, axis, name, offset) do
names =
Enum.zip_with(names, fn zipped ->
Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis))
end)

case Enum.uniq(shapes) do
[shape] ->
new_axis(shape, names, axis, name, length(shapes), offset)

shapes ->
raise ArgumentError,
"can only stack tensors of the same shape, got distinct shapes: #{inspect(shapes)}"
end
end

@doc """
Returns the shape and names after a concat.

## Examples

iex> Nx.Shape.concatenate([{2, 3, 2}, {1, 3, 2}, {4, 3, 2}], [[:x, :y, :z], [:x, :y, :z], [:x, :y, :z]], 0)
{{7, 3, 2}, [:x, :y, :z]}
iex> Nx.Shape.concatenate([{2, 3, 2}, {1, 3, 2}, {4, 3, 2}], [[:x, :y, :z], [:x, :y, :z], [:x, :y, :z]], 0, 0)
{{7, 3, 2}, [:x, :y, :z], 0}
"""
def concatenate(shapes, names, axis) do
names = validate_concat_names!(names, axis)
{concat_dims(shapes, axis), names}
def concatenate([s1 | _] = shapes, [n1 | _] = names, axis, offset) do
axis = normalize_axis(s1, axis, n1, offset)

names =
Enum.zip_with(names, fn zipped ->
Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis))
end)

{concat_dims(shapes, axis), names, axis}
end

defp concat_dims([s1 | shapes] = all_shapes, axis) do
Expand Down Expand Up @@ -2120,15 +2168,6 @@ defmodule Nx.Shape do
)
end

defp validate_concat_names!(names, axis) do
_ =
Enum.zip_with(names, fn zipped ->
Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis))
end)

hd(names)
end

def fft({}) do
raise ArgumentError, "expected a tensor with rank > 0, got tensor with rank 0"
end
Expand Down
10 changes: 10 additions & 0 deletions torchx/lib/torchx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,16 @@ defmodule Torchx.Backend do
|> to_nx(out)
end

@impl true
def stack(out, tensors, axis) do
reshape = put_elem(out.shape, axis, 1)

tensors
|> Enum.map(&(&1 |> from_nx() |> Torchx.reshape(reshape)))
|> Torchx.concatenate(axis)
|> to_nx(out)
end

@impl true
def gather(out, tensor, indices, opts) do
tensor_axes = Nx.axes(tensor)
Expand Down
Loading