-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: broadcast vectors for grad calculation #1535
Changes from 9 commits
394a12d
414726b
a08d0fd
2f7c5f1
d87ffa1
7fbdffd
db0b6f0
20cc168
22b9a24
8f60a71
a026c0e
37408b3
380c330
1487b17
a832956
24b9ea5
d0f93c9
b89111d
9075ef0
fcc4e10
affdc90
add0134
8d94bc0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,13 +94,13 @@ defmodule Nx.Defn.Expr do | |
def metadata(expr, metadata) when is_map(metadata) do | ||
case to_container_expr(expr) do | ||
%{data: %{context: context}} = res -> | ||
expr(res, context, :metadata, [Nx.devectorize(expr), metadata]) | ||
expr(res, context, :metadata, [expr, metadata]) | ||
|
||
t when is_tuple(t) -> | ||
context = elem(t, 0).data.context | ||
|
||
tuple( | ||
expr(tuple_out(tuple_size(t)), context, :metadata, [Nx.devectorize(expr), metadata]), | ||
expr(tuple_out(tuple_size(t)), context, :metadata, [expr, metadata]), | ||
Tuple.to_list(t) | ||
) | ||
end | ||
|
@@ -1394,6 +1394,11 @@ defmodule Nx.Defn.Expr do | |
|
||
## Constant helpers and related optimizations | ||
|
||
defp constant(%{vectorized_axes: [_ | _]} = out, number) do | ||
out = %{out | names: Enum.map(out.names, fn _ -> nil end)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this part should be done here, we should preserve the names. Sorry for the confusion. |
||
tensor(Nx.fill(out, number, type: out.type)) | ||
end | ||
|
||
defp constant(%{shape: shape, type: type} = out, number) do | ||
number = | ||
cond do | ||
|
@@ -1661,11 +1666,11 @@ defmodule Nx.Defn.Expr do | |
|
||
defp counter_to_name(counter), do: [?a + counter] | ||
|
||
defp to_type_shape(%{type: type, shape: shape}) do | ||
brackets = | ||
shape | ||
|> Tuple.to_list() | ||
|> Enum.map(&[?[, Integer.to_string(&1), ?]]) | ||
defp to_type_shape(%{vectorized_axes: vectorized_axes, type: type, shape: shape}) do | ||
axes = | ||
Keyword.values(vectorized_axes) ++ Tuple.to_list(shape) | ||
|
||
brackets = Enum.map(axes, &[?[, Integer.to_string(&1), ?]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we revert? 🤔 |
||
|
||
IO.iodata_to_binary([Nx.Type.to_string(type) | brackets]) | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,10 @@ defmodule Nx.Defn.Grad do | |
|
||
def transform(to_grad, fun, transform) do | ||
{to_grad, ids} = | ||
Composite.traverse(to_grad, %{}, fn to_grad, ids -> | ||
to_grad = | ||
Expr.metadata(to_grad, %{__MODULE__ => :to_grad}) | ||
|
||
{to_grad, Map.put(ids, to_grad.data.id, :stop)} | ||
Composite.traverse(to_grad, %{}, fn node, ids -> | ||
node = Expr.metadata(node, %{__MODULE__ => :to_grad}) | ||
ids = Map.put(ids, node.data.id, :stop) | ||
{node, ids} | ||
end) | ||
|
||
# Collect all IDs in the function environment and mark | ||
|
@@ -19,36 +18,26 @@ defmodule Nx.Defn.Grad do | |
{:env, env} = Function.info(fun, :env) | ||
ids = stop_grads(env, ids) | ||
|
||
# save vectorized axes before devectorizing | ||
expr = to_grad |> fun.() | ||
expr = fun.(to_grad) | ||
|
||
transformed_expr = | ||
expr |> transform.() |> validate_expr!() | ||
|
||
transformed_expr = transform.(expr) |> validate_expr!() |> Nx.devectorize(keep_names: false) | ||
{parents, nodes} = parents_tree(transformed_expr, ids) | ||
|
||
to_grad_ids = {to_grad, ids} | ||
grads = %{transformed_expr.data.id => [constant(1.0, transformed_expr)]} | ||
|
||
{graded, _} = | ||
Composite.traverse( | ||
to_grad, | ||
{nodes, grads}, | ||
fn %{vectorized_axes: vectorized_axes} = node, acc -> | ||
node | ||
|> Nx.devectorize(keep_names: false) | ||
|> to_grad(to_grad_ids, parents, acc) | ||
|> then(fn {node, acc} -> | ||
{Nx.vectorize(node, vectorized_axes), acc} | ||
end) | ||
end | ||
) | ||
Composite.traverse(to_grad, {nodes, grads}, &to_grad(&1, to_grad_ids, parents, &2)) | ||
|
||
{expr, graded} | ||
end | ||
|
||
defp constant(float, shape) do | ||
shape = Nx.shape(shape) | ||
defp constant(float, %T{shape: shape} = t) do | ||
names = List.duplicate(nil, tuple_size(shape)) | ||
Expr.constant(%T{shape: shape, type: {:f, 32}, names: names}, float, []) | ||
|
||
Expr.constant(%T{t | names: names, type: {:f, 32}}, float, []) | ||
end | ||
|
||
defp validate_expr!(%T{data: %Expr{}} = expr) do | ||
|
@@ -338,6 +327,8 @@ defmodule Nx.Defn.Grad do | |
@verify_grad Application.compile_env(:nx, :verify_grad, false) | ||
|
||
defp update_grads(op, args, ans, g, _to_grad_ids, grads) do | ||
args = revectorize_args(args, ans) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer to not revectorized everything on every operation. Is there any chance we could do in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Lines like this one make it so that |
||
|
||
pairs = grad(op, args, ans, g) | ||
|
||
if @verify_grad do | ||
|
@@ -1343,9 +1334,77 @@ defmodule Nx.Defn.Grad do | |
|
||
## General helpers | ||
|
||
defp unbroadcast(%{shape: shape} = x, res, %{shape: shape}), do: {x, res} | ||
defp revectorize_args(args, ans) do | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's only apply this if args has more than one element and there are vectorized axes. Also please test |
||
names_ans = | ||
Enum.with_index(Keyword.keys(ans.vectorized_axes) ++ ans.names, fn name, idx -> | ||
if(name, do: name, else: {:ans, idx}) | ||
end) | ||
|
||
for arg <- args do | ||
case arg do | ||
%T{names: names} -> | ||
names = Enum.with_index(names, fn name, idx -> if(name, do: {name, idx}) end) | ||
|
||
{vectorized_axes, offset} = | ||
names | ||
|> Enum.reduce({[], 0}, fn | ||
nil, acc -> | ||
acc | ||
|
||
{name, _idx}, {acc, count} -> | ||
if name in names_ans do | ||
{[name | acc], count + 1} | ||
else | ||
{acc, count} | ||
end | ||
end) | ||
|
||
axes_names = Enum.reverse(vectorized_axes) | ||
|
||
{vec_shape_list, shape_list} = arg.shape |> Tuple.to_list() |> Enum.split(offset) | ||
|
||
vectorized_axes = | ||
Enum.zip(axes_names, vec_shape_list) | ||
|
||
%{ | ||
arg | ||
| vectorized_axes: vectorized_axes, | ||
names: Enum.drop(arg.names, offset), | ||
shape: List.to_tuple(shape_list) | ||
} | ||
|
||
arg -> | ||
arg | ||
end | ||
end | ||
end | ||
|
||
defp unbroadcast(x, res, ans) do | ||
vectorized_axes_x = Keyword.keys(x.vectorized_axes) | ||
vectorized_axes_ans = Keyword.keys(ans.vectorized_axes) | ||
|
||
permutation = | ||
vectorized_axes_ans | ||
|> Enum.with_index() | ||
|> Enum.sort_by(fn {axis, _idx} -> axis in vectorized_axes_x end) | ||
|> Enum.map(fn {_, idx} -> idx end) | ||
|
||
num_vectorized_axes = length(permutation) | ||
|
||
inner_axes = Enum.to_list(num_vectorized_axes..(num_vectorized_axes + Nx.rank(res) - 1)//1) | ||
|
||
res = | ||
res | ||
|> Nx.devectorize() | ||
|> Nx.transpose(axes: permutation ++ inner_axes) | ||
|> Nx.vectorize(vectorized_axes_x) | ||
|
||
grad_broadcast(x, res, ans) | ||
end | ||
|
||
defp grad_broadcast(%{shape: shape} = x, res, %{shape: shape}), do: {x, res} | ||
|
||
defp unbroadcast(%{shape: shape} = x, res, %{shape: new_shape}) do | ||
defp grad_broadcast(%{shape: shape} = x, res, %{shape: new_shape}) do | ||
axes = Nx.Shape.broadcast_axes(shape, new_shape) | ||
{x, grad_broadcast(x, new_shape, axes, res)} | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert. devectorize with keep_names.