Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: broadcast vectors for grad calculation #1535

Merged
merged 23 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
394a12d
fix: broadcast vectors for grad calculation
polvalente Sep 15, 2024
414726b
fix attempt
polvalente Sep 15, 2024
a08d0fd
test: make core tests pass
polvalente Sep 16, 2024
2f7c5f1
fix: inspect vectorized axes as usual
polvalente Sep 16, 2024
d87ffa1
chore: revert some changes
polvalente Sep 16, 2024
7fbdffd
chore: remove commented code
polvalente Sep 16, 2024
db0b6f0
chore: remove stray comments
polvalente Sep 16, 2024
20cc168
chore: remove more stray comments
polvalente Sep 16, 2024
22b9a24
refactor: support vectorized constant
polvalente Sep 16, 2024
8f60a71
test: add x * sin(y) grad test
polvalente Sep 16, 2024
a026c0e
feat: revectorize args only when strictly necessary
polvalente Sep 24, 2024
37408b3
fix: correctness of revectorize_args over possible kw_args functions
polvalente Sep 24, 2024
380c330
refactor: simpler revectorization of nodes
polvalente Sep 24, 2024
1487b17
refactor: revectorize in place
polvalente Sep 24, 2024
a832956
fix: propagate vectorized axes throughout the recursion'
polvalente Sep 24, 2024
24b9ea5
chore: revert some code due to code review
polvalente Sep 24, 2024
d0f93c9
chore: revert unbroadcast
polvalente Sep 24, 2024
b89111d
chore: remove some devectorization occurences
polvalente Sep 24, 2024
9075ef0
chore: simplify vectorized axes calculation
polvalente Sep 24, 2024
fcc4e10
chore: remove another superfluous devectorize
polvalente Sep 24, 2024
affdc90
Merge branch 'main' into pv-fix/vectorized-grad
polvalente Sep 24, 2024
add0134
Update nx/lib/nx/defn/grad.ex
polvalente Sep 24, 2024
8d94bc0
chore: format
polvalente Sep 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5420,9 +5420,13 @@ defmodule Nx do
{_, [], 0} ->
fun.(left, right)

{[devec_left, devec_right], canonical_vectorized_axes, _offset} ->
devec_left
|> fun.(devec_right)
{[devec_left, devec_right], canonical_vectorized_axes, offset} ->
leading_names = Keyword.keys(canonical_vectorized_axes)
l = %{devec_left | names: leading_names ++ Enum.drop(devec_left.names, offset)}
r = %{devec_right | names: leading_names ++ Enum.drop(devec_right.names, offset)}

l
|> fun.(r)
|> vectorize(canonical_vectorized_axes)
end
end
Expand Down
6 changes: 5 additions & 1 deletion nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,10 @@ defmodule Nx.Defn.Expr do

## Constant helpers and related optimizations

defp constant(%{vectorized_axes: [_ | _]} = out, number) do
tensor(Nx.fill(out, number, type: out.type))
end

defp constant(%{shape: shape, type: type} = out, number) do
number =
cond do
Expand Down Expand Up @@ -1661,7 +1665,7 @@ defmodule Nx.Defn.Expr do

defp counter_to_name(counter), do: [?a + counter]

defp to_type_shape(%{type: type, shape: shape}) do
defp to_type_shape(%{vectorized_axes: [], type: type, shape: shape}) do
brackets =
shape
|> Tuple.to_list()
Expand Down
124 changes: 96 additions & 28 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ defmodule Nx.Defn.Grad do
{:env, env} = Function.info(fun, :env)
ids = stop_grads(env, ids)

# save vectorized axes before devectorizing
expr = to_grad |> fun.()
expr = fun.(to_grad)

transformed_expr = transform.(expr) |> validate_expr!() |> Nx.devectorize(keep_names: false)
transformed_expr =
expr |> transform.() |> validate_expr!()

{parents, nodes} = parents_tree(transformed_expr, ids)

Expand All @@ -33,23 +33,17 @@ defmodule Nx.Defn.Grad do
Composite.traverse(
to_grad,
{nodes, grads},
fn %{vectorized_axes: vectorized_axes} = node, acc ->
node
|> Nx.devectorize(keep_names: false)
|> to_grad(to_grad_ids, parents, acc)
|> then(fn {node, acc} ->
{Nx.vectorize(node, vectorized_axes), acc}
end)
fn node, acc ->
to_grad(node, to_grad_ids, parents, acc)
end
)

{expr, graded}
end

defp constant(float, shape) do
shape = Nx.shape(shape)
defp constant(float, %T{shape: shape} = t) do
names = List.duplicate(nil, tuple_size(shape))
Expr.constant(%T{shape: shape, type: {:f, 32}, names: names}, float, [])
Expr.constant(%T{t | names: names, type: {:f, 32}}, float, [])
end

