diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 7845b65703..beaa293d77 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -5420,9 +5420,13 @@ defmodule Nx do {_, [], 0} -> fun.(left, right) - {[devec_left, devec_right], canonical_vectorized_axes, _offset} -> - devec_left - |> fun.(devec_right) + {[devec_left, devec_right], canonical_vectorized_axes, offset} -> + leading_names = Keyword.keys(canonical_vectorized_axes) + l = %{devec_left | names: leading_names ++ Enum.drop(devec_left.names, offset)} + r = %{devec_right | names: leading_names ++ Enum.drop(devec_right.names, offset)} + + l + |> fun.(r) |> vectorize(canonical_vectorized_axes) end end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index bb07ea30d7..997950a1f3 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1394,6 +1394,10 @@ defmodule Nx.Defn.Expr do ## Constant helpers and related optimizations + defp constant(%{vectorized_axes: [_ | _]} = out, number) do + tensor(Nx.fill(out, number, type: out.type)) + end + defp constant(%{shape: shape, type: type} = out, number) do number = cond do @@ -1661,7 +1665,7 @@ defmodule Nx.Defn.Expr do defp counter_to_name(counter), do: [?a + counter] - defp to_type_shape(%{type: type, shape: shape}) do + defp to_type_shape(%{vectorized_axes: [], type: type, shape: shape}) do brackets = shape |> Tuple.to_list() diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index bdfebcc22b..33a7a0deed 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -19,10 +19,10 @@ defmodule Nx.Defn.Grad do {:env, env} = Function.info(fun, :env) ids = stop_grads(env, ids) - # save vectorized axes before devectorizing - expr = to_grad |> fun.() + expr = fun.(to_grad) - transformed_expr = transform.(expr) |> validate_expr!() |> Nx.devectorize(keep_names: false) + transformed_expr = + expr |> transform.() |> validate_expr!() {parents, nodes} = parents_tree(transformed_expr, ids) @@ -33,23 +33,17 @@ defmodule Nx.Defn.Grad do Composite.traverse( to_grad, {nodes, grads}, - fn %{vectorized_axes: vectorized_axes} = node, acc -> - node - |> Nx.devectorize(keep_names: false) - |> to_grad(to_grad_ids, parents, acc) - |> then(fn {node, acc} -> - {Nx.vectorize(node, vectorized_axes), acc} - end) + fn node, acc -> + to_grad(node, to_grad_ids, parents, acc) end ) {expr, graded} end - defp constant(float, shape) do - shape = Nx.shape(shape) + defp constant(float, %T{shape: shape} = t) do names = List.duplicate(nil, tuple_size(shape)) - Expr.constant(%T{shape: shape, type: {:f, 32}, names: names}, float, []) + Expr.constant(%T{t | names: names, type: {:f, 32}}, float, []) end defp validate_expr!(%T{data: %Expr{}} = expr) do @@ -94,47 +88,88 @@ defmodule Nx.Defn.Grad do [:equal, :greater, :greater_equal, :less, :less_equal, :not_equal, :argsort] defp parents_tree(expr, nodes) do - Composite.reduce(expr, {%{}, nodes}, &recur_parents_tree/2) + Composite.reduce( + expr, + {%{}, nodes}, + &recur_parents_tree( + Nx.devectorize(&1, keep_names: true), + &2, + Keyword.keys(&1.vectorized_axes) + ) + ) end - defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}) do + defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}, vectorized_names) do case nodes do - %{^id => _} -> {parents, nodes} - %{} -> parents_args(op, t, id, {parents, Map.put(nodes, id, t)}) + %{^id => _} -> + {parents, nodes} + + %{} -> + # We use this to compute the proper axis sizes for the tensor + nodes = Map.put(nodes, id, {t, vectorized_names}) + + parents_args(op, t, id, {parents, nodes}, vectorized_names) end end - defp parents_args(:metadata, %{data: %{args: [_, %{stop_grad: true}]}}, _id, acc) do + defp parents_args( + :metadata, + %{data: %{args: [_, %{stop_grad: true}]}}, + _id, + acc, + _parent_vectorized_names + ) do acc end - defp parents_args(:optional, %{data: %{args: [call, _expr, callback]}} = t, id, acc) do + defp parents_args( + :optional, + %{data: %{args: [call, _expr, callback]}} = t, + id, + acc, + parent_vectorized_names + ) do expr = apply(callback, call.data.args) # Now traverse over the optional expression where args are the new parameters. # Once we access the parameter itself, we point the parameter to the arg. - {parents, nodes} = - Composite.reduce(expr, acc, fn expr, {parents, nodes} -> - parents = Map.update(parents, expr.data.id, [id], &[id | &1]) - recur_parents_tree(expr, {parents, nodes}) + {{parents, nodes}, _} = + Composite.reduce(expr, {acc, parent_vectorized_names}, fn + expr, {{parents, nodes}, expr_vectorized_names} -> + arg_vectorized_names = compute_arg_vectorized_names(expr, expr_vectorized_names) + parents = Map.update(parents, expr.data.id, [id], &[id | &1]) + + acc = + recur_parents_tree( + expr, + {parents, nodes}, + arg_vectorized_names + ) + + {acc, expr_vectorized_names} end) - {parents, Map.put(nodes, id, put_in(t.data.args, [call, expr, callback]))} + updated_node = + {put_in(t.data.args, [call, expr, callback]), parent_vectorized_names} + + {parents, Map.put(nodes, id, updated_node)} end # We register cond as a special node to avoid pretraversing it. # Instead we traverse it early on on the grad computation. - defp parents_args(:cond, _, id, {parents, nodes}) do + defp parents_args(:cond, _, id, {parents, nodes}, _parent_vectorized_names) do {Map.update(parents, __MODULE__, [id], &[id | &1]), nodes} end - defp parents_args(op, t, parent_id, acc) do + defp parents_args(op, t, parent_id, acc, parent_vectorized_names) do reduce_args(op, t, acc, fn arg, {parents, nodes} -> if arg.data.op in @constants do {parents, nodes} else + arg_vectorized_names = compute_arg_vectorized_names(t, parent_vectorized_names) parents = Map.update(parents, arg.data.id, [parent_id], &[parent_id | &1]) - recur_parents_tree(arg, {parents, nodes}) + + recur_parents_tree(arg, {parents, nodes}, arg_vectorized_names) end end) end @@ -191,10 +226,27 @@ defmodule Nx.Defn.Grad do case nodes do %{^id => _} -> {nodes, grads} = traverse_parents(id, to_grad_ids, parents, {nodes, grads}) - {ans, nodes} = Map.pop!(nodes, id) + {{ans, vectorized_names}, nodes} = Map.pop!(nodes, id) %T{data: %Expr{op: op, args: args}} = ans {gs, grads} = Map.pop(grads, id) + {args, ans} = + if vectorized_names != [] do + args = + Enum.map(args, fn + %T{} = arg -> + revectorize_node(arg, vectorized_names) + + opt -> + opt + end) + + ans = Nx.vectorize(ans, vectorized_names) + {args, ans} + else + {args, ans} + end + case gs do nil -> {nodes, grads} @@ -213,6 +265,22 @@ defmodule Nx.Defn.Grad do end end + defp compute_arg_vectorized_names(%{vectorized_axes: vectorized_axes}, []), + do: Keyword.keys(vectorized_axes) + + defp compute_arg_vectorized_names( + %{vectorized_axes: vectorized_axes, names: names}, + parent_names + ) do + Keyword.keys(vectorized_axes) ++ Enum.filter(names, &(&1 in parent_names)) + end + + defp revectorize_node(node, vectorized_names) do + vectorized_names = compute_arg_vectorized_names(node, vectorized_names) + + Nx.vectorize(node, vectorized_names) + end + defp update_grads(:elem, [%{type: {:tuple, size}} = tuple, pos], _ans, g, _to_grad_ids, grads) do update_in(grads[tuple.data.id], fn tuple -> tuple = tuple || Tuple.duplicate([], size) diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index 3a04276c27..6ab7e8ef78 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -4256,7 +4256,43 @@ defmodule Nx.Defn.GradTest do end describe "vectorization" do - test "supports vectorization" do + test "supports combination of vectorized and non-vectorized tensors" do + x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.vectorize(:x) + y = 1 + + grad = Nx.Defn.grad(y, fn y -> Nx.add(x, y) end) + + assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x]) + end + + test "supports combination of vectorized and non-vectorized tensors over composed function" do + x = Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) |> Nx.vectorize(:x) + y = 1 + + grad = Nx.Defn.grad(y, fn y -> Nx.add(y, Nx.sin(x)) end) + assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x]) + + grad = Nx.Defn.grad(x, fn x -> Nx.add(y, Nx.sin(x)) end) + assert grad == Nx.cos(x) + end + + # Skipping this as it's not supported yet. + @tag :skip + test "edge case where the same name changes meaning" do + x = Nx.tensor([[1], [2], [3]]) |> Nx.vectorize(x: 3) + + grad = + Nx.Defn.grad(x, fn t -> + devec = Nx.devectorize(t, keep_names: true) + new_axis = Nx.reshape(devec, {1, 3, 1}, names: [:x, nil, nil]) + + Nx.vectorize(new_axis, x: 1) + end) + + assert grad == Nx.tensor([[1], [1], [1]]) |> Nx.vectorize(x: 3) + end + + test "supports heterogenous vectorization combinations" do x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) y = Nx.tensor([10, 20]) @@ -4264,13 +4300,53 @@ defmodule Nx.Defn.GradTest do # expected result: equivalent to fully broadcasting one tensor onto the other x_vec = Nx.vectorize(x, :x) y_vec = Nx.vectorize(y, :y) - {grad_x_vec, grad_y_vec} = Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.multiply(a, b) end) + grad_fun = fn x, y -> + Nx.Defn.grad({x, y}, fn {a, b} -> Nx.multiply(a, b) end) + end + + {grad_x_vec, grad_y_vec} = grad_fun.(x_vec, y_vec) + + # Explicit assertion on the results assert grad_x_vec == - Nx.tensor([[30.0, 30.0, 30.0], [30.0, 30.0, 30.0]]) - |> Nx.vectorize(x_vec.vectorized_axes) + Nx.tensor([ + [ + [10.0, 10.0, 10.0], + [20.0, 20.0, 20.0] + ], + [ + [10.0, 10.0, 10.0], + [20.0, 20.0, 20.0] + ] + ]) + |> Nx.vectorize([:x, :y]) + + assert grad_y_vec == + Nx.tensor([ + [6.0, 6.0], + [15.0, 15.0] + ]) + |> Nx.vectorize([:x, :y]) - assert grad_y_vec == Nx.tensor([21.0, 21.0]) |> Nx.vectorize(y_vec.vectorized_axes) + # Conceptual assertion: the result should be equivalent to calling Nx.Defn.grad with + # each cross-entry of the combined vectors [(x0, y0), (x0, y1), (x1, y0), (x1, y1)] + + {x0y0_wrt_x, x0y0_wrt_y} = grad_fun.(x[0], y[0]) + {x0y1_wrt_x, x0y1_wrt_y} = grad_fun.(x[0], y[1]) + {x1y0_wrt_x, x1y0_wrt_y} = grad_fun.(x[1], y[0]) + {x1y1_wrt_x, x1y1_wrt_y} = grad_fun.(x[1], y[1]) + + assert grad_x_vec == + [x0y0_wrt_x, x0y1_wrt_x, x1y0_wrt_x, x1y1_wrt_x] + |> Nx.stack() + |> Nx.reshape({2, 2, 3}) + |> Nx.vectorize([:x, :y]) + + assert grad_y_vec == + [x0y0_wrt_y, x0y1_wrt_y, x1y0_wrt_y, x1y1_wrt_y] + |> Nx.stack() + |> Nx.reshape({2, 2}) + |> Nx.vectorize([:x, :y]) # second case: y is vectorized scalar, x is vectorized vectors, same vectorized axis name # expected result: equivalent to "row-wise" broadcasting