Skip to content

Commit

Permalink
[better_errors] Refactor more uses of partial_eval.tracing_debug_info…
Browse files Browse the repository at this point in the history
… (part 1)

We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.

This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
  • Loading branch information
gnecula committed Jan 27, 2025
1 parent 381da3c commit c8a6053
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 85 deletions.
7 changes: 4 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,22 +439,23 @@ def _trace_to_jaxpr(fun: Callable,

### Utilities

def saved_residuals(f, *args, **kwargs) -> list[tuple[core.AbstractValue, str]]:
def saved_residuals(f: Callable,
*args, **kwargs) -> list[tuple[core.AbstractValue, str]]:
in_leaves, in_tree = tree_flatten((args, kwargs))

def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)

debug_info = api_util.tracing_debug_info("saved_residuals", f, args, kwargs)
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
return_shape=True)(*in_leaves)
assert isinstance(out, tuple)
jaxpr_, out_shape = out
jaxpr = jaxpr_.jaxpr
out_tree = lambda: tree_structure(out_shape)
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals")
return _saved_residuals(jaxpr, dbg.arg_names)
return _saved_residuals(jaxpr, debug_info.arg_names)

def _saved_residuals(jaxpr: core.Jaxpr,
arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]:
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def apply_flat_fun(fun, io_tree, *py_args):
return tree_unflatten(out_tree, ans)

@lu.transformation_with_aux2
def flatten_fun_nokwargs(f, store, in_tree, *args_flat):
def flatten_fun_nokwargs(f: Callable, store: lu.Store,
in_tree: PyTreeDef, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = f(*py_args)
ans, out_tree = tree_flatten(ans)
Expand Down
7 changes: 5 additions & 2 deletions jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src import api_util
from jax._src.api_util import flatten_fun_nokwargs, resolve_kwargs
from jax._src.interpreters import ad
from jax._src.interpreters import batching
Expand Down Expand Up @@ -147,10 +148,12 @@ def __call__(self, *args, **kwargs):
raise AttributeError(
f"No batching rule defined for custom_vmap function {fun_name} "
"using def_vmap.")
debug = api_util.tracing_debug_info("custom_vmap", self.fun, args, {})
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun,
debug_info=debug),
in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))
Expand Down
9 changes: 4 additions & 5 deletions jax/_src/custom_dce.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,16 @@ def __call__(self, *args, **kwargs):
"def_dce."
)
rule_name = util.fun_name(self.dce_rule)
debug = api_util.tracing_debug_info("custom_dce", self.fun,
args, {}, static_argnums=self.static_argnums)
args = api_util.resolve_kwargs(self.fun, args, kwargs)
if self.static_argnums:
static_argnums = set(self.static_argnums)
for i in static_argnums:
check_for_tracers(args[i])
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
fun, dyn_args = api_util.argnums_partial(
lu.wrap_init(self.fun),
lu.wrap_init(self.fun, debug_info=debug),
dyn_argnums,
args,
require_static_args_hashable=False,
Expand All @@ -144,7 +146,7 @@ def __call__(self, *args, **kwargs):
lu.wrap_init(self.dce_rule), static_args
)
else:
fun = lu.wrap_init(self.fun)
fun = lu.wrap_init(self.fun, debug_info=debug)
dce_rule = lu.wrap_init(self.dce_rule)
dyn_args = args

Expand Down Expand Up @@ -188,9 +190,6 @@ def dce_jaxpr_thunk(

return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins

debug = pe.tracing_debug_info(
self.fun, in_tree, out_tree, False, "custom_dce"
)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
out_avals = closed_call.out_avals
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from jax._src.core import Trace, Tracer, TraceTag, AxisName
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
register_pytree_node, PyTreeDef)
from jax._src.typing import Array
from jax._src.util import (unzip2, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
Expand Down Expand Up @@ -328,7 +328,8 @@ def is_vmappable(x: Any) -> bool:
return type(x) is Jumble or type(x) in vmappables

@lu.transformation_with_aux2
def flatten_fun_for_vmap(f, store, in_tree, *args_flat):
def flatten_fun_for_vmap(f: Callable,
store: lu.Store, in_tree: PyTreeDef, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = f(*py_args, **py_kwargs)
ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable)
Expand Down Expand Up @@ -591,7 +592,7 @@ def _batch_outer(f, axis_data, in_dims, *in_vals):
return outs

@lu.transformation2
def _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals):
def _batch_inner(f: Callable, axis_data, out_dim_dests, tag, in_dims, *in_vals):
in_dims = in_dims() if callable(in_dims) else in_dims
with core.take_current_trace() as parent_trace:
trace = BatchTrace(parent_trace, tag, axis_data)
Expand Down
45 changes: 27 additions & 18 deletions jax/_src/lax/control_flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from functools import partial
from typing import Any

