Skip to content

Commit

Permalink
fix: broadcast vectors for grad calculation (#1535)
Browse files Browse the repository at this point in the history
Co-authored-by: José Valim <[email protected]>
  • Loading branch information
polvalente and josevalim authored Sep 24, 2024
1 parent 8102cd9 commit 762d3ee
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 37 deletions.
10 changes: 7 additions & 3 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5420,9 +5420,13 @@ defmodule Nx do
{_, [], 0} ->
fun.(left, right)

{[devec_left, devec_right], canonical_vectorized_axes, _offset} ->
devec_left
|> fun.(devec_right)
{[devec_left, devec_right], canonical_vectorized_axes, offset} ->
leading_names = Keyword.keys(canonical_vectorized_axes)
l = %{devec_left | names: leading_names ++ Enum.drop(devec_left.names, offset)}
r = %{devec_right | names: leading_names ++ Enum.drop(devec_right.names, offset)}

l
|> fun.(r)
|> vectorize(canonical_vectorized_axes)
end
end
Expand Down
6 changes: 5 additions & 1 deletion nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,10 @@ defmodule Nx.Defn.Expr do

## Constant helpers and related optimizations

defp constant(%{vectorized_axes: [_ | _]} = out, number) do
tensor(Nx.fill(out, number, type: out.type))
end

defp constant(%{shape: shape, type: type} = out, number) do
number =
cond do
Expand Down Expand Up @@ -1661,7 +1665,7 @@ defmodule Nx.Defn.Expr do

defp counter_to_name(counter), do: [?a + counter]

defp to_type_shape(%{type: type, shape: shape}) do
defp to_type_shape(%{vectorized_axes: [], type: type, shape: shape}) do
brackets =
shape
|> Tuple.to_list()
Expand Down
124 changes: 96 additions & 28 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ defmodule Nx.Defn.Grad do
{:env, env} = Function.info(fun, :env)
ids = stop_grads(env, ids)

# save vectorized axes before devectorizing
expr = to_grad |> fun.()
expr = fun.(to_grad)

transformed_expr = transform.(expr) |> validate_expr!() |> Nx.devectorize(keep_names: false)
transformed_expr =
expr |> transform.() |> validate_expr!()

{parents, nodes} = parents_tree(transformed_expr, ids)

Expand All @@ -33,23 +33,17 @@ defmodule Nx.Defn.Grad do
Composite.traverse(
to_grad,
{nodes, grads},
fn %{vectorized_axes: vectorized_axes} = node, acc ->
node
|> Nx.devectorize(keep_names: false)
|> to_grad(to_grad_ids, parents, acc)
|> then(fn {node, acc} ->
{Nx.vectorize(node, vectorized_axes), acc}
end)
fn node, acc ->
to_grad(node, to_grad_ids, parents, acc)
end
)

{expr, graded}
end

defp constant(float, shape) do
shape = Nx.shape(shape)
defp constant(float, %T{shape: shape} = t) do
names = List.duplicate(nil, tuple_size(shape))
Expr.constant(%T{shape: shape, type: {:f, 32}, names: names}, float, [])
Expr.constant(%T{t | names: names, type: {:f, 32}}, float, [])
end

defp validate_expr!(%T{data: %Expr{}} = expr) do
Expand Down Expand Up @@ -94,47 +88,88 @@ defmodule Nx.Defn.Grad do
[:equal, :greater, :greater_equal, :less, :less_equal, :not_equal, :argsort]

defp parents_tree(expr, nodes) do
Composite.reduce(expr, {%{}, nodes}, &recur_parents_tree/2)
Composite.reduce(
expr,
{%{}, nodes},
&recur_parents_tree(
Nx.devectorize(&1, keep_names: true),
&2,
Keyword.keys(&1.vectorized_axes)
)
)
end

defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}) do
defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}, vectorized_names) do
case nodes do
%{^id => _} -> {parents, nodes}
%{} -> parents_args(op, t, id, {parents, Map.put(nodes, id, t)})
%{^id => _} ->
{parents, nodes}

%{} ->
# We use this to compute the proper axis sizes for the tensor
nodes = Map.put(nodes, id, {t, vectorized_names})

parents_args(op, t, id, {parents, nodes}, vectorized_names)
end
end

