From 394a12d3c04acd8fdee6c30710498de765598a92 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sun, 15 Sep 2024 02:24:13 -0300 Subject: [PATCH 01/22] fix: broadcast vectors for grad calculation --- nx/lib/nx/defn/grad.ex | 15 ++++++--- nx/test/nx/defn/grad_test.exs | 62 ++++++++++++++++++++++++++++++++--- 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 17044a40419..49e5784c140 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -5,12 +5,19 @@ defmodule Nx.Defn.Grad do alias Nx.Tensor, as: T def transform(to_grad, fun, transform) do - {to_grad, ids} = - Composite.traverse(to_grad, %{}, fn to_grad, ids -> + broadcasted_nodes = + [to_grad] + |> Composite.flatten_list() + |> Nx.broadcast_vectors() + + {to_grad, {ids, []}} = + Composite.traverse(to_grad, {%{}, broadcasted_nodes}, fn _to_grad, + {ids, [broadcasted_node | nodes]} -> to_grad = - Expr.metadata(to_grad, %{__MODULE__ => :to_grad}) + Expr.metadata(broadcasted_node, %{__MODULE__ => :to_grad}) - {to_grad, Map.put(ids, to_grad.data.id, :stop)} + ids = Map.put(ids, to_grad.data.id, :stop) + {to_grad, {ids, nodes}} end) # Collect all IDs in the function environment and mark diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index 4242c548f3d..16613614b24 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -4238,7 +4238,19 @@ defmodule Nx.Defn.GradTest do end describe "vectorization" do - test "supports vectorization" do + test "supports combination of vectorized and non-vectorized tensors" do + x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.vectorize(:x) + y = Nx.tensor(1) + + {grad_x, grad_y} = Nx.Defn.grad({x, y}, fn {a, b} -> Nx.add(a, b) end) + + assert grad_x == + Nx.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) |> Nx.vectorize(x.vectorized_axes) + + assert grad_y == Nx.tensor([3.0, 3.0]) |> Nx.vectorize(x.vectorized_axes) + end + + test "supports heterogenous vectorization combinations" do x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) y = Nx.tensor([10, 20]) @@ -4246,13 +4258,53 @@ defmodule Nx.Defn.GradTest do # expected result: equivalent to fully broadcasting one tensor onto the other x_vec = Nx.vectorize(x, :x) y_vec = Nx.vectorize(y, :y) - {grad_x_vec, grad_y_vec} = Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.multiply(a, b) end) + grad_fun = fn x, y -> + Nx.Defn.grad({x, y}, fn {a, b} -> Nx.multiply(a, b) end) + end + + {grad_x_vec, grad_y_vec} = grad_fun.(x_vec, y_vec) + + # Explicit assertion on the results assert grad_x_vec == - Nx.tensor([[30.0, 30.0, 30.0], [30.0, 30.0, 30.0]]) - |> Nx.vectorize(x_vec.vectorized_axes) + Nx.tensor([ + [ + [10.0, 10.0, 10.0], + [20.0, 20.0, 20.0] + ], + [ + [10.0, 10.0, 10.0], + [20.0, 20.0, 20.0] + ] + ]) + |> Nx.vectorize([:x, :y]) - assert grad_y_vec == Nx.tensor([21.0, 21.0]) |> Nx.vectorize(y_vec.vectorized_axes) + assert grad_y_vec == + Nx.tensor([ + [6.0, 6.0], + [15.0, 15.0] + ]) + |> Nx.vectorize([:x, :y]) + + # Conceptual assertion: the result should be equivalent to calling Nx.Defn.grad with + # each cross-entry of the combined vectors [(x0, y0), (x0, y1), (x1, y0), (x1, y1)] + + {x0y0_wrt_x, x0y0_wrt_y} = grad_fun.(x[0], y[0]) + {x0y1_wrt_x, x0y1_wrt_y} = grad_fun.(x[0], y[1]) + {x1y0_wrt_x, x1y0_wrt_y} = grad_fun.(x[1], y[0]) + {x1y1_wrt_x, x1y1_wrt_y} = grad_fun.(x[1], y[1]) + + assert grad_x_vec == + [x0y0_wrt_x, x0y1_wrt_x, x1y0_wrt_x, x1y1_wrt_x] + |> Nx.stack() + |> Nx.reshape({2, 2, 3}) + |> Nx.vectorize([:x, :y]) + + assert grad_y_vec == + [x0y0_wrt_y, x0y1_wrt_y, x1y0_wrt_y, x1y1_wrt_y] + |> Nx.stack() + |> Nx.reshape({2, 2}) + |> Nx.vectorize([:x, :y]) # second case: y is vectorized scalar, x is vectorized vectors, same vectorized axis name # expected result: equivalent to "row-wise" broadcasting From 414726b4e1d304f271fd151cb7cecada45ade5aa Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sun, 15 Sep 2024 02:45:36 -0300 Subject: [PATCH 02/22] fix attempt --- nx/lib/nx/defn/grad.ex | 31 +++++++++++++++---------------- nx/test/nx/defn/grad_test.exs | 9 +++------ 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 49e5784c140..b6555f8d228 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -5,19 +5,23 @@ defmodule Nx.Defn.Grad do alias Nx.Tensor, as: T def transform(to_grad, fun, transform) do - broadcasted_nodes = - [to_grad] - |> Composite.flatten_list() - |> Nx.broadcast_vectors() + to_grad = + Composite.traverse(to_grad, fn to_grad -> + Expr.metadata(to_grad, %{__MODULE__ => :to_grad}) + end) + + # save vectorized axes before devectorizing + expr = fun.(to_grad) - {to_grad, {ids, []}} = - Composite.traverse(to_grad, {%{}, broadcasted_nodes}, fn _to_grad, - {ids, [broadcasted_node | nodes]} -> - to_grad = - Expr.metadata(broadcasted_node, %{__MODULE__ => :to_grad}) + transformed_expr = + expr |> transform.() |> validate_expr!() |> Nx.devectorize(keep_names: false) - ids = Map.put(ids, to_grad.data.id, :stop) - {to_grad, {ids, nodes}} + {to_grad, ids} = + Composite.traverse(to_grad, %{}, fn node, ids -> + [node, _expr] = Nx.broadcast_vectors([node, expr]) + node = Expr.metadata(node, %{__MODULE__ => :to_grad}) + ids = Map.put(ids, node.data.id, :stop) + {node, ids} end) # Collect all IDs in the function environment and mark @@ -25,11 +29,6 @@ defmodule Nx.Defn.Grad do # traversing trees when not necessary. {:env, env} = Function.info(fun, :env) ids = stop_grads(env, ids) - - # save vectorized axes before devectorizing - expr = to_grad |> fun.() - - transformed_expr = transform.(expr) |> validate_expr!() |> Nx.devectorize(keep_names: false) {parents, nodes} = parents_tree(transformed_expr, ids) to_grad_ids = {to_grad, ids} diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index 16613614b24..c7c3bca55e1 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -4240,14 +4240,11 @@ defmodule Nx.Defn.GradTest do describe "vectorization" do test "supports combination of vectorized and non-vectorized tensors" do x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.vectorize(:x) - y = Nx.tensor(1) + y = 1 - {grad_x, grad_y} = Nx.Defn.grad({x, y}, fn {a, b} -> Nx.add(a, b) end) + grad = Nx.Defn.grad(y, fn y -> Nx.add(x, y) end) - assert grad_x == - Nx.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) |> Nx.vectorize(x.vectorized_axes) - - assert grad_y == Nx.tensor([3.0, 3.0]) |> Nx.vectorize(x.vectorized_axes) + assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x]) end test "supports heterogenous vectorization combinations" do From a08d0fd05564d47ccba1f6bd7fd5215998ea0c90 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 16 Sep 2024 06:38:30 -0300 Subject: [PATCH 03/22] test: make core tests pass --- nx/lib/nx.ex | 10 +++- nx/lib/nx/defn/expr.ex | 4 +- nx/lib/nx/defn/grad.ex | 124 +++++++++++++++++++++++++++++++++-------- 3 files changed, 111 insertions(+), 27 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 7845b65703a..beaa293d770 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -5420,9 +5420,13 @@ defmodule Nx do {_, [], 0} -> fun.(left, right) - {[devec_left, devec_right], canonical_vectorized_axes, _offset} -> - devec_left - |> fun.(devec_right) + {[devec_left, devec_right], canonical_vectorized_axes, offset} -> + leading_names = Keyword.keys(canonical_vectorized_axes) + l = %{devec_left | names: leading_names ++ Enum.drop(devec_left.names, offset)} + r = %{devec_right | names: leading_names ++ Enum.drop(devec_right.names, offset)} + + l + |> fun.(r) |> vectorize(canonical_vectorized_axes) end end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index bb07ea30d7a..e1d88f01cfb 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -94,13 +94,13 @@ defmodule Nx.Defn.Expr do def metadata(expr, metadata) when is_map(metadata) do case to_container_expr(expr) do %{data: %{context: context}} = res -> - expr(res, context, :metadata, [Nx.devectorize(expr), metadata]) + expr(res, context, :metadata, [expr, metadata]) t when is_tuple(t) -> context = elem(t, 0).data.context tuple( - expr(tuple_out(tuple_size(t)), context, :metadata, [Nx.devectorize(expr), metadata]), + expr(tuple_out(tuple_size(t)), context, :metadata, [expr, metadata]), Tuple.to_list(t) ) end diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index b6555f8d228..bc824d75bb4 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -5,25 +5,28 @@ defmodule Nx.Defn.Grad do alias Nx.Tensor, as: T def transform(to_grad, fun, transform) do - to_grad = - Composite.traverse(to_grad, fn to_grad -> - Expr.metadata(to_grad, %{__MODULE__ => :to_grad}) - end) - - # save vectorized axes before devectorizing - expr = fun.(to_grad) - - transformed_expr = - expr |> transform.() |> validate_expr!() |> Nx.devectorize(keep_names: false) - {to_grad, ids} = Composite.traverse(to_grad, %{}, fn node, ids -> - [node, _expr] = Nx.broadcast_vectors([node, expr]) node = Expr.metadata(node, %{__MODULE__ => :to_grad}) ids = Map.put(ids, node.data.id, :stop) {node, ids} end) + expr = fun.(to_grad) + + transformed_expr = + expr |> transform.() |> validate_expr!() + + # |> Nx.devectorize(keep_names: false) + + # to_grad = + # Composite.traverse(to_grad, fn node -> + # [_expr, node] = Nx.broadcast_vectors([expr, node]) + # # ids = Map.put(ids, node.data.id, :stop) + # # {node, ids} + # node + # end) + # Collect all IDs in the function environment and mark # them as stop grads. This is an optimization to avoid # traversing trees when not necessary. @@ -38,13 +41,14 @@ defmodule Nx.Defn.Grad do Composite.traverse( to_grad, {nodes, grads}, - fn %{vectorized_axes: vectorized_axes} = node, acc -> + fn node, acc -> node - |> Nx.devectorize(keep_names: false) + # |> Nx.devectorize(keep_names: false) |> to_grad(to_grad_ids, parents, acc) - |> then(fn {node, acc} -> - {Nx.vectorize(node, vectorized_axes), acc} - end) + + # |> then(fn {node, acc} -> + # {Nx.vectorize(node, vectorized_axes), acc} + # end) end ) @@ -52,9 +56,16 @@ defmodule Nx.Defn.Grad do end defp constant(float, shape) do - shape = Nx.shape(shape) - names = List.duplicate(nil, tuple_size(shape)) - Expr.constant(%T{shape: shape, type: {:f, 32}, names: names}, float, []) + case shape do + %T{vectorized_axes: [_ | _]} = t -> + # [_expr, t] = Nx.broadcast_vectors([shape, float], align_ranks: false) + Expr.tensor(Nx.fill(t, float, type: :f32)) + + t -> + shape = Nx.shape(t) + names = List.duplicate(nil, tuple_size(shape)) + Expr.constant(%T{shape: shape, type: {:f, 32}, names: names}, float, []) + end end defp validate_expr!(%T{data: %Expr{}} = expr) do @@ -344,6 +355,8 @@ defmodule Nx.Defn.Grad do @verify_grad Application.compile_env(:nx, :verify_grad, false) defp update_grads(op, args, ans, g, _to_grad_ids, grads) do + args = revectorize_args(args, ans) + pairs = grad(op, args, ans, g) if @verify_grad do @@ -1349,9 +1362,76 @@ defmodule Nx.Defn.Grad do ## General helpers - defp unbroadcast(%{shape: shape} = x, res, %{shape: shape}), do: {x, res} + defp revectorize_args(args, ans) do + names_ans = + Enum.with_index(Keyword.keys(ans.vectorized_axes) ++ ans.names, fn name, idx -> + if(name, do: name, else: {:ans, idx}) + end) + + for arg <- args do + case arg do + %T{names: names} -> + names = Enum.with_index(names, fn name, idx -> if(name, do: {name, idx}) end) + + vectorized_axes = + names + |> Enum.reduce([], fn + nil, acc -> + acc + + {name, _idx}, acc -> + if name in names_ans do + [name | acc] + else + acc + end + end) + |> Enum.reverse() + + Nx.vectorize(arg, vectorized_axes) + + arg -> + arg + end + end + end + + defp unbroadcast(x, res, ans) do + # ans := y, x + + # y: [a, b, c] + # x: [b, d] + # ans: [b, a, c, d] + + # res/dx -> [b, d] {a, c, ...} + + vectorized_axes_x = Keyword.keys(x.vectorized_axes) + vectorized_axes_ans = Keyword.keys(ans.vectorized_axes) + + # num_extra_axes = length(vectorized_axes_ans -- vectorized_axes_x) + + permutation = + vectorized_axes_ans + |> Enum.with_index() + |> Enum.sort_by(fn {axis, _idx} -> axis in vectorized_axes_x end) + |> Enum.map(fn {_, idx} -> idx end) + + num_vectorized_axes = length(permutation) + + inner_axes = Enum.to_list(num_vectorized_axes..(num_vectorized_axes + Nx.rank(res) - 1)//1) + + res = + res + |> Nx.devectorize() + |> Nx.transpose(axes: permutation ++ inner_axes) + |> Nx.vectorize(vectorized_axes_x) + + unbroadcast2(x, res, ans) + end + + defp unbroadcast2(%{shape: shape} = x, res, %{shape: shape}), do: {x, res} - defp unbroadcast(%{shape: shape} = x, res, %{shape: new_shape}) do + defp unbroadcast2(%{shape: shape} = x, res, %{shape: new_shape}) do axes = Nx.Shape.broadcast_axes(shape, new_shape) {x, grad_broadcast(x, new_shape, axes, res)} end From 2f7c5f1d03cf9866ad275a724acd610ffaeca717 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 16 Sep 2024 06:41:16 -0300 Subject: [PATCH 04/22] fix: inspect vectorized axes as usual --- nx/lib/nx/defn/expr.ex | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index e1d88f01cfb..300dc7fd4d3 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1661,11 +1661,11 @@ defmodule Nx.Defn.Expr do defp counter_to_name(counter), do: [?a + counter] - defp to_type_shape(%{type: type, shape: shape}) do - brackets = - shape - |> Tuple.to_list() - |> Enum.map(&[?[, Integer.to_string(&1), ?]]) + defp to_type_shape(%{vectorized_axes: vectorized_axes, type: type, shape: shape}) do + axes = + Keyword.values(vectorized_axes) ++ Tuple.to_list(shape) + + brackets = Enum.map(axes, &[?[, Integer.to_string(&1), ?]]) IO.iodata_to_binary([Nx.Type.to_string(type) | brackets]) end From d87ffa13e5e9c5468049f2e8824ddc3438f4a357 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 16 Sep 2024 06:43:17 -0300 Subject: [PATCH 05/22] chore: revert some changes --- nx/lib/nx/defn/grad.ex | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index bc824d75bb4..33f5e24bcea 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -12,26 +12,17 @@ defmodule Nx.Defn.Grad do {node, ids} end) - expr = fun.(to_grad) - - transformed_expr = - expr |> transform.() |> validate_expr!() - - # |> Nx.devectorize(keep_names: false) - - # to_grad = - # Composite.traverse(to_grad, fn node -> - # [_expr, node] = Nx.broadcast_vectors([expr, node]) - # # ids = Map.put(ids, node.data.id, :stop) - # # {node, ids} - # node - # end) - # Collect all IDs in the function environment and mark # them as stop grads. This is an optimization to avoid # traversing trees when not necessary. {:env, env} = Function.info(fun, :env) ids = stop_grads(env, ids) + + expr = fun.(to_grad) + + transformed_expr = + expr |> transform.() |> validate_expr!() + {parents, nodes} = parents_tree(transformed_expr, ids) to_grad_ids = {to_grad, ids} @@ -58,7 +49,6 @@ defmodule Nx.Defn.Grad do defp constant(float, shape) do case shape do %T{vectorized_axes: [_ | _]} = t -> - # [_expr, t] = Nx.broadcast_vectors([shape, float], align_ranks: false) Expr.tensor(Nx.fill(t, float, type: :f32)) t -> From 7fbdffdb06017f6fb47404d70153fcb508d4369f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 16 Sep 2024 06:44:15 -0300 Subject: [PATCH 06/22] chore: remove commented code --- nx/lib/nx/defn/grad.ex | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 33f5e24bcea..bf2c8855b1a 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -29,19 +29,7 @@ defmodule Nx.Defn.Grad do grads = %{transformed_expr.data.id => [constant(1.0, transformed_expr)]} {graded, _} = - Composite.traverse( - to_grad, - {nodes, grads}, - fn node, acc -> - node - # |> Nx.devectorize(keep_names: false) - |> to_grad(to_grad_ids, parents, acc) - - # |> then(fn {node, acc} -> - # {Nx.vectorize(node, vectorized_axes), acc} - # end) - end - ) + Composite.traverse(to_grad, {nodes, grads}, &to_grad(&1, to_grad_ids, parents, &2)) {expr, graded} end From db0b6f048adb8455c022b29785fea1c588659d3a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 16 Sep 2024 06:45:42 -0300 Subject: [PATCH 07/22] chore: remove stray comments --- nx/lib/nx/defn/grad.ex | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index bf2c8855b1a..9958753ab4e 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -1375,14 +1375,6 @@ defmodule Nx.Defn.Grad do end defp unbroadcast(x, res, ans) do - # ans := y, x - - # y: [a, b, c] - # x: [b, d] - # ans: [b, a, c, d] - - # res/dx -> [b, d] {a, c, ...} - vectorized_axes_x = Keyword.keys(x.vectorized_axes) vectorized_axes_ans = Keyword.keys(ans.vectorized_axes) @@ -1404,12 +1396,12 @@ defmodule Nx.Defn.Grad do |> Nx.transpose(axes: permutation ++ inner_axes) |> Nx.vectorize(vectorized_axes_x) - unbroadcast2(x, res, ans) + grad_broadcast(x, res, ans) end - defp unbroadcast2(%{shape: shape} = x, res, %{shape: shape}), do: {x, res} + defp grad_broadcast(%{shape: shape} = x, res, %{shape: shape}), do: {x, res} - defp unbroadcast2(%{shape: shape} = x, res, %{shape: new_shape}) do + defp grad_broadcast(%{shape: shape} = x, res, %{shape: new_shape}) do axes = Nx.Shape.broadcast_axes(shape, new_shape) {x, grad_broadcast(x, new_shape, axes, res)} end From 20cc168030f33a0b4934cea22eb6ca10171c573e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 16 Sep 2024 06:46:06 -0300 Subject: [PATCH 08/22] chore: remove more stray comments --- nx/lib/nx/defn/grad.ex | 2 -- 1 file changed, 2 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 9958753ab4e..b7bbc8867fe 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -1378,8 +1378,6 @@ defmodule Nx.Defn.Grad do vectorized_axes_x = Keyword.keys(x.vectorized_axes) vectorized_axes_ans = Keyword.keys(ans.vectorized_axes) - # num_extra_axes = length(vectorized_axes_ans -- vectorized_axes_x) - permutation = vectorized_axes_ans |> Enum.with_index() From 22b9a24359ddf8794237fb85e0e7c4715f53dd19 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 16 Sep 2024 12:49:17 -0300 Subject: [PATCH 09/22] refactor: support vectorized constant --- nx/lib/nx/defn/expr.ex | 5 +++++ nx/lib/nx/defn/grad.ex | 39 ++++++++++++++++++++++----------------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 300dc7fd4d3..666d6280742 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1394,6 +1394,11 @@ defmodule Nx.Defn.Expr do ## Constant helpers and related optimizations + defp constant(%{vectorized_axes: [_ | _]} = out, number) do + out = %{out | names: Enum.map(out.names, fn _ -> nil end)} + tensor(Nx.fill(out, number, type: out.type)) + end + defp constant(%{shape: shape, type: type} = out, number) do number = cond do diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index b7bbc8867fe..60866f0cc9b 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -34,16 +34,10 @@ defmodule Nx.Defn.Grad do {expr, graded} end - defp constant(float, shape) do - case shape do - %T{vectorized_axes: [_ | _]} = t -> - Expr.tensor(Nx.fill(t, float, type: :f32)) - - t -> - shape = Nx.shape(t) - names = List.duplicate(nil, tuple_size(shape)) - Expr.constant(%T{shape: shape, type: {:f, 32}, names: names}, float, []) - end + defp constant(float, %T{shape: shape} = t) do + names = List.duplicate(nil, tuple_size(shape)) + + Expr.constant(%T{t | names: names, type: {:f, 32}}, float, []) end defp validate_expr!(%T{data: %Expr{}} = expr) do @@ -1351,22 +1345,33 @@ defmodule Nx.Defn.Grad do %T{names: names} -> names = Enum.with_index(names, fn name, idx -> if(name, do: {name, idx}) end) - vectorized_axes = + {vectorized_axes, offset} = names - |> Enum.reduce([], fn + |> Enum.reduce({[], 0}, fn nil, acc -> acc - {name, _idx}, acc -> + {name, _idx}, {acc, count} -> if name in names_ans do - [name | acc] + {[name | acc], count + 1} else - acc + {acc, count} end end) - |> Enum.reverse() - Nx.vectorize(arg, vectorized_axes) + axes_names = Enum.reverse(vectorized_axes) + + {vec_shape_list, shape_list} = arg.shape |> Tuple.to_list() |> Enum.split(offset) + + vectorized_axes = + Enum.zip(axes_names, vec_shape_list) + + %{ + arg + | vectorized_axes: vectorized_axes, + names: Enum.drop(arg.names, offset), + shape: List.to_tuple(shape_list) + } arg -> arg From 8f60a7173b140a87d3da4ee63075d10761bfa85b Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 16 Sep 2024 17:21:22 -0300 Subject: [PATCH 10/22] test: add x * sin(y) grad test --- nx/lib/nx/defn/expr.ex | 1 - nx/lib/nx/defn/grad.ex | 1 - nx/test/nx/defn/grad_test.exs | 11 +++++++++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 666d6280742..d472834578b 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1395,7 +1395,6 @@ defmodule Nx.Defn.Expr do ## Constant helpers and related optimizations defp constant(%{vectorized_axes: [_ | _]} = out, number) do - out = %{out | names: Enum.map(out.names, fn _ -> nil end)} tensor(Nx.fill(out, number, type: out.type)) end diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 60866f0cc9b..dba6071616d 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -36,7 +36,6 @@ defmodule Nx.Defn.Grad do defp constant(float, %T{shape: shape} = t) do names = List.duplicate(nil, tuple_size(shape)) - Expr.constant(%T{t | names: names, type: {:f, 32}}, float, []) end diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index c7c3bca55e1..508379ff976 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -4247,6 +4247,17 @@ defmodule Nx.Defn.GradTest do assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x]) end + test "supports combination of vectorized and non-vectorized tensors over composed function" do + x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.vectorize(:x) + y = 1 + + grad = Nx.Defn.grad(y, fn y -> Nx.add(y, Nx.sin(x)) end) + assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x]) + + grad = Nx.Defn.grad(x, fn x -> Nx.add(y, Nx.sin(x)) end) + assert grad == Nx.cos(x) + end + test "supports heterogenous vectorization combinations" do x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) y = Nx.tensor([10, 20]) From a026c0ebbb5fdea5500eeaefe4d6bdb3252796fd Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 03:46:15 -0300 Subject: [PATCH 11/22] feat: revectorize args only when strictly necessary --- nx/lib/nx/defn/grad.ex | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index dba6071616d..c49b2428a4f 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -326,7 +326,7 @@ defmodule Nx.Defn.Grad do @verify_grad Application.compile_env(:nx, :verify_grad, false) defp update_grads(op, args, ans, g, _to_grad_ids, grads) do - args = revectorize_args(args, ans) + args = revectorize_args(args, ans, g) pairs = grad(op, args, ans, g) @@ -1333,20 +1333,30 @@ defmodule Nx.Defn.Grad do ## General helpers - defp revectorize_args(args, ans) do - names_ans = - Enum.with_index(Keyword.keys(ans.vectorized_axes) ++ ans.names, fn name, idx -> - if(name, do: name, else: {:ans, idx}) - end) + defp revectorize_args(args, [ans | _], g) do + revectorize_args(args, ans, g) + end + + defp revectorize_args(args, %{} = ans, [g | _]) do + revectorize_args(args, ans, g) + end + + defp revectorize_args(args, %{vectorized_axes: []}, %{vectorized_axes: []}) do + args + end + + defp revectorize_args(args, ans, g) do + names_ans = Keyword.keys(ans.vectorized_axes) ++ Keyword.keys(g.vectorized_axes) ++ ans.names + + names_ans = names_ans |> Enum.filter(& &1) |> MapSet.new() for arg <- args do - case arg do + case Nx.devectorize(arg) do %T{names: names} -> names = Enum.with_index(names, fn name, idx -> if(name, do: {name, idx}) end) {vectorized_axes, offset} = - names - |> Enum.reduce({[], 0}, fn + Enum.reduce(names, {[], 0}, fn nil, acc -> acc @@ -1372,7 +1382,7 @@ defmodule Nx.Defn.Grad do shape: List.to_tuple(shape_list) } - arg -> + _arg -> arg end end From 37408b316401de93248cd9763bca2941bf01dc8a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 03:49:28 -0300 Subject: [PATCH 12/22] fix: correctness of revectorize_args over possible kw_args functions --- nx/lib/nx/defn/grad.ex | 56 ++++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index c49b2428a4f..293a39123f5 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -4,6 +4,8 @@ defmodule Nx.Defn.Grad do alias Nx.Defn.{Composite, Expr, Tree} alias Nx.Tensor, as: T + require Nx + def transform(to_grad, fun, transform) do {to_grad, ids} = Composite.traverse(to_grad, %{}, fn node, ids -> @@ -1350,42 +1352,42 @@ defmodule Nx.Defn.Grad do names_ans = names_ans |> Enum.filter(& &1) |> MapSet.new() - for arg <- args do - case Nx.devectorize(arg) do - %T{names: names} -> - names = Enum.with_index(names, fn name, idx -> if(name, do: {name, idx}) end) + {tensor_args, kw_args} = Enum.split_while(args, &Nx.is_tensor/1) - {vectorized_axes, offset} = - Enum.reduce(names, {[], 0}, fn - nil, acc -> - acc + revec_tensor_args = + for arg <- tensor_args do + %T{names: names} = arg = Nx.devectorize(arg) + names = Enum.with_index(names, fn name, idx -> if(name, do: {name, idx}) end) - {name, _idx}, {acc, count} -> - if name in names_ans do - {[name | acc], count + 1} - else - {acc, count} - end - end) + {vectorized_axes, offset} = + Enum.reduce(names, {[], 0}, fn + nil, acc -> + acc - axes_names = Enum.reverse(vectorized_axes) + {name, _idx}, {acc, count} -> + if name in names_ans do + {[name | acc], count + 1} + else + {acc, count} + end + end) - {vec_shape_list, shape_list} = arg.shape |> Tuple.to_list() |> Enum.split(offset) + axes_names = Enum.reverse(vectorized_axes) - vectorized_axes = - Enum.zip(axes_names, vec_shape_list) + {vec_shape_list, shape_list} = arg.shape |> Tuple.to_list() |> Enum.split(offset) - %{ - arg - | vectorized_axes: vectorized_axes, - names: Enum.drop(arg.names, offset), - shape: List.to_tuple(shape_list) - } + vectorized_axes = + Enum.zip(axes_names, vec_shape_list) - _arg -> + %{ arg + | vectorized_axes: vectorized_axes, + names: Enum.drop(arg.names, offset), + shape: List.to_tuple(shape_list) + } end - end + + revec_tensor_args ++ kw_args end defp unbroadcast(x, res, ans) do From 380c3303feb3ce66809dc9a453454a0b96669fbe Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:07:14 -0300 Subject: [PATCH 13/22] refactor: simpler revectorization of nodes --- nx/lib/nx/defn/grad.ex | 74 +++++++++--------------------------------- 1 file changed, 16 insertions(+), 58 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 293a39123f5..c47290ab89a 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -36,6 +36,16 @@ defmodule Nx.Defn.Grad do {expr, graded} end + defp revectorize_node(%{vectorized_axes: vectorized_axes, names: names} = node) do + vec_names = Enum.take_while(names, &(not is_nil(&1))) + + node + |> Nx.devectorize() + |> Nx.vectorize(vectorized_axes ++ vec_names) + end + + defp revectorize_node(arg), do: arg + defp constant(float, %T{shape: shape} = t) do names = List.duplicate(nil, tuple_size(shape)) Expr.constant(%T{t | names: names, type: {:f, 32}}, float, []) @@ -184,6 +194,9 @@ defmodule Nx.Defn.Grad do %T{data: %Expr{op: op, args: args}} = ans {gs, grads} = Map.pop(grads, id) + args = + Enum.map(args, &revectorize_node/1) + case gs do nil -> {nodes, grads} @@ -328,8 +341,6 @@ defmodule Nx.Defn.Grad do @verify_grad Application.compile_env(:nx, :verify_grad, false) defp update_grads(op, args, ans, g, _to_grad_ids, grads) do - args = revectorize_args(args, ans, g) - pairs = grad(op, args, ans, g) if @verify_grad do @@ -1335,63 +1346,8 @@ defmodule Nx.Defn.Grad do ## General helpers - defp revectorize_args(args, [ans | _], g) do - revectorize_args(args, ans, g) - end - - defp revectorize_args(args, %{} = ans, [g | _]) do - revectorize_args(args, ans, g) - end - - defp revectorize_args(args, %{vectorized_axes: []}, %{vectorized_axes: []}) do - args - end - - defp revectorize_args(args, ans, g) do - names_ans = Keyword.keys(ans.vectorized_axes) ++ Keyword.keys(g.vectorized_axes) ++ ans.names - - names_ans = names_ans |> Enum.filter(& &1) |> MapSet.new() - - {tensor_args, kw_args} = Enum.split_while(args, &Nx.is_tensor/1) - - revec_tensor_args = - for arg <- tensor_args do - %T{names: names} = arg = Nx.devectorize(arg) - names = Enum.with_index(names, fn name, idx -> if(name, do: {name, idx}) end) - - {vectorized_axes, offset} = - Enum.reduce(names, {[], 0}, fn - nil, acc -> - acc - - {name, _idx}, {acc, count} -> - if name in names_ans do - {[name | acc], count + 1} - else - {acc, count} - end - end) - - axes_names = Enum.reverse(vectorized_axes) - - {vec_shape_list, shape_list} = arg.shape |> Tuple.to_list() |> Enum.split(offset) - - vectorized_axes = - Enum.zip(axes_names, vec_shape_list) - - %{ - arg - | vectorized_axes: vectorized_axes, - names: Enum.drop(arg.names, offset), - shape: List.to_tuple(shape_list) - } - end - - revec_tensor_args ++ kw_args - end - defp unbroadcast(x, res, ans) do - vectorized_axes_x = Keyword.keys(x.vectorized_axes) + vectorized_axes_x = Keyword.keys(x.vectorized_axes) ++ Enum.filter(x.names, & &1) vectorized_axes_ans = Keyword.keys(ans.vectorized_axes) permutation = @@ -1410,6 +1366,8 @@ defmodule Nx.Defn.Grad do |> Nx.transpose(axes: permutation ++ inner_axes) |> Nx.vectorize(vectorized_axes_x) + x = x |> Nx.devectorize(keep_names: false) |> Nx.vectorize(vectorized_axes_x) + grad_broadcast(x, res, ans) end From 1487b17ab6b40fb4dc8d279fff1509d34fe65940 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:14:36 -0300 Subject: [PATCH 14/22] refactor: revectorize in place --- nx/lib/nx/defn/grad.ex | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index c47290ab89a..c0965438f2d 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -36,16 +36,6 @@ defmodule Nx.Defn.Grad do {expr, graded} end - defp revectorize_node(%{vectorized_axes: vectorized_axes, names: names} = node) do - vec_names = Enum.take_while(names, &(not is_nil(&1))) - - node - |> Nx.devectorize() - |> Nx.vectorize(vectorized_axes ++ vec_names) - end - - defp revectorize_node(arg), do: arg - defp constant(float, %T{shape: shape} = t) do names = List.duplicate(nil, tuple_size(shape)) Expr.constant(%T{t | names: names, type: {:f, 32}}, float, []) @@ -215,6 +205,28 @@ defmodule Nx.Defn.Grad do end end + defp revectorize_node(%{shape: shape, vectorized_axes: vectorized_axes, names: names} = node) do + {reverse_vec_names, all_nil_names, reverse_inner_shape} = + Enum.zip_reduce(names, Tuple.to_list(shape), {[], [], []}, fn + nil, axis_size, {v, n, s} -> + {v, [nil | n], [axis_size | s]} + + name, axis_size, {v, n, s} -> + {[{name, axis_size} | v], n, s} + end) + + inner_shape = reverse_inner_shape |> Enum.reverse() |> List.to_tuple() + + %{ + node + | vectorized_axes: vectorized_axes ++ Enum.reverse(reverse_vec_names), + names: all_nil_names, + shape: inner_shape + } + end + + defp revectorize_node(arg), do: arg + defp update_grads(:elem, [%{type: {:tuple, size}} = tuple, pos], _ans, g, _to_grad_ids, grads) do update_in(grads[tuple.data.id], fn tuple -> tuple = tuple || Tuple.duplicate([], size) From a832956f20ace98278cb001605924aefc6c7c50e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:06:16 -0300 Subject: [PATCH 15/22] fix: propagate vectorized axes throughout the recursion' --- nx/lib/nx.ex | 9 +- nx/lib/nx/defn/grad.ex | 164 ++++++++++++++++++++++++---------- nx/test/nx/defn/grad_test.exs | 2 +- 3 files changed, 124 insertions(+), 51 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index beaa293d770..ed070ef4b06 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -4906,12 +4906,19 @@ defmodule Nx do def devectorize(%T{shape: shape, names: names, vectorized_axes: vectorized_axes} = tensor, opts) when vectorized_axes != [] do - opts = keyword!(opts, keep_names: true) + opts = keyword!(opts, keep_names: true, drop_inner_names: false) {vectorized_names, vectorized_sizes} = Enum.unzip(vectorized_axes) output_shape_l = vectorized_sizes ++ Tuple.to_list(shape) output_shape = List.to_tuple(output_shape_l) + names = + if opts[:drop_inner_names] do + Enum.map(names, fn _ -> nil end) + else + names + end + output_names = if opts[:keep_names] do vectorized_names ++ names diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index c0965438f2d..358b84f38d8 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -83,47 +83,93 @@ defmodule Nx.Defn.Grad do [:equal, :greater, :greater_equal, :less, :less_equal, :not_equal, :argsort] defp parents_tree(expr, nodes) do - Composite.reduce(expr, {%{}, nodes}, &recur_parents_tree/2) + Composite.reduce( + expr, + {%{}, nodes}, + &recur_parents_tree(Nx.devectorize(&1, keep_names: true), &2, &1.vectorized_axes) + ) end - defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}) do + defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}, vectorized_axes) do case nodes do - %{^id => _} -> {parents, nodes} - %{} -> parents_args(op, t, id, {parents, Map.put(nodes, id, t)}) + %{^id => _} -> + {parents, nodes} + + %{} -> + parent_vectorized_axes = compute_arg_vectorized_axes(t, vectorized_axes) + + nodes = Map.put(nodes, id, {Nx.devectorize(t, keep_names: true), parent_vectorized_axes}) + + parents_args(op, t, id, {parents, nodes}, vectorized_axes) end end - defp parents_args(:metadata, %{data: %{args: [_, %{stop_grad: true}]}}, _id, acc) do + defp parents_args( + :metadata, + %{data: %{args: [_, %{stop_grad: true}]}}, + _id, + acc, + _parent_vectorized_axes + ) do acc end - defp parents_args(:optional, %{data: %{args: [call, _expr, callback]}} = t, id, acc) do + defp parents_args( + :optional, + %{data: %{args: [call, _expr, callback]}} = t, + id, + acc, + parent_vectorized_axes + ) do expr = apply(callback, call.data.args) # Now traverse over the optional expression where args are the new parameters. # Once we access the parameter itself, we point the parameter to the arg. - {parents, nodes} = - Composite.reduce(expr, acc, fn expr, {parents, nodes} -> - parents = Map.update(parents, expr.data.id, [id], &[id | &1]) - recur_parents_tree(expr, {parents, nodes}) - end) + {{parents, nodes}, _} = + Composite.reduce( + Nx.devectorize(expr, keep_names: true), + {acc, parent_vectorized_axes}, + fn expr, {{parents, nodes}, expr_vectorized_axes} -> + arg_vectorized_axes = compute_arg_vectorized_axes(expr, expr_vectorized_axes) + parents = Map.update(parents, expr.data.id, [id], &[id | &1]) + + acc = + recur_parents_tree( + expr, + {parents, nodes}, + arg_vectorized_axes + ) + + {acc, expr_vectorized_axes} + end + ) + + updated_node = + {put_in(t.data.args, [call, expr, callback]) |> Nx.devectorize(keep_names: true), + parent_vectorized_axes} - {parents, Map.put(nodes, id, put_in(t.data.args, [call, expr, callback]))} + {parents, Map.put(nodes, id, updated_node)} end # We register cond as a special node to avoid pretraversing it. # Instead we traverse it early on on the grad computation. - defp parents_args(:cond, _, id, {parents, nodes}) do + defp parents_args(:cond, _, id, {parents, nodes}, _parent_vectorized_axes) do {Map.update(parents, __MODULE__, [id], &[id | &1]), nodes} end - defp parents_args(op, t, parent_id, acc) do + defp parents_args(op, t, parent_id, acc, parent_vectorized_axes) do reduce_args(op, t, acc, fn arg, {parents, nodes} -> if arg.data.op in @constants do {parents, nodes} else + arg_vectorized_axes = compute_arg_vectorized_axes(t, parent_vectorized_axes) parents = Map.update(parents, arg.data.id, [parent_id], &[parent_id | &1]) - recur_parents_tree(arg, {parents, nodes}) + + recur_parents_tree( + Nx.devectorize(arg, keep_names: true), + {parents, nodes}, + arg_vectorized_axes + ) end end) end @@ -180,12 +226,26 @@ defmodule Nx.Defn.Grad do case nodes do %{^id => _} -> {nodes, grads} = traverse_parents(id, to_grad_ids, parents, {nodes, grads}) - {ans, nodes} = Map.pop!(nodes, id) + {{ans, vectorized_axes}, nodes} = Map.pop!(nodes, id) %T{data: %Expr{op: op, args: args}} = ans {gs, grads} = Map.pop(grads, id) - args = - Enum.map(args, &revectorize_node/1) + {args, ans} = + if vectorized_axes != [] do + args = + Enum.map(args, fn + arg when Nx.is_tensor(arg) -> + revectorize_node(arg, vectorized_axes) + + opt -> + opt + end) + + ans = Nx.vectorize(ans, vectorized_axes) + {args, ans} + else + {args, ans} + end case gs do nil -> @@ -205,27 +265,33 @@ defmodule Nx.Defn.Grad do end end - defp revectorize_node(%{shape: shape, vectorized_axes: vectorized_axes, names: names} = node) do - {reverse_vec_names, all_nil_names, reverse_inner_shape} = - Enum.zip_reduce(names, Tuple.to_list(shape), {[], [], []}, fn - nil, axis_size, {v, n, s} -> - {v, [nil | n], [axis_size | s]} + defp compute_arg_vectorized_axes(%{vectorized_axes: vectorized_axes}, []), do: vectorized_axes - name, axis_size, {v, n, s} -> - {[{name, axis_size} | v], n, s} + defp compute_arg_vectorized_axes( + %{vectorized_axes: vectorized_axes, names: names, shape: shape}, + parent_vectorized_axes + ) do + parent_names = Keyword.keys(parent_vectorized_axes) + + reversed_inner_axes = + Enum.zip_reduce(names, Tuple.to_list(shape), [], fn name, axis_size, acc -> + if name in parent_names do + [{name, axis_size} | acc] + else + acc + end end) - inner_shape = reverse_inner_shape |> Enum.reverse() |> List.to_tuple() - - %{ - node - | vectorized_axes: vectorized_axes ++ Enum.reverse(reverse_vec_names), - names: all_nil_names, - shape: inner_shape - } + vectorized_axes ++ Enum.reverse(reversed_inner_axes) end - defp revectorize_node(arg), do: arg + defp revectorize_node(node, vectorized_axes) do + vectorized_axes = compute_arg_vectorized_axes(node, vectorized_axes) + + node + |> Nx.devectorize(keep_names: false) + |> Nx.vectorize(vectorized_axes) + end defp update_grads(:elem, [%{type: {:tuple, size}} = tuple, pos], _ans, g, _to_grad_ids, grads) do update_in(grads[tuple.data.id], fn tuple -> @@ -1359,26 +1425,26 @@ defmodule Nx.Defn.Grad do ## General helpers defp unbroadcast(x, res, ans) do - vectorized_axes_x = Keyword.keys(x.vectorized_axes) ++ Enum.filter(x.names, & &1) - vectorized_axes_ans = Keyword.keys(ans.vectorized_axes) + # vectorized_axes_x = Keyword.keys(x.vectorized_axes) ++ Enum.filter(x.names, & &1) + # vectorized_axes_ans = Keyword.keys(ans.vectorized_axes) - permutation = - vectorized_axes_ans - |> Enum.with_index() - |> Enum.sort_by(fn {axis, _idx} -> axis in vectorized_axes_x end) - |> Enum.map(fn {_, idx} -> idx end) + # permutation = + # vectorized_axes_ans + # |> Enum.with_index() + # |> Enum.sort_by(fn {axis, _idx} -> axis in vectorized_axes_x end) + # |> Enum.map(fn {_, idx} -> idx end) - num_vectorized_axes = length(permutation) + # num_vectorized_axes = length(permutation) - inner_axes = Enum.to_list(num_vectorized_axes..(num_vectorized_axes + Nx.rank(res) - 1)//1) + # inner_axes = Enum.to_list(num_vectorized_axes..(num_vectorized_axes + Nx.rank(res) - 1)//1) - res = - res - |> Nx.devectorize() - |> Nx.transpose(axes: permutation ++ inner_axes) - |> Nx.vectorize(vectorized_axes_x) + # res = + # res + # |> Nx.devectorize() + # |> Nx.transpose(axes: permutation ++ inner_axes) + # |> Nx.vectorize(vectorized_axes_x) - x = x |> Nx.devectorize(keep_names: false) |> Nx.vectorize(vectorized_axes_x) + # x = x |> Nx.devectorize(keep_names: false) |> Nx.vectorize(vectorized_axes_x) grad_broadcast(x, res, ans) end diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index 508379ff976..52e1926735f 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -4248,7 +4248,7 @@ defmodule Nx.Defn.GradTest do end test "supports combination of vectorized and non-vectorized tensors over composed function" do - x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.vectorize(:x) + x = Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) |> Nx.vectorize(:x) y = 1 grad = Nx.Defn.grad(y, fn y -> Nx.add(y, Nx.sin(x)) end) From 24b9ea59a2b76ccd60f9f5143ea09d1e90c06e6e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:27:45 -0300 Subject: [PATCH 16/22] chore: revert some code due to code review --- nx/lib/nx.ex | 9 +-------- nx/lib/nx/defn/expr.ex | 14 +++++++------- nx/lib/nx/defn/grad.ex | 18 +++++++++++++----- nx/test/nx/defn/grad_test.exs | 16 ++++++++++++++++ 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index ed070ef4b06..beaa293d770 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -4906,19 +4906,12 @@ defmodule Nx do def devectorize(%T{shape: shape, names: names, vectorized_axes: vectorized_axes} = tensor, opts) when vectorized_axes != [] do - opts = keyword!(opts, keep_names: true, drop_inner_names: false) + opts = keyword!(opts, keep_names: true) {vectorized_names, vectorized_sizes} = Enum.unzip(vectorized_axes) output_shape_l = vectorized_sizes ++ Tuple.to_list(shape) output_shape = List.to_tuple(output_shape_l) - names = - if opts[:drop_inner_names] do - Enum.map(names, fn _ -> nil end) - else - names - end - output_names = if opts[:keep_names] do vectorized_names ++ names diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index d472834578b..997950a1f3c 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -94,13 +94,13 @@ defmodule Nx.Defn.Expr do def metadata(expr, metadata) when is_map(metadata) do case to_container_expr(expr) do %{data: %{context: context}} = res -> - expr(res, context, :metadata, [expr, metadata]) + expr(res, context, :metadata, [Nx.devectorize(expr), metadata]) t when is_tuple(t) -> context = elem(t, 0).data.context tuple( - expr(tuple_out(tuple_size(t)), context, :metadata, [expr, metadata]), + expr(tuple_out(tuple_size(t)), context, :metadata, [Nx.devectorize(expr), metadata]), Tuple.to_list(t) ) end @@ -1665,11 +1665,11 @@ defmodule Nx.Defn.Expr do defp counter_to_name(counter), do: [?a + counter] - defp to_type_shape(%{vectorized_axes: vectorized_axes, type: type, shape: shape}) do - axes = - Keyword.values(vectorized_axes) ++ Tuple.to_list(shape) - - brackets = Enum.map(axes, &[?[, Integer.to_string(&1), ?]]) + defp to_type_shape(%{vectorized_axes: [], type: type, shape: shape}) do + brackets = + shape + |> Tuple.to_list() + |> Enum.map(&[?[, Integer.to_string(&1), ?]]) IO.iodata_to_binary([Nx.Type.to_string(type) | brackets]) end diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 358b84f38d8..363afb014bc 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -8,10 +8,11 @@ defmodule Nx.Defn.Grad do def transform(to_grad, fun, transform) do {to_grad, ids} = - Composite.traverse(to_grad, %{}, fn node, ids -> - node = Expr.metadata(node, %{__MODULE__ => :to_grad}) - ids = Map.put(ids, node.data.id, :stop) - {node, ids} + Composite.traverse(to_grad, %{}, fn to_grad, ids -> + to_grad = + Expr.metadata(to_grad, %{__MODULE__ => :to_grad}) + + {to_grad, Map.put(ids, to_grad.data.id, :stop)} end) # Collect all IDs in the function environment and mark @@ -31,7 +32,13 @@ defmodule Nx.Defn.Grad do grads = %{transformed_expr.data.id => [constant(1.0, transformed_expr)]} {graded, _} = - Composite.traverse(to_grad, {nodes, grads}, &to_grad(&1, to_grad_ids, parents, &2)) + Composite.traverse( + to_grad, + {nodes, grads}, + fn %{vectorized_axes: vectorized_axes} = node, acc -> + to_grad(node, to_grad_ids, parents, acc) + end + ) {expr, graded} end @@ -96,6 +103,7 @@ defmodule Nx.Defn.Grad do {parents, nodes} %{} -> + # We use this to compute the proper axis sizes for the tensor parent_vectorized_axes = compute_arg_vectorized_axes(t, vectorized_axes) nodes = Map.put(nodes, id, {Nx.devectorize(t, keep_names: true), parent_vectorized_axes}) diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index 52e1926735f..7cd9824cb67 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -4258,6 +4258,22 @@ defmodule Nx.Defn.GradTest do assert grad == Nx.cos(x) end + # Skipping this as it's not supported yet. + @tag :skip + test "edge case where the same name changes meaning" do + x = Nx.tensor([[1], [2], [3]]) |> Nx.vectorize(x: 3) + + grad = + Nx.Defn.grad(x, fn t -> + devec = Nx.devectorize(t, keep_names: true) + new_axis = Nx.reshape(devec, {1, 3, 1}, names: [:x, nil, nil]) + + Nx.vectorize(new_axis, x: 1) + end) + + assert grad == Nx.tensor([[1], [1], [1]]) |> Nx.vectorize(x: 3) + end + test "supports heterogenous vectorization combinations" do x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) y = Nx.tensor([10, 20]) From d0f93c9efe58d0d383aa68667a20b3c13aad0e3d Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:28:39 -0300 Subject: [PATCH 17/22] chore: revert unbroadcast --- nx/lib/nx/defn/grad.ex | 29 ++--------------------------- 1 file changed, 2 insertions(+), 27 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 363afb014bc..07f4acdd557 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -1432,34 +1432,9 @@ defmodule Nx.Defn.Grad do ## General helpers - defp unbroadcast(x, res, ans) do - # vectorized_axes_x = Keyword.keys(x.vectorized_axes) ++ Enum.filter(x.names, & &1) - # vectorized_axes_ans = Keyword.keys(ans.vectorized_axes) + defp unbroadcast(%{shape: shape} = x, res, %{shape: shape}), do: {x, res} - # permutation = - # vectorized_axes_ans - # |> Enum.with_index() - # |> Enum.sort_by(fn {axis, _idx} -> axis in vectorized_axes_x end) - # |> Enum.map(fn {_, idx} -> idx end) - - # num_vectorized_axes = length(permutation) - - # inner_axes = Enum.to_list(num_vectorized_axes..(num_vectorized_axes + Nx.rank(res) - 1)//1) - - # res = - # res - # |> Nx.devectorize() - # |> Nx.transpose(axes: permutation ++ inner_axes) - # |> Nx.vectorize(vectorized_axes_x) - - # x = x |> Nx.devectorize(keep_names: false) |> Nx.vectorize(vectorized_axes_x) - - grad_broadcast(x, res, ans) - end - - defp grad_broadcast(%{shape: shape} = x, res, %{shape: shape}), do: {x, res} - - defp grad_broadcast(%{shape: shape} = x, res, %{shape: new_shape}) do + defp unbroadcast(%{shape: shape} = x, res, %{shape: new_shape}) do axes = Nx.Shape.broadcast_axes(shape, new_shape) {x, grad_broadcast(x, new_shape, axes, res)} end From b89111d5f6bd5912039316be5e8c903336235890 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:33:28 -0300 Subject: [PATCH 18/22] chore: remove some devectorization occurences --- nx/lib/nx/defn/grad.ex | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 07f4acdd557..48a2f3350f7 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -106,7 +106,7 @@ defmodule Nx.Defn.Grad do # We use this to compute the proper axis sizes for the tensor parent_vectorized_axes = compute_arg_vectorized_axes(t, vectorized_axes) - nodes = Map.put(nodes, id, {Nx.devectorize(t, keep_names: true), parent_vectorized_axes}) + nodes = Map.put(nodes, id, {t, parent_vectorized_axes}) parents_args(op, t, id, {parents, nodes}, vectorized_axes) end @@ -134,10 +134,8 @@ defmodule Nx.Defn.Grad do # Now traverse over the optional expression where args are the new parameters. # Once we access the parameter itself, we point the parameter to the arg. {{parents, nodes}, _} = - Composite.reduce( - Nx.devectorize(expr, keep_names: true), - {acc, parent_vectorized_axes}, - fn expr, {{parents, nodes}, expr_vectorized_axes} -> + Composite.reduce(expr, {acc, parent_vectorized_axes}, fn + expr, {{parents, nodes}, expr_vectorized_axes} -> arg_vectorized_axes = compute_arg_vectorized_axes(expr, expr_vectorized_axes) parents = Map.update(parents, expr.data.id, [id], &[id | &1]) @@ -149,12 +147,10 @@ defmodule Nx.Defn.Grad do ) {acc, expr_vectorized_axes} - end - ) + end) updated_node = - {put_in(t.data.args, [call, expr, callback]) |> Nx.devectorize(keep_names: true), - parent_vectorized_axes} + {put_in(t.data.args, [call, expr, callback]), parent_vectorized_axes} {parents, Map.put(nodes, id, updated_node)} end @@ -173,11 +169,7 @@ defmodule Nx.Defn.Grad do arg_vectorized_axes = compute_arg_vectorized_axes(t, parent_vectorized_axes) parents = Map.update(parents, arg.data.id, [parent_id], &[parent_id | &1]) - recur_parents_tree( - Nx.devectorize(arg, keep_names: true), - {parents, nodes}, - arg_vectorized_axes - ) + recur_parents_tree(arg, {parents, nodes}, arg_vectorized_axes) end end) end From 9075ef0e3032f216c5bf594f13e8cc3149ffe578 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:49:01 -0300 Subject: [PATCH 19/22] chore: simplify vectorized axes calculation --- nx/lib/nx/defn/grad.ex | 75 +++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 48a2f3350f7..7177db183e3 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -4,8 +4,6 @@ defmodule Nx.Defn.Grad do alias Nx.Defn.{Composite, Expr, Tree} alias Nx.Tensor, as: T - require Nx - def transform(to_grad, fun, transform) do {to_grad, ids} = Composite.traverse(to_grad, %{}, fn to_grad, ids -> @@ -35,7 +33,7 @@ defmodule Nx.Defn.Grad do Composite.traverse( to_grad, {nodes, grads}, - fn %{vectorized_axes: vectorized_axes} = node, acc -> + fn node, acc -> to_grad(node, to_grad_ids, parents, acc) end ) @@ -93,22 +91,24 @@ defmodule Nx.Defn.Grad do Composite.reduce( expr, {%{}, nodes}, - &recur_parents_tree(Nx.devectorize(&1, keep_names: true), &2, &1.vectorized_axes) + &recur_parents_tree( + Nx.devectorize(&1, keep_names: true), + &2, + Keyword.keys(&1.vectorized_axes) + ) ) end - defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}, vectorized_axes) do + defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}, vectorized_names) do case nodes do %{^id => _} -> {parents, nodes} %{} -> # We use this to compute the proper axis sizes for the tensor - parent_vectorized_axes = compute_arg_vectorized_axes(t, vectorized_axes) + nodes = Map.put(nodes, id, {t, vectorized_names}) - nodes = Map.put(nodes, id, {t, parent_vectorized_axes}) - - parents_args(op, t, id, {parents, nodes}, vectorized_axes) + parents_args(op, t, id, {parents, nodes}, vectorized_names) end end @@ -117,7 +117,7 @@ defmodule Nx.Defn.Grad do %{data: %{args: [_, %{stop_grad: true}]}}, _id, acc, - _parent_vectorized_axes + _parent_vectorized_names ) do acc end @@ -127,49 +127,49 @@ defmodule Nx.Defn.Grad do %{data: %{args: [call, _expr, callback]}} = t, id, acc, - parent_vectorized_axes + parent_vectorized_names ) do expr = apply(callback, call.data.args) # Now traverse over the optional expression where args are the new parameters. # Once we access the parameter itself, we point the parameter to the arg. {{parents, nodes}, _} = - Composite.reduce(expr, {acc, parent_vectorized_axes}, fn - expr, {{parents, nodes}, expr_vectorized_axes} -> - arg_vectorized_axes = compute_arg_vectorized_axes(expr, expr_vectorized_axes) + Composite.reduce(expr, {acc, parent_vectorized_names}, fn + expr, {{parents, nodes}, expr_vectorized_names} -> + arg_vectorized_names = compute_arg_vectorized_names(expr, expr_vectorized_names) parents = Map.update(parents, expr.data.id, [id], &[id | &1]) acc = recur_parents_tree( expr, {parents, nodes}, - arg_vectorized_axes + arg_vectorized_names ) - {acc, expr_vectorized_axes} + {acc, expr_vectorized_names} end) updated_node = - {put_in(t.data.args, [call, expr, callback]), parent_vectorized_axes} + {put_in(t.data.args, [call, expr, callback]), parent_vectorized_names} {parents, Map.put(nodes, id, updated_node)} end # We register cond as a special node to avoid pretraversing it. # Instead we traverse it early on on the grad computation. - defp parents_args(:cond, _, id, {parents, nodes}, _parent_vectorized_axes) do + defp parents_args(:cond, _, id, {parents, nodes}, _parent_vectorized_names) do {Map.update(parents, __MODULE__, [id], &[id | &1]), nodes} end - defp parents_args(op, t, parent_id, acc, parent_vectorized_axes) do + defp parents_args(op, t, parent_id, acc, parent_vectorized_names) do reduce_args(op, t, acc, fn arg, {parents, nodes} -> if arg.data.op in @constants do {parents, nodes} else - arg_vectorized_axes = compute_arg_vectorized_axes(t, parent_vectorized_axes) + arg_vectorized_names = compute_arg_vectorized_names(t, parent_vectorized_names) parents = Map.update(parents, arg.data.id, [parent_id], &[parent_id | &1]) - recur_parents_tree(arg, {parents, nodes}, arg_vectorized_axes) + recur_parents_tree(arg, {parents, nodes}, arg_vectorized_names) end end) end @@ -226,22 +226,22 @@ defmodule Nx.Defn.Grad do case nodes do %{^id => _} -> {nodes, grads} = traverse_parents(id, to_grad_ids, parents, {nodes, grads}) - {{ans, vectorized_axes}, nodes} = Map.pop!(nodes, id) + {{ans, vectorized_names}, nodes} = Map.pop!(nodes, id) %T{data: %Expr{op: op, args: args}} = ans {gs, grads} = Map.pop(grads, id) {args, ans} = - if vectorized_axes != [] do + if vectorized_names != [] do args = Enum.map(args, fn - arg when Nx.is_tensor(arg) -> - revectorize_node(arg, vectorized_axes) + %T{} = arg -> + revectorize_node(arg, vectorized_names) opt -> opt end) - ans = Nx.vectorize(ans, vectorized_axes) + ans = Nx.vectorize(ans, vectorized_names) {args, ans} else {args, ans} @@ -265,32 +265,31 @@ defmodule Nx.Defn.Grad do end end - defp compute_arg_vectorized_axes(%{vectorized_axes: vectorized_axes}, []), do: vectorized_axes + defp compute_arg_vectorized_names(%{vectorized_axes: vectorized_axes}, []), + do: Keyword.keys(vectorized_axes) - defp compute_arg_vectorized_axes( - %{vectorized_axes: vectorized_axes, names: names, shape: shape}, - parent_vectorized_axes + defp compute_arg_vectorized_names( + %{vectorized_axes: vectorized_axes, names: names}, + parent_names ) do - parent_names = Keyword.keys(parent_vectorized_axes) - reversed_inner_axes = - Enum.zip_reduce(names, Tuple.to_list(shape), [], fn name, axis_size, acc -> + Enum.reduce(names, [], fn name, acc -> if name in parent_names do - [{name, axis_size} | acc] + [name | acc] else acc end end) - vectorized_axes ++ Enum.reverse(reversed_inner_axes) + Keyword.keys(vectorized_axes) ++ Enum.reverse(reversed_inner_axes) end - defp revectorize_node(node, vectorized_axes) do - vectorized_axes = compute_arg_vectorized_axes(node, vectorized_axes) + defp revectorize_node(node, vectorized_names) do + vectorized_names = compute_arg_vectorized_names(node, vectorized_names) node |> Nx.devectorize(keep_names: false) - |> Nx.vectorize(vectorized_axes) + |> Nx.vectorize(vectorized_names) end defp update_grads(:elem, [%{type: {:tuple, size}} = tuple, pos], _ans, g, _to_grad_ids, grads) do From fcc4e104e28b10edd22f83889c217cd8e2e154b2 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:51:04 -0300 Subject: [PATCH 20/22] chore: remove another superfluous devectorize --- nx/lib/nx/defn/grad.ex | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 7177db183e3..5c7b9beb6db 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -287,9 +287,7 @@ defmodule Nx.Defn.Grad do defp revectorize_node(node, vectorized_names) do vectorized_names = compute_arg_vectorized_names(node, vectorized_names) - node - |> Nx.devectorize(keep_names: false) - |> Nx.vectorize(vectorized_names) + Nx.vectorize(node, vectorized_names) end defp update_grads(:elem, [%{type: {:tuple, size}} = tuple, pos], _ans, g, _to_grad_ids, grads) do From add013488c334f90dc7855541f1126904ef044f9 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 17:41:42 -0300 Subject: [PATCH 21/22] Update nx/lib/nx/defn/grad.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- nx/lib/nx/defn/grad.ex | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 14eac0b1df2..fcef50b0d62 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -272,16 +272,7 @@ defmodule Nx.Defn.Grad do %{vectorized_axes: vectorized_axes, names: names}, parent_names ) do - reversed_inner_axes = - Enum.reduce(names, [], fn name, acc -> - if name in parent_names do - [name | acc] - else - acc - end - end) - - Keyword.keys(vectorized_axes) ++ Enum.reverse(reversed_inner_axes) + Keyword.keys(vectorized_axes) ++ Enum.filter(names, & &1 in parent_names) end defp revectorize_node(node, vectorized_names) do From 8d94bc0980116e1f789b1fcbd80688b9a26e7a38 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 24 Sep 2024 17:52:23 -0300 Subject: [PATCH 22/22] chore: format --- nx/lib/nx/defn/grad.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index fcef50b0d62..33a7a0deed1 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -272,7 +272,7 @@ defmodule Nx.Defn.Grad do %{vectorized_axes: vectorized_axes, names: names}, parent_names ) do - Keyword.keys(vectorized_axes) ++ Enum.filter(names, & &1 in parent_names) + Keyword.keys(vectorized_axes) ++ Enum.filter(names, &(&1 in parent_names)) end defp revectorize_node(node, vectorized_names) do