from jax._src import api_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src.lax import lax
Expand All @@ -28,9 +29,8 @@
from jax._src import state
from jax._src import util
from jax._src.util import weakref_lru_cache, safe_map, partition_list
from jax.api_util import flatten_fun_nokwargs
from jax._src.interpreters import partial_eval as pe
from jax.tree_util import tree_map, tree_unflatten, keystr
from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef
from jax._src.tree_util import equality_errors_pytreedef

map, unsafe_map = safe_map, map
Expand All @@ -50,41 +50,50 @@ def _typecheck_param(prim, param, name, msg_required, pred):
raise core.JaxprTypeError(msg)

@weakref_lru_cache
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: str | None = None):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
debug = pe.tracing_debug_info(fun, in_tree, out_tree, False,
primitive_name or "<unknown>")
def _initial_style_open_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo):
wrapped_fun, out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun, debug_info=debug_info),
in_tree)
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
wrapped_fun, in_avals, debug)
wrapped_fun, in_avals, debug_info)
return jaxpr, consts, out_tree(), attrs_tracked

@weakref_lru_cache
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: str | None = None):
def _initial_style_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo):
assert debug_info is not None
jaxpr, consts, out_tree, () = _initial_style_open_jaxpr(
fun, in_tree, in_avals, primitive_name)
fun, in_tree, in_avals, debug_info)
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
return closed_jaxpr, consts, out_tree

def _initial_style_jaxpr_attrs(fun: Callable, in_tree, in_avals,
primitive_name: str | None = None):
def _initial_style_jaxpr_attrs(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo):
jaxpr, consts, out_tree, attrs_tracked = _initial_style_open_jaxpr(
fun, in_tree, in_avals, primitive_name)
fun, in_tree, in_avals, debug_info)
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
return closed_jaxpr, consts, out_tree, attrs_tracked

def _initial_style_jaxprs_with_common_consts(
funs: Sequence[Callable], in_tree, in_avals, primitive_name: str):
funs: Sequence[Callable],
in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
debug_infos: Sequence[api_util.TracingDebugInfo]):
# When staging the branches of a conditional into jaxprs, constants are
# extracted from each branch and converted to jaxpr arguments. To use the
# staged jaxprs as the branches to a conditional *primitive*, we need for
# their (input) signatures to match. This function "joins" the staged jaxprs:
# for each one, it makes another that accepts *all* constants, but only uses
# those that it needs (dropping the rest).

jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals, primitive_name)
for fn in funs]
jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals,
debug_info)
for fn, debug_info in zip(funs, debug_infos)]
if not jaxpr_data:
return [], [], []

Expand Down
22 changes: 12 additions & 10 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@

from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src.api_util import (
_check_no_aliased_ref_args, _check_no_aliased_closed_over_refs)
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
Expand Down Expand Up @@ -135,17 +134,18 @@ def switch(index, branches, *operands):
if (config.disable_jit.value and core.is_concrete(index)):
return branches[int(index)](*operands)

dbgs = [api_util.tracing_debug_info("switch", branch, operands, {})
for branch in branches]
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops))

if config.mutable_array_checks.value:
dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch') # type: ignore
_check_no_aliased_ref_args(dbg, ops_avals, ops)
api_util._check_no_aliased_ref_args(dbgs[0], ops_avals, ops)

jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
branches, ops_tree, ops_avals, dbgs)
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxprs[0].consts, *consts), ops)
api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops)
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals("branch 0 output",
out_trees[0], jaxprs[0].out_avals,
Expand Down Expand Up @@ -237,14 +237,16 @@ def cond(pred, true_fun, false_fun, *operands):
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops))

dbg_true_fun = api_util.tracing_debug_info("cond", true_fun, operands, {})
if config.mutable_array_checks.value:
dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond') # type: ignore
_check_no_aliased_ref_args(dbg, ops_avals, ops)
api_util._check_no_aliased_ref_args(dbg_true_fun, ops_avals, ops)
dbg_false_fun = api_util.tracing_debug_info("cond", false_fun, operands, {})
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
(true_fun, false_fun), ops_tree, ops_avals,
[dbg_true_fun, dbg_false_fun])
true_jaxpr, false_jaxpr = jaxprs
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*true_jaxpr.consts, *consts), ops)
api_util._check_no_aliased_closed_over_refs(dbg_true_fun, (*true_jaxpr.consts, *consts), ops)

out_tree, false_out_tree = out_trees
if any(isinstance(out_aval, AbstractRef) for out_aval in
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/lax/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Generic, TypeVar

from jax import lax
from jax.api_util import flatten_fun_nokwargs
from jax._src import api_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand Down Expand Up @@ -73,7 +73,7 @@ class Ref(Generic[T]): pass
def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
state_avals: Sequence[core.AbstractValue]
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
f, out_tree_thunk = flatten_fun_nokwargs(
f, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree)))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
f, state_avals)
Expand Down Expand Up @@ -195,10 +195,10 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
def _create_jaxpr(init):
init_flat = tree_leaves(init)
_, in_tree = tree_flatten((init, xs))

dbg = api_util.tracing_debug_info("scan", f, (init, xs), {})
carry_avals = tuple(map(core.get_aval, init_flat))
jaxpr, _, out_tree = _initial_style_jaxpr(
f, in_tree, carry_avals + x_avals, "scan")
f, in_tree, carry_avals + x_avals, dbg)
return jaxpr, out_tree
jaxpr, out_tree = _create_jaxpr(init)
_, ys_avals = tree_unflatten(out_tree, jaxpr.out_avals)
Expand Down
14 changes: 8 additions & 6 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,21 +272,21 @@ def scan(f, init, xs, length=None):

xs_avals = [core.get_aval(x) for x in xs_flat]
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
dbg_body = api_util.tracing_debug_info("scan", f, (init, xs), {})

if config.mutable_array_checks.value:
in_flat, in_tree = tree_flatten((init, xs))
dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan') # type: ignore
in_avals = tuple(_map(core.get_aval, in_flat))
_check_no_aliased_ref_args(dbg, in_avals, in_flat)
_check_no_aliased_ref_args(dbg_body, in_avals, in_flat)

def _create_jaxpr(init):
init_flat, init_tree = tree_flatten(init)
in_flat, in_tree = tree_flatten((init, xs))
carry_avals = tuple(_map(core.get_aval, init_flat))
jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs(
f, in_tree, (*carry_avals, *x_avals), "scan")
f, in_tree, (*carry_avals, *x_avals), debug_info=dbg_body)
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), in_flat)
_check_no_aliased_closed_over_refs(dbg_body, (*jaxpr.consts, *consts), in_flat)
out_tree_children = out_tree.children()
if len(out_tree_children) != 2:
msg = "scan body output must be a pair, got {}."
Expand Down Expand Up @@ -1355,10 +1355,12 @@ def while_loop(cond_fun, body_fun, init_val):
def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(core.get_aval, init_vals))
cond_dbg = api_util.tracing_debug_info("while_cond", cond_fun, (init_val,), {})
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, "while_cond")
cond_fun, in_tree, init_avals, cond_dbg)
body_dbg = api_util.tracing_debug_info("while_body", body_fun, (init_val,), {})
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
body_fun, in_tree, init_avals, "while_loop")
body_fun, in_tree, init_avals, body_dbg)
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
msg = "cond_fun must return a boolean scalar, but got pytree {}."
raise TypeError(msg.format(cond_tree))
Expand Down
Loading

0 comments on commit c8a6053

Please sign in to comment.