Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: broadcast vectors for grad calculation #1535

Merged
merged 23 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
394a12d
fix: broadcast vectors for grad calculation
polvalente Sep 15, 2024
414726b
fix attempt
polvalente Sep 15, 2024
a08d0fd
test: make core tests pass
polvalente Sep 16, 2024
2f7c5f1
fix: inspect vectorized axes as usual
polvalente Sep 16, 2024
d87ffa1
chore: revert some changes
polvalente Sep 16, 2024
7fbdffd
chore: remove commented code
polvalente Sep 16, 2024
db0b6f0
chore: remove stray comments
polvalente Sep 16, 2024
20cc168
chore: remove more stray comments
polvalente Sep 16, 2024
22b9a24
refactor: support vectorized constant
polvalente Sep 16, 2024
8f60a71
test: add x * sin(y) grad test
polvalente Sep 16, 2024
a026c0e
feat: revectorize args only when strictly necessary
polvalente Sep 24, 2024
37408b3
fix: correctness of revectorize_args over possible kw_args functions
polvalente Sep 24, 2024
380c330
refactor: simpler revectorization of nodes
polvalente Sep 24, 2024
1487b17
refactor: revectorize in place
polvalente Sep 24, 2024
a832956
fix: propagate vectorized axes throughout the recursion'
polvalente Sep 24, 2024
24b9ea5
chore: revert some code due to code review
polvalente Sep 24, 2024
d0f93c9
chore: revert unbroadcast
polvalente Sep 24, 2024
b89111d
chore: remove some devectorization occurences
polvalente Sep 24, 2024
9075ef0
chore: simplify vectorized axes calculation
polvalente Sep 24, 2024
fcc4e10
chore: remove another superfluous devectorize
polvalente Sep 24, 2024
affdc90
Merge branch 'main' into pv-fix/vectorized-grad
polvalente Sep 24, 2024
add0134
Update nx/lib/nx/defn/grad.ex
polvalente Sep 24, 2024
8d94bc0
chore: format
polvalente Sep 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5420,9 +5420,13 @@ defmodule Nx do
{_, [], 0} ->
fun.(left, right)

{[devec_left, devec_right], canonical_vectorized_axes, _offset} ->
devec_left
|> fun.(devec_right)
{[devec_left, devec_right], canonical_vectorized_axes, offset} ->
leading_names = Keyword.keys(canonical_vectorized_axes)
l = %{devec_left | names: leading_names ++ Enum.drop(devec_left.names, offset)}
r = %{devec_right | names: leading_names ++ Enum.drop(devec_right.names, offset)}

l
|> fun.(r)
|> vectorize(canonical_vectorized_axes)
end
end
Expand Down
19 changes: 12 additions & 7 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ defmodule Nx.Defn.Expr do
def metadata(expr, metadata) when is_map(metadata) do
case to_container_expr(expr) do
%{data: %{context: context}} = res ->
expr(res, context, :metadata, [Nx.devectorize(expr), metadata])
expr(res, context, :metadata, [expr, metadata])

t when is_tuple(t) ->
context = elem(t, 0).data.context