defp parents_args(:metadata, %{data: %{args: [_, %{stop_grad: true}]}}, _id, acc) do
defp parents_args(
:metadata,
%{data: %{args: [_, %{stop_grad: true}]}},
_id,
acc,
_parent_vectorized_names
) do
acc
end

defp parents_args(:optional, %{data: %{args: [call, _expr, callback]}} = t, id, acc) do
defp parents_args(
:optional,
%{data: %{args: [call, _expr, callback]}} = t,
id,
acc,
parent_vectorized_names
) do
expr = apply(callback, call.data.args)

# Now traverse over the optional expression where args are the new parameters.
# Once we access the parameter itself, we point the parameter to the arg.
{parents, nodes} =
Composite.reduce(expr, acc, fn expr, {parents, nodes} ->
parents = Map.update(parents, expr.data.id, [id], &[id | &1])
recur_parents_tree(expr, {parents, nodes})
{{parents, nodes}, _} =
Composite.reduce(expr, {acc, parent_vectorized_names}, fn
expr, {{parents, nodes}, expr_vectorized_names} ->
arg_vectorized_names = compute_arg_vectorized_names(expr, expr_vectorized_names)
parents = Map.update(parents, expr.data.id, [id], &[id | &1])

acc =
recur_parents_tree(
expr,
{parents, nodes},
arg_vectorized_names
)

{acc, expr_vectorized_names}
end)

{parents, Map.put(nodes, id, put_in(t.data.args, [call, expr, callback]))}
updated_node =
{put_in(t.data.args, [call, expr, callback]), parent_vectorized_names}

{parents, Map.put(nodes, id, updated_node)}
end

# We register cond as a special node to avoid pretraversing it.
# Instead we traverse it early on on the grad computation.
defp parents_args(:cond, _, id, {parents, nodes}) do
defp parents_args(:cond, _, id, {parents, nodes}, _parent_vectorized_names) do
{Map.update(parents, __MODULE__, [id], &[id | &1]), nodes}
end

defp parents_args(op, t, parent_id, acc) do
defp parents_args(op, t, parent_id, acc, parent_vectorized_names) do
reduce_args(op, t, acc, fn arg, {parents, nodes} ->
if arg.data.op in @constants do
{parents, nodes}
else
arg_vectorized_names = compute_arg_vectorized_names(t, parent_vectorized_names)
parents = Map.update(parents, arg.data.id, [parent_id], &[parent_id | &1])
recur_parents_tree(arg, {parents, nodes})

recur_parents_tree(arg, {parents, nodes}, arg_vectorized_names)
end
end)
end
Expand Down Expand Up @@ -191,10 +226,27 @@ defmodule Nx.Defn.Grad do
case nodes do
%{^id => _} ->
{nodes, grads} = traverse_parents(id, to_grad_ids, parents, {nodes, grads})
{ans, nodes} = Map.pop!(nodes, id)
{{ans, vectorized_names}, nodes} = Map.pop!(nodes, id)
%T{data: %Expr{op: op, args: args}} = ans
{gs, grads} = Map.pop(grads, id)

{args, ans} =
if vectorized_names != [] do
args =
Enum.map(args, fn
%T{} = arg ->
revectorize_node(arg, vectorized_names)

opt ->
opt
end)

ans = Nx.vectorize(ans, vectorized_names)
{args, ans}
else
{args, ans}
end

case gs do
nil ->
{nodes, grads}
Expand All @@ -213,6 +265,22 @@ defmodule Nx.Defn.Grad do
end
end

defp compute_arg_vectorized_names(%{vectorized_axes: vectorized_axes}, []),
do: Keyword.keys(vectorized_axes)

defp compute_arg_vectorized_names(
%{vectorized_axes: vectorized_axes, names: names},
parent_names
) do
Keyword.keys(vectorized_axes) ++ Enum.filter(names, &(&1 in parent_names))
end

defp revectorize_node(node, vectorized_names) do
vectorized_names = compute_arg_vectorized_names(node, vectorized_names)

Nx.vectorize(node, vectorized_names)
end

defp update_grads(:elem, [%{type: {:tuple, size}} = tuple, pos], _ans, g, _to_grad_ids, grads) do
update_in(grads[tuple.data.id], fn tuple ->
tuple = tuple || Tuple.duplicate([], size)
Expand Down
86 changes: 81 additions & 5 deletions nx/test/nx/defn/grad_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4256,21 +4256,97 @@ defmodule Nx.Defn.GradTest do
end

describe "vectorization" do
test "supports vectorization" do
test "supports combination of vectorized and non-vectorized tensors" do
x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.vectorize(:x)
y = 1

grad = Nx.Defn.grad(y, fn y -> Nx.add(x, y) end)

assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x])
end