defp validate_expr!(%T{data: %Expr{}} = expr) do
Expand Down Expand Up @@ -94,47 +88,88 @@ defmodule Nx.Defn.Grad do
[:equal, :greater, :greater_equal, :less, :less_equal, :not_equal, :argsort]

defp parents_tree(expr, nodes) do
Composite.reduce(expr, {%{}, nodes}, &recur_parents_tree/2)
Composite.reduce(
expr,
{%{}, nodes},
&recur_parents_tree(
Nx.devectorize(&1, keep_names: true),
&2,
Keyword.keys(&1.vectorized_axes)
)
)
end

defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}) do
defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}, vectorized_names) do
case nodes do
%{^id => _} -> {parents, nodes}
%{} -> parents_args(op, t, id, {parents, Map.put(nodes, id, t)})
%{^id => _} ->
{parents, nodes}

%{} ->
# We use this to compute the proper axis sizes for the tensor
nodes = Map.put(nodes, id, {t, vectorized_names})

parents_args(op, t, id, {parents, nodes}, vectorized_names)
end
end

defp parents_args(:metadata, %{data: %{args: [_, %{stop_grad: true}]}}, _id, acc) do
defp parents_args(
:metadata,
%{data: %{args: [_, %{stop_grad: true}]}},
_id,
acc,
_parent_vectorized_names
) do
acc
end

defp parents_args(:optional, %{data: %{args: [call, _expr, callback]}} = t, id, acc) do
defp parents_args(
:optional,
%{data: %{args: [call, _expr, callback]}} = t,
id,
acc,
parent_vectorized_names
) do
expr = apply(callback, call.data.args)

# Now traverse over the optional expression where args are the new parameters.
# Once we access the parameter itself, we point the parameter to the arg.
{parents, nodes} =
Composite.reduce(expr, acc, fn expr, {parents, nodes} ->
parents = Map.update(parents, expr.data.id, [id], &[id | &1])
recur_parents_tree(expr, {parents, nodes})
{{parents, nodes}, _} =
Composite.reduce(expr, {acc, parent_vectorized_names}, fn
expr, {{parents, nodes}, expr_vectorized_names} ->
arg_vectorized_names = compute_arg_vectorized_names(expr, expr_vectorized_names)
parents = Map.update(parents, expr.data.id, [id], &[id | &1])

acc =
recur_parents_tree(
expr,
{parents, nodes},
arg_vectorized_names
)

{acc, expr_vectorized_names}
end)

{parents, Map.put(nodes, id, put_in(t.data.args, [call, expr, callback]))}
updated_node =
{put_in(t.data.args, [call, expr, callback]), parent_vectorized_names}

{parents, Map.put(nodes, id, updated_node)}
end

# We register cond as a special node to avoid pretraversing it.
# Instead we traverse it early on on the grad computation.
defp parents_args(:cond, _, id, {parents, nodes}) do
defp parents_args(:cond, _, id, {parents, nodes}, _parent_vectorized_names) do
{Map.update(parents, __MODULE__, [id], &[id | &1]), nodes}
end

defp parents_args(op, t, parent_id, acc) do
defp parents_args(op, t, parent_id, acc, parent_vectorized_names) do
reduce_args(op, t, acc, fn arg, {parents, nodes} ->
if arg.data.op in @constants do
{parents, nodes}
else
arg_vectorized_names = compute_arg_vectorized_names(t, parent_vectorized_names)
parents = Map.update(parents, arg.data.id, [parent_id], &[parent_id | &1])
recur_parents_tree(arg, {parents, nodes})

recur_parents_tree(arg, {parents, nodes}, arg_vectorized_names)
end
end)
end
Expand Down Expand Up @@ -191,10 +226,27 @@ defmodule Nx.Defn.Grad do
case nodes do
%{^id => _} ->
{nodes, grads} = traverse_parents(id, to_grad_ids, parents, {nodes, grads})
{ans, nodes} = Map.pop!(nodes, id)
{{ans, vectorized_names}, nodes} = Map.pop!(nodes, id)
%T{data: %Expr{op: op, args: args}} = ans
{gs, grads} = Map.pop(grads, id)

{args, ans} =
if vectorized_names != [] do
args =
Enum.map(args, fn
%T{} = arg ->
revectorize_node(arg, vectorized_names)

opt ->
opt
end)

ans = Nx.vectorize(ans, vectorized_names)
{args, ans}
else
{args, ans}
end

case gs do
nil ->
{nodes, grads}
Expand All @@ -213,6 +265,22 @@ defmodule Nx.Defn.Grad do
end
end

defp compute_arg_vectorized_names(%{vectorized_axes: vectorized_axes}, []),
do: Keyword.keys(vectorized_axes)

defp compute_arg_vectorized_names(
%{vectorized_axes: vectorized_axes, names: names},
parent_names
) do
Keyword.keys(vectorized_axes) ++ Enum.filter(names, &(&1 in parent_names))
end

defp revectorize_node(node, vectorized_names) do
vectorized_names = compute_arg_vectorized_names(node, vectorized_names)

Nx.vectorize(node, vectorized_names)
end

defp update_grads(:elem, [%{type: {:tuple, size}} = tuple, pos], _ans, g, _to_grad_ids, grads) do
update_in(grads[tuple.data.id], fn tuple ->
tuple = tuple || Tuple.duplicate([], size)
Expand Down
86 changes: 81 additions & 5 deletions nx/test/nx/defn/grad_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4256,21 +4256,97 @@ defmodule Nx.Defn.GradTest do
end

describe "vectorization" do
test "supports vectorization" do
test "supports combination of vectorized and non-vectorized tensors" do
x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.vectorize(:x)
y = 1

grad = Nx.Defn.grad(y, fn y -> Nx.add(x, y) end)

assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x])
end