tuple(
expr(tuple_out(tuple_size(t)), context, :metadata, [Nx.devectorize(expr), metadata]),
expr(tuple_out(tuple_size(t)), context, :metadata, [expr, metadata]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert. devectorize with keep_names.

Tuple.to_list(t)
)
end
Expand Down Expand Up @@ -1394,6 +1394,11 @@ defmodule Nx.Defn.Expr do

## Constant helpers and related optimizations

defp constant(%{vectorized_axes: [_ | _]} = out, number) do
out = %{out | names: Enum.map(out.names, fn _ -> nil end)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this part should be done here, we should preserve the names. Sorry for the confusion.

tensor(Nx.fill(out, number, type: out.type))
end

defp constant(%{shape: shape, type: type} = out, number) do
number =
cond do
Expand Down Expand Up @@ -1661,11 +1666,11 @@ defmodule Nx.Defn.Expr do

defp counter_to_name(counter), do: [?a + counter]

defp to_type_shape(%{type: type, shape: shape}) do
brackets =
shape
|> Tuple.to_list()
|> Enum.map(&[?[, Integer.to_string(&1), ?]])
defp to_type_shape(%{vectorized_axes: vectorized_axes, type: type, shape: shape}) do
axes =
Keyword.values(vectorized_axes) ++ Tuple.to_list(shape)

brackets = Enum.map(axes, &[?[, Integer.to_string(&1), ?]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we revert? 🤔


IO.iodata_to_binary([Nx.Type.to_string(type) | brackets])
end
Expand Down
109 changes: 84 additions & 25 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ defmodule Nx.Defn.Grad do

def transform(to_grad, fun, transform) do
{to_grad, ids} =
Composite.traverse(to_grad, %{}, fn to_grad, ids ->
to_grad =
Expr.metadata(to_grad, %{__MODULE__ => :to_grad})

{to_grad, Map.put(ids, to_grad.data.id, :stop)}
Composite.traverse(to_grad, %{}, fn node, ids ->
node = Expr.metadata(node, %{__MODULE__ => :to_grad})
ids = Map.put(ids, node.data.id, :stop)
{node, ids}
end)

# Collect all IDs in the function environment and mark
Expand All @@ -19,36 +18,26 @@ defmodule Nx.Defn.Grad do
{:env, env} = Function.info(fun, :env)
ids = stop_grads(env, ids)

# save vectorized axes before devectorizing
expr = to_grad |> fun.()
expr = fun.(to_grad)

transformed_expr =
expr |> transform.() |> validate_expr!()

transformed_expr = transform.(expr) |> validate_expr!() |> Nx.devectorize(keep_names: false)
{parents, nodes} = parents_tree(transformed_expr, ids)

to_grad_ids = {to_grad, ids}
grads = %{transformed_expr.data.id => [constant(1.0, transformed_expr)]}

{graded, _} =
Composite.traverse(
to_grad,
{nodes, grads},
fn %{vectorized_axes: vectorized_axes} = node, acc ->
node
|> Nx.devectorize(keep_names: false)
|> to_grad(to_grad_ids, parents, acc)
|> then(fn {node, acc} ->
{Nx.vectorize(node, vectorized_axes), acc}
end)
end
)
Composite.traverse(to_grad, {nodes, grads}, &to_grad(&1, to_grad_ids, parents, &2))

{expr, graded}
end

defp constant(float, shape) do
shape = Nx.shape(shape)
defp constant(float, %T{shape: shape} = t) do
names = List.duplicate(nil, tuple_size(shape))
Expr.constant(%T{shape: shape, type: {:f, 32}, names: names}, float, [])

Expr.constant(%T{t | names: names, type: {:f, 32}}, float, [])
end

defp validate_expr!(%T{data: %Expr{}} = expr) do
Expand Down Expand Up @@ -338,6 +327,8 @@ defmodule Nx.Defn.Grad do
@verify_grad Application.compile_env(:nx, :verify_grad, false)

defp update_grads(op, args, ans, g, _to_grad_ids, grads) do
args = revectorize_args(args, ans)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to not revectorized everything on every operation. Is there any chance we could do in broadcast only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[unbroadcast(x, Nx.multiply(g, y), ans), unbroadcast(y, Nx.multiply(g, x), ans)]

Lines like this one make it so that g is vectorized and y is unvectorized but has axes with the same name, so things break there.


pairs = grad(op, args, ans, g)

if @verify_grad do
Expand Down Expand Up @@ -1343,9 +1334,77 @@ defmodule Nx.Defn.Grad do

## General helpers

defp unbroadcast(%{shape: shape} = x, res, %{shape: shape}), do: {x, res}
defp revectorize_args(args, ans) do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's only apply this if args has more than one element and there are vectorized axes.

Also please test x * sin(y) where y is vectorized.

names_ans =
Enum.with_index(Keyword.keys(ans.vectorized_axes) ++ ans.names, fn name, idx ->
if(name, do: name, else: {:ans, idx})
end)

for arg <- args do
case arg do
%T{names: names} ->
names = Enum.with_index(names, fn name, idx -> if(name, do: {name, idx}) end)

{vectorized_axes, offset} =
names
|> Enum.reduce({[], 0}, fn
nil, acc ->
acc

{name, _idx}, {acc, count} ->
if name in names_ans do
{[name | acc], count + 1}
else
{acc, count}
end
end)

axes_names = Enum.reverse(vectorized_axes)

{vec_shape_list, shape_list} = arg.shape |> Tuple.to_list() |> Enum.split(offset)

vectorized_axes =
Enum.zip(axes_names, vec_shape_list)

%{
arg
| vectorized_axes: vectorized_axes,
names: Enum.drop(arg.names, offset),
shape: List.to_tuple(shape_list)
}

arg ->
arg
end
end
end

defp unbroadcast(x, res, ans) do
vectorized_axes_x = Keyword.keys(x.vectorized_axes)
vectorized_axes_ans = Keyword.keys(ans.vectorized_axes)

permutation =
vectorized_axes_ans
|> Enum.with_index()
|> Enum.sort_by(fn {axis, _idx} -> axis in vectorized_axes_x end)
|> Enum.map(fn {_, idx} -> idx end)

num_vectorized_axes = length(permutation)

inner_axes = Enum.to_list(num_vectorized_axes..(num_vectorized_axes + Nx.rank(res) - 1)//1)

res =
res
|> Nx.devectorize()
|> Nx.transpose(axes: permutation ++ inner_axes)
|> Nx.vectorize(vectorized_axes_x)

grad_broadcast(x, res, ans)
end

defp grad_broadcast(%{shape: shape} = x, res, %{shape: shape}), do: {x, res}

defp unbroadcast(%{shape: shape} = x, res, %{shape: new_shape}) do
defp grad_broadcast(%{shape: shape} = x, res, %{shape: new_shape}) do
axes = Nx.Shape.broadcast_axes(shape, new_shape)
{x, grad_broadcast(x, new_shape, axes, res)}
end
Expand Down
59 changes: 54 additions & 5 deletions nx/test/nx/defn/grad_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4238,21 +4238,70 @@ defmodule Nx.Defn.GradTest do
end

describe "vectorization" do
test "supports vectorization" do
test "supports combination of vectorized and non-vectorized tensors" do
x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.vectorize(:x)
y = 1

grad = Nx.Defn.grad(y, fn y -> Nx.add(x, y) end)

assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x])
end

test "supports heterogenous vectorization combinations" do
x = Nx.tensor([[1, 2, 3], [4, 5, 6]])
y = Nx.tensor([10, 20])

# first case: y is vectorized scalar, x is vectorized vectors, different vectorized axis names
# expected result: equivalent to fully broadcasting one tensor onto the other
x_vec = Nx.vectorize(x, :x)
y_vec = Nx.vectorize(y, :y)
{grad_x_vec, grad_y_vec} = Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.multiply(a, b) end)

grad_fun = fn x, y ->
Nx.Defn.grad({x, y}, fn {a, b} -> Nx.multiply(a, b) end)
end

{grad_x_vec, grad_y_vec} = grad_fun.(x_vec, y_vec)

# Explicit assertion on the results
assert grad_x_vec ==
Nx.tensor([[30.0, 30.0, 30.0], [30.0, 30.0, 30.0]])
|> Nx.vectorize(x_vec.vectorized_axes)
Nx.tensor([
[
[10.0, 10.0, 10.0],
[20.0, 20.0, 20.0]
],
[
[10.0, 10.0, 10.0],
[20.0, 20.0, 20.0]
]
])
|> Nx.vectorize([:x, :y])

assert grad_y_vec ==
Nx.tensor([
[6.0, 6.0],
[15.0, 15.0]
])
|> Nx.vectorize([:x, :y])

# Conceptual assertion: the result should be equivalent to calling Nx.Defn.grad with
# each cross-entry of the combined vectors [(x0, y0), (x0, y1), (x1, y0), (x1, y1)]

assert grad_y_vec == Nx.tensor([21.0, 21.0]) |> Nx.vectorize(y_vec.vectorized_axes)
{x0y0_wrt_x, x0y0_wrt_y} = grad_fun.(x[0], y[0])
{x0y1_wrt_x, x0y1_wrt_y} = grad_fun.(x[0], y[1])
{x1y0_wrt_x, x1y0_wrt_y} = grad_fun.(x[1], y[0])
{x1y1_wrt_x, x1y1_wrt_y} = grad_fun.(x[1], y[1])

assert grad_x_vec ==
[x0y0_wrt_x, x0y1_wrt_x, x1y0_wrt_x, x1y1_wrt_x]
|> Nx.stack()
|> Nx.reshape({2, 2, 3})
|> Nx.vectorize([:x, :y])

assert grad_y_vec ==
[x0y0_wrt_y, x0y1_wrt_y, x1y0_wrt_y, x1y1_wrt_y]
|> Nx.stack()
|> Nx.reshape({2, 2})
|> Nx.vectorize([:x, :y])

# second case: y is vectorized scalar, x is vectorized vectors, same vectorized axis name
# expected result: equivalent to "row-wise" broadcasting
Expand Down
Loading