test "supports combination of vectorized and non-vectorized tensors over composed function" do
x = Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) |> Nx.vectorize(:x)
y = 1

grad = Nx.Defn.grad(y, fn y -> Nx.add(y, Nx.sin(x)) end)
assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x])

grad = Nx.Defn.grad(x, fn x -> Nx.add(y, Nx.sin(x)) end)
assert grad == Nx.cos(x)
end

# Skipping this as it's not supported yet.
@tag :skip
test "edge case where the same name changes meaning" do
x = Nx.tensor([[1], [2], [3]]) |> Nx.vectorize(x: 3)

grad =
Nx.Defn.grad(x, fn t ->
devec = Nx.devectorize(t, keep_names: true)
new_axis = Nx.reshape(devec, {1, 3, 1}, names: [:x, nil, nil])

Nx.vectorize(new_axis, x: 1)
end)

assert grad == Nx.tensor([[1], [1], [1]]) |> Nx.vectorize(x: 3)
end

test "supports heterogenous vectorization combinations" do
x = Nx.tensor([[1, 2, 3], [4, 5, 6]])
y = Nx.tensor([10, 20])

# first case: y is vectorized scalar, x is vectorized vectors, different vectorized axis names
# expected result: equivalent to fully broadcasting one tensor onto the other
x_vec = Nx.vectorize(x, :x)
y_vec = Nx.vectorize(y, :y)
{grad_x_vec, grad_y_vec} = Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.multiply(a, b) end)

grad_fun = fn x, y ->
Nx.Defn.grad({x, y}, fn {a, b} -> Nx.multiply(a, b) end)
end

{grad_x_vec, grad_y_vec} = grad_fun.(x_vec, y_vec)

# Explicit assertion on the results
assert grad_x_vec ==
Nx.tensor([[30.0, 30.0, 30.0], [30.0, 30.0, 30.0]])
|> Nx.vectorize(x_vec.vectorized_axes)
Nx.tensor([
[
[10.0, 10.0, 10.0],
[20.0, 20.0, 20.0]
],
[
[10.0, 10.0, 10.0],
[20.0, 20.0, 20.0]
]
])
|> Nx.vectorize([:x, :y])

assert grad_y_vec ==
Nx.tensor([
[6.0, 6.0],
[15.0, 15.0]
])
|> Nx.vectorize([:x, :y])

assert grad_y_vec == Nx.tensor([21.0, 21.0]) |> Nx.vectorize(y_vec.vectorized_axes)
# Conceptual assertion: the result should be equivalent to calling Nx.Defn.grad with
# each cross-entry of the combined vectors [(x0, y0), (x0, y1), (x1, y0), (x1, y1)]

{x0y0_wrt_x, x0y0_wrt_y} = grad_fun.(x[0], y[0])
{x0y1_wrt_x, x0y1_wrt_y} = grad_fun.(x[0], y[1])
{x1y0_wrt_x, x1y0_wrt_y} = grad_fun.(x[1], y[0])
{x1y1_wrt_x, x1y1_wrt_y} = grad_fun.(x[1], y[1])

assert grad_x_vec ==
[x0y0_wrt_x, x0y1_wrt_x, x1y0_wrt_x, x1y1_wrt_x]
|> Nx.stack()
|> Nx.reshape({2, 2, 3})
|> Nx.vectorize([:x, :y])

assert grad_y_vec ==
[x0y0_wrt_y, x0y1_wrt_y, x1y0_wrt_y, x1y1_wrt_y]
|> Nx.stack()
|> Nx.reshape({2, 2})
|> Nx.vectorize([:x, :y])

# second case: y is vectorized scalar, x is vectorized vectors, same vectorized axis name
# expected result: equivalent to "row-wise" broadcasting
Expand Down

0 comments on commit 762d3ee

Please sign in to comment.