Skip to content

Commit

Permalink
chore: revert some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Sep 16, 2024
1 parent 2f7c5f1 commit d87ffa1
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,17 @@ defmodule Nx.Defn.Grad do
{node, ids}
end)

expr = fun.(to_grad)

transformed_expr =
expr |> transform.() |> validate_expr!()

# |> Nx.devectorize(keep_names: false)

# to_grad =
# Composite.traverse(to_grad, fn node ->
# [_expr, node] = Nx.broadcast_vectors([expr, node])
# # ids = Map.put(ids, node.data.id, :stop)
# # {node, ids}
# node
# end)

# Collect all IDs in the function environment and mark
# them as stop grads. This is an optimization to avoid
# traversing trees when not necessary.
{:env, env} = Function.info(fun, :env)
ids = stop_grads(env, ids)

expr = fun.(to_grad)

transformed_expr =
expr |> transform.() |> validate_expr!()

{parents, nodes} = parents_tree(transformed_expr, ids)

to_grad_ids = {to_grad, ids}
Expand All @@ -58,7 +49,6 @@ defmodule Nx.Defn.Grad do
defp constant(float, shape) do
case shape do
%T{vectorized_axes: [_ | _]} = t ->
# [_expr, t] = Nx.broadcast_vectors([shape, float], align_ranks: false)
Expr.tensor(Nx.fill(t, float, type: :f32))

t ->
Expand Down

0 comments on commit d87ffa1

Please sign in to comment.