test "supports combination of vectorized and non-vectorized tensors over composed function" do
x = Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) |> Nx.vectorize(:x)
y = 1

grad = Nx.Defn.grad(y, fn y -> Nx.add(y, Nx.sin(x)) end)
assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x])

grad = Nx.Defn.grad(x, fn x -> Nx.add(y, Nx.sin(x)) end)
assert grad == Nx.cos(x)
end

# Skipping this as it's not supported yet.
@tag :skip
test "edge case where the same name changes meaning" do
x = Nx.tensor([[1], [2], [3]]) |> Nx.vectorize(x: 3)

grad =
Nx.Defn.grad(x, fn t ->
devec = Nx.devectorize(t, keep_names: true)
new_axis = Nx.reshape(devec, {1, 3, 1}, names: [:x, nil, nil])

Nx.vectorize(new_axis, x: 1)
end)

assert grad == Nx.tensor([[1], [1], [1]]) |> Nx.vectorize(x: 3)
end

test "supports heterogenous vectorization combinations" do
x = Nx.tensor([[1, 2, 3], [4, 5, 6]])
y = Nx.tensor([10, 20])

# first case: y is vectorized scalar, x is vectorized vectors, different vectorized axis names
# expected result: equivalent to fully broadcasting one tensor onto the other
x_vec = Nx.vectorize(x, :x)
y_vec = Nx.vectorize(y, :y)
{grad_x_vec, grad_y_vec} = Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.multiply(a, b) end)

grad_fun = fn x, y ->
Nx.Defn.grad({x, y}, fn {a, b} -> Nx.multiply(a, b) end)
end

{grad_x_vec, grad_y_vec} = grad_fun.(x_vec, y_vec)

# Explicit assertion on the results
assert grad_x_vec ==
Nx.tensor([[30.0, 30.0, 30.0], [30.0, 30.0, 30.0]])
|> Nx.vectorize(x_vec.vectorized_axes)
Nx.tensor([
[
[10.0, 10.0, 10.0],
[20.0, 20.0, 20.0]
],
[
[10.0, 10.0, 10.0],
[20.0, 20.0, 20.0]
]
])
|> Nx.vectorize([:x, :y])

assert grad_y_vec ==
Nx.tensor([
[6.0, 6.0],
[15.0, 15.0]
])
|> Nx.vectorize([:x, :y])

assert grad_y_vec == Nx.tensor([21.0, 21.0]) |> Nx.vectorize(y_vec.vectorized_axes)
# Conceptual assertion: the result should be equivalent to calling Nx.Defn.grad with
# each cross-entry of the combined vectors [(x0, y0), (x0, y1), (x1, y0), (x1, y1)]

{x0y0_wrt_x, x0y0_wrt_y} = grad_fun.(x[0], y[0])
{x0y1_wrt_x, x0y1_wrt_y} = grad_fun.(x[0], y[1])
{x1y0_wrt_x, x1y0_wrt_y} = grad_fun.(x[1], y[0])
{x1y1_wrt_x, x1y1_wrt_y} = grad_fun.(x[1], y[1])

assert grad_x_vec ==
[x0y0_wrt_x, x0y1_wrt_x, x1y0_wrt_x, x1y1_wrt_x]
|> Nx.stack()
|> Nx.reshape({2, 2, 3})
|> Nx.vectorize([:x, :y])

assert grad_y_vec ==
[x0y0_wrt_y, x0y1_wrt_y, x1y0_wrt_y, x1y1_wrt_y]
|> Nx.stack()
|> Nx.reshape({2, 2})
|> Nx.vectorize([:x, :y])

# second case: y is vectorized scalar, x is vectorized vectors, same vectorized axis name
# expected result: equivalent to "row-wise" broadcasting
Expand Down
Loading