diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 581ef613c5..3ede3c33df 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -243,6 +243,17 @@ defmodule EXLA.Backend do jit([], expr_fun, tensors, [List.to_tuple(tensors)]) end + @impl true + def stack(out, tensors, axis) do + out = Nx.to_template(out) + + expr_fun = fn tensors -> + Nx.Defn.Expr.stack(out, Tuple.to_list(tensors), axis) + end + + jit([], expr_fun, tensors, [List.to_tuple(tensors)]) + end + @impl true def slice(out, tensor, start_indices, lengths, strides) do out = Nx.to_template(out) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 310fa40463..0d3e18fd7a 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1293,10 +1293,13 @@ defmodule EXLA.Defn do end defp to_operator(:concatenate, [[%Value{} | _rest] = tensors, axis], ans, _state) do - tensors = - tensors - |> Enum.map(&to_type(&1, ans.type)) + tensors = Enum.map(tensors, &to_type(&1, ans.type)) + Value.concatenate(tensors, axis, expr_to_typespec(ans)) + end + defp to_operator(:stack, [[%Value{} | _rest] = tensors, axis], ans, _state) do + reshape_typespec = Typespec.tensor(ans.type, put_elem(ans.shape, axis, 1)) + tensors = Enum.map(tensors, &(&1 |> to_type(ans.type) |> Value.reshape(reshape_typespec))) Value.concatenate(tensors, axis, expr_to_typespec(ans)) end diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index f8158634e9..24bea7ae7e 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -3251,20 +3251,9 @@ defmodule Nx do """ @doc type: :shape, from_backend: false def new_axis(tensor, axis, name \\ nil) when is_integer(axis) do - apply_vectorized(tensor, fn tensor, offset -> - %{shape: shape, names: names} = tensor = to_tensor(tensor) - rank = tuple_size(shape) - norm = if axis < 0, do: axis + rank + 1, else: axis + offset - - if norm not in offset..tuple_size(shape) do - raise ArgumentError, - "new axis position for shape #{inspect(shape)} must be " <> - "a number between #{-rank - 1 + offset} and #{rank - offset}, got: #{axis}" - end - - new_shape = Tuple.insert_at(shape, norm, 1) - new_names = List.insert_at(names, norm, name) - impl!(tensor).reshape(%{tensor | shape: new_shape, names: new_names}, tensor) + apply_vectorized(tensor, fn %{shape: shape, names: names} = tensor, offset -> + {shape, names, _axis} = Nx.Shape.new_axis(shape, names, axis, name, 1, offset) + impl!(tensor).reshape(%{tensor | shape: shape, names: names}, tensor) end) end @@ -14668,28 +14657,35 @@ defmodule Nx do t [_ | _] = tensors -> - [%T{vectorized_axes: vectorized_axes} | _] = - tensors = broadcast_vectors(tensors, align_ranks: true) + concatenate_or_stack( + tensors, + fn shapes, names, offset -> Nx.Shape.concatenate(shapes, names, axis, offset) end, + fn out, tensors, axis -> list_impl!(tensors).concatenate(out, tensors, axis) end + ) + end + end - offset = length(vectorized_axes) - tensors = if vectorized_axes != [], do: Enum.map(tensors, &devectorize/1), else: tensors + defp concatenate_or_stack(tensors, shape_and_name, callback) do + [%T{vectorized_axes: vectorized_axes} | _] = + tensors = broadcast_vectors(tensors, align_ranks: true) - {types, [s1 | _] = shapes, [n1 | _] = names} = - Enum.reduce(tensors, {[], [], []}, fn - %T{type: t, shape: s, names: n}, {types, shapes, names} -> - {[t | types], [s | shapes], [n | names]} - end) + offset = length(vectorized_axes) + tensors = if vectorized_axes != [], do: Enum.map(tensors, &devectorize/1), else: tensors + + {types, shapes, names} = + Enum.reduce(tensors, {[], [], []}, fn + %T{type: t, shape: s, names: n}, {types, shapes, names} -> + {[t | types], [s | shapes], [n | names]} + end) - axis = Nx.Shape.normalize_axis(s1, axis, n1, offset) - output_type = Enum.reduce(types, &Nx.Type.merge/2) + output_type = Enum.reduce(types, &Nx.Type.merge/2) - {output_shape, output_names} = - Nx.Shape.concatenate(Enum.reverse(shapes), Enum.reverse(names), axis) + {output_shape, output_names, axis} = + shape_and_name.(Enum.reverse(shapes), Enum.reverse(names), offset) - out = %{hd(tensors) | type: output_type, shape: output_shape, names: output_names} - result = list_impl!(tensors).concatenate(out, tensors, axis) - vectorize(result, vectorized_axes) - end + out = %{hd(tensors) | type: output_type, shape: output_shape, names: output_names} + result = callback.(out, tensors, axis) + vectorize(result, vectorized_axes) end defp flatten_list_or_container(list) when is_list(list) do @@ -14807,16 +14803,26 @@ defmodule Nx do > """ - @doc type: :ndim, from_backend: false + @doc type: :ndim def stack(tensors, opts \\ []) do opts = keyword!(opts, axis: 0, name: nil) axis = opts[:axis] name = opts[:name] - tensors - |> flatten_list_or_container() - |> Enum.map(&Nx.new_axis(&1, axis, name)) - |> Nx.concatenate(axis: axis) + case flatten_list_or_container(tensors) do + [] -> + raise ArgumentError, "no tensors were given to stack" + + [t] -> + Nx.new_axis(t, axis, name) + + [_ | _] = tensors -> + concatenate_or_stack( + tensors, + fn shapes, names, offset -> Nx.Shape.stack(shapes, names, axis, name, offset) end, + fn out, tensors, axis -> list_impl!(tensors).stack(out, tensors, axis) end + ) + end end @doc """ diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 0c2b573a99..6442319ea3 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -75,6 +75,7 @@ defmodule Nx.Backend do @callback put_slice(out :: tensor, tensor, tensor, list) :: tensor @callback gather(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor @callback concatenate(out :: tensor, tensor, axis) :: tensor + @callback stack(out :: tensor, tensor, axis) :: tensor @callback select(out :: tensor, tensor, tensor, tensor) :: tensor @callback conv(out :: tensor, tensor, kernel :: tensor, keyword) :: tensor diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index f24c5e2dda..402045c80f 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -1999,6 +1999,19 @@ defmodule Nx.BinaryBackend do offset end + @impl true + def stack(out, tensors, axis) do + %{shape: output_shape, type: {_, size} = output_type} = out + + tensors + |> Enum.map(fn %{shape: shape} = t -> + t = as_type(%{t | type: output_type}, t) + {to_binary(t), Tuple.insert_at(shape, axis, 1)} + end) + |> bin_concatenate(size, axis, output_shape) + |> then(&from_binary(out, &1)) + end + @impl true def concatenate(out, tensors, axis) do %{shape: output_shape, type: {_, size} = output_type} = out diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 24bae22572..382430694d 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -20,7 +20,7 @@ defmodule Nx.Defn.Evaluator do alias Nx.Defn.{Composite, Expr, Tree} @creation_ops [:eye, :iota, :from_binary] - @list_ops [:concatenate] + @list_ops [:concatenate, :stack] @indices_ops [:slice, :put_slice] @impl true diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index f534c04bee..e79731255a 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1201,6 +1201,12 @@ defmodule Nx.Defn.Expr do expr(out, context, :concatenate, [tensors, axis]) end + @impl true + def stack(out, tensors, axis) do + {tensors, context} = to_exprs(tensors) + expr(out, context, :stack, [tensors, axis]) + end + @impl true def triangular_solve(out, a, b, opts) do {[a, b], context} = to_exprs([a, b]) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 7fab315a70..17044a4041 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -614,6 +614,21 @@ defmodule Nx.Defn.Grad do [{x, g}] end + defp grad(:stack, [tensors, axis], ans, g) do + zero_axes = List.duplicate(0, Nx.rank(ans)) + ans_shape_list = Tuple.to_list(ans.shape) + + {pairs, _} = + Enum.map_reduce(tensors, 0, fn t, limit -> + current_limit = 1 + limit + start = List.replace_at(zero_axes, axis, limit) + len = List.replace_at(ans_shape_list, axis, 1) + {{t, Nx.slice(g, start, len)}, current_limit} + end) + + pairs + end + defp grad(:concatenate, [tensors, axis], ans, g) do zero_axes = List.duplicate(0, Nx.rank(ans)) ans_shape_list = Tuple.to_list(ans.shape) diff --git a/nx/lib/nx/defn/tree.ex b/nx/lib/nx/defn/tree.ex index 9d66f10780..67826582a3 100644 --- a/nx/lib/nx/defn/tree.ex +++ b/nx/lib/nx/defn/tree.ex @@ -202,7 +202,8 @@ defmodule Nx.Defn.Tree do {[%{token | hooks: hooks}], acc} end - def apply_args(%T{data: %Expr{op: :concatenate, args: [list | args]}}, _type, acc, fun) do + def apply_args(%T{data: %Expr{op: op, args: [list | args]}}, _type, acc, fun) + when op in [:concatenate, :stack] do {list, acc} = Enum.map_reduce(list, acc, fun) {[list | args], acc} end diff --git a/nx/lib/nx/shape.ex b/nx/lib/nx/shape.ex index 46f21da0c5..537fe5b947 100644 --- a/nx/lib/nx/shape.ex +++ b/nx/lib/nx/shape.ex @@ -1648,17 +1648,65 @@ defmodule Nx.Shape do {shape, names} end + @doc """ + Returns the shape and name of new axis. + """ + def new_axis(shape, names, axis, name, size, offset) do + rank = tuple_size(shape) + norm = if axis < 0, do: axis + rank + 1, else: axis + offset + + if norm not in offset..tuple_size(shape) do + raise ArgumentError, + "new axis position for shape #{inspect(shape)} must be " <> + "a number between #{-rank - 1 + offset} and #{rank - offset}, got: #{axis}" + end + + new_shape = Tuple.insert_at(shape, norm, size) + new_names = List.insert_at(names, norm, name) + {new_shape, new_names, norm} + end + + @doc """ + Returns the shape and names after a stack. + + ## Examples + + iex> Nx.Shape.stack([{3, 2}, {3, 2}, {3, 2}], [[nil, nil], [nil, :z], [:y, nil]], 0, :x, 0) + {{3, 3, 2}, [:x, :y, :z], 0} + """ + def stack(shapes, names, axis, name, offset) do + names = + Enum.zip_with(names, fn zipped -> + Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis)) + end) + + case Enum.uniq(shapes) do + [shape] -> + new_axis(shape, names, axis, name, length(shapes), offset) + + shapes -> + raise ArgumentError, + "can only stack tensors of the same shape, got distinct shapes: #{inspect(shapes)}" + end + end + @doc """ Returns the shape and names after a concat. ## Examples - iex> Nx.Shape.concatenate([{2, 3, 2}, {1, 3, 2}, {4, 3, 2}], [[:x, :y, :z], [:x, :y, :z], [:x, :y, :z]], 0) - {{7, 3, 2}, [:x, :y, :z]} + iex> Nx.Shape.concatenate([{2, 3, 2}, {1, 3, 2}, {4, 3, 2}], [[:x, :y, :z], [:x, :y, :z], [:x, :y, :z]], 0, 0) + {{7, 3, 2}, [:x, :y, :z], 0} """ - def concatenate(shapes, names, axis) do - names = validate_concat_names!(names, axis) - {concat_dims(shapes, axis), names} + def concatenate([s1 | _] = shapes, [n1 | _] = names, axis, offset) do + axis = normalize_axis(s1, axis, n1, offset) + + names = + Enum.zip_with(names, fn zipped -> + Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis)) + end) + + {concat_dims(shapes, axis), names, axis} end defp concat_dims([s1 | shapes] = all_shapes, axis) do @@ -2120,15 +2168,6 @@ defmodule Nx.Shape do ) end - defp validate_concat_names!(names, axis) do - _ = - Enum.zip_with(names, fn zipped -> - Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis)) - end) - - hd(names) - end - def fft({}) do raise ArgumentError, "expected a tensor with rank > 0, got tensor with rank 0" end diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 3c771f6fa7..a4333a6fca 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -337,6 +337,16 @@ defmodule Torchx.Backend do |> to_nx(out) end + @impl true + def stack(out, tensors, axis) do + reshape = put_elem(out.shape, axis, 1) + + tensors + |> Enum.map(&(&1 |> from_nx() |> Torchx.reshape(reshape))) + |> Torchx.concatenate(axis) + |> to_nx(out) + end + @impl true def gather(out, tensor, indices, opts) do tensor_axes = Nx.axes(tensor)