From 27b51e5772568ae4fce09e73d45c896b9bffab15 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 25 Jan 2025 07:16:25 +0200 Subject: [PATCH] [better_errors] Refactor more uses of partial_eval.tracing_debug_info We replace those uses with api_util.tracing_debug_info, which means we have to move the call further upstream. But this is better because we have the actual args and kwargs, and we can do a better job, especially for `arg_names`. --- jax/_src/ad_checkpoint.py | 7 ++-- jax/_src/api_util.py | 3 +- jax/_src/custom_batching.py | 7 +++- jax/_src/custom_dce.py | 9 ++--- jax/_src/interpreters/batching.py | 7 ++-- jax/_src/interpreters/partial_eval.py | 5 +++ jax/_src/lax/control_flow/common.py | 46 ++++++++++++++--------- jax/_src/lax/control_flow/conditionals.py | 22 ++++++----- jax/_src/lax/control_flow/for_loop.py | 8 ++-- jax/_src/lax/control_flow/loops.py | 9 +++-- jax/_src/lax/lax.py | 11 ++++-- tests/debug_info_test.py | 15 ++++---- 12 files changed, 89 insertions(+), 60 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index ae755e5ef92d..64978fa97217 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -439,13 +439,15 @@ def _trace_to_jaxpr(fun: Callable, ### Utilities -def saved_residuals(f, *args, **kwargs) -> list[tuple[core.AbstractValue, str]]: +def saved_residuals(f: Callable, + *args, **kwargs) -> list[tuple[core.AbstractValue, str]]: in_leaves, in_tree = tree_flatten((args, kwargs)) def f_(*args): args, kwargs = tree_unflatten(in_tree, args) return f(*args, **kwargs) + debug_info = api_util.tracing_debug_info("saved_residuals", f, args, kwargs) out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1], return_shape=True)(*in_leaves) assert isinstance(out, tuple) @@ -453,8 +455,7 @@ def f_(*args): jaxpr = jaxpr_.jaxpr out_tree = lambda: tree_structure(out_shape) assert len(jaxpr.invars) == len(in_leaves) - dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals") - return _saved_residuals(jaxpr, dbg.arg_names) + return _saved_residuals(jaxpr, debug_info.arg_names) def _saved_residuals(jaxpr: core.Jaxpr, arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]: diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index aa424ff01dbf..5d7e4b3d8c2f 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -84,7 +84,8 @@ def apply_flat_fun(fun, io_tree, *py_args): return tree_unflatten(out_tree, ans) @lu.transformation_with_aux2 -def flatten_fun_nokwargs(f, store, in_tree, *args_flat): +def flatten_fun_nokwargs(f: Callable, store: lu.Store, + in_tree: PyTreeDef, *args_flat): py_args = tree_unflatten(in_tree, args_flat) ans = f(*py_args) ans, out_tree = tree_flatten(ans) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index f748b28a7fd1..ce7973ba800e 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -28,6 +28,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util +from jax._src import api_util from jax._src.api_util import flatten_fun_nokwargs, resolve_kwargs from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -147,10 +148,12 @@ def __call__(self, *args, **kwargs): raise AttributeError( f"No batching rule defined for custom_vmap function {fun_name} " "using def_vmap.") + debug = api_util.tracing_debug_info("custom_vmap", self.fun, args, {}) args_flat, in_tree = tree_flatten(args) - flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) + flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun, + debug_info=debug), + in_tree) in_avals = [core.get_aval(x) for x in args_flat] - debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False, "custom_vmap") jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) in_tree = treedef_tuple((tree_structure(consts), in_tree)) diff --git a/jax/_src/custom_dce.py b/jax/_src/custom_dce.py index 3a827f215979..152fee605a92 100644 --- a/jax/_src/custom_dce.py +++ b/jax/_src/custom_dce.py @@ -127,6 +127,8 @@ def __call__(self, *args, **kwargs): "def_dce." ) rule_name = util.fun_name(self.dce_rule) + debug = api_util.tracing_debug_info("custom_dce", self.fun, + args, {}, static_argnums=self.static_argnums) args = api_util.resolve_kwargs(self.fun, args, kwargs) if self.static_argnums: static_argnums = set(self.static_argnums) @@ -134,7 +136,7 @@ def __call__(self, *args, **kwargs): check_for_tracers(args[i]) dyn_argnums = [i for i in range(len(args)) if i not in static_argnums] fun, dyn_args = api_util.argnums_partial( - lu.wrap_init(self.fun), + lu.wrap_init(self.fun, debug_info=debug), dyn_argnums, args, require_static_args_hashable=False, @@ -144,7 +146,7 @@ def __call__(self, *args, **kwargs): lu.wrap_init(self.dce_rule), static_args ) else: - fun = lu.wrap_init(self.fun) + fun = lu.wrap_init(self.fun, debug_info=debug) dce_rule = lu.wrap_init(self.dce_rule) dyn_args = args @@ -188,9 +190,6 @@ def dce_jaxpr_thunk( return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins - debug = pe.tracing_debug_info( - self.fun, in_tree, out_tree, False, "custom_dce" - ) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) out_avals = closed_call.out_avals diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index d5fb5f9856a2..84f8c5520caf 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -32,7 +32,7 @@ from jax._src.core import Trace, Tracer, TraceTag, AxisName from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_unflatten, tree_flatten, - register_pytree_node) + register_pytree_node, PyTreeDef) from jax._src.typing import Array from jax._src.util import (unzip2, safe_map, safe_zip, split_list, canonicalize_axis, moveaxis, as_hashable_function, @@ -328,7 +328,8 @@ def is_vmappable(x: Any) -> bool: return type(x) is Jumble or type(x) in vmappables @lu.transformation_with_aux2 -def flatten_fun_for_vmap(f, store, in_tree, *args_flat): +def flatten_fun_for_vmap(f: Callable, + store: lu.Store, in_tree: PyTreeDef, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) ans = f(*py_args, **py_kwargs) ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable) @@ -591,7 +592,7 @@ def _batch_outer(f, axis_data, in_dims, *in_vals): return outs @lu.transformation2 -def _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals): +def _batch_inner(f: Callable, axis_data, out_dim_dests, tag, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index dd81d8d4a552..38a5770730f3 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2119,6 +2119,11 @@ def tracing_debug_info( has_kwargs: bool, traced_for: str ) -> lu.TracingDebugInfo: + # if traced_for not in ("jit", "pmap", "xla_pmap", "checkify", "scan", + # "", "while_cond", "while_loop", + # "custom_linear_solve", "closed_call", + # "pallas_call", "pallas_call index_map", "checkify_pallas"): + # assert False, (traced_for, fn) # DO_NOT_SUBMIT # TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead # We just have to make sure we grad the debugging information when we have # the unflattened args diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 04393ccbf67a..eb1c29d046fa 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -20,6 +20,7 @@ from functools import partial from typing import Any +from jax._src import api_util from jax._src import core from jax._src import linear_util as lu from jax._src.lax import lax @@ -28,9 +29,8 @@ from jax._src import state from jax._src import util from jax._src.util import weakref_lru_cache, safe_map, partition_list -from jax.api_util import flatten_fun_nokwargs from jax._src.interpreters import partial_eval as pe -from jax.tree_util import tree_map, tree_unflatten, keystr +from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef from jax._src.tree_util import equality_errors_pytreedef map, unsafe_map = safe_map, map @@ -50,41 +50,51 @@ def _typecheck_param(prim, param, name, msg_required, pred): raise core.JaxprTypeError(msg) @weakref_lru_cache -def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, - primitive_name: str | None = None): - wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - debug = pe.tracing_debug_info(fun, in_tree, out_tree, False, - primitive_name or "") +def _initial_style_open_jaxpr(fun: Callable, + in_tree: PyTreeDef, + in_avals: Sequence[core.AbstractValue], + debug_info: api_util.TracingDebugInfo | None = None): + wrapped_fun, out_tree = api_util.flatten_fun_nokwargs( + lu.wrap_init(fun, debug_info=debug_info), + in_tree) jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( - wrapped_fun, in_avals, debug) + wrapped_fun, in_avals, debug_info) return jaxpr, consts, out_tree(), attrs_tracked @weakref_lru_cache -def _initial_style_jaxpr(fun: Callable, in_tree, in_avals, - primitive_name: str | None = None): +def _initial_style_jaxpr(fun: Callable, + in_tree: PyTreeDef, + in_avals: Sequence[core.AbstractValue], + # TODO(necula): remove primitive_name + primitive_name: str | None = None, + debug_info: api_util.TracingDebugInfo | None = None): jaxpr, consts, out_tree, () = _initial_style_open_jaxpr( - fun, in_tree, in_avals, primitive_name) + fun, in_tree, in_avals, debug_info) closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) return closed_jaxpr, consts, out_tree -def _initial_style_jaxpr_attrs(fun: Callable, in_tree, in_avals, - primitive_name: str | None = None): +def _initial_style_jaxpr_attrs(fun: Callable, + in_tree: PyTreeDef, + in_avals: Sequence[core.AbstractValue], + debug_info: api_util.TracingDebugInfo | None = None): jaxpr, consts, out_tree, attrs_tracked = _initial_style_open_jaxpr( - fun, in_tree, in_avals, primitive_name) + fun, in_tree, in_avals, debug_info) closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) return closed_jaxpr, consts, out_tree, attrs_tracked def _initial_style_jaxprs_with_common_consts( - funs: Sequence[Callable], in_tree, in_avals, primitive_name: str): + funs: Sequence[Callable], + in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], *, + debug_infos: Sequence[api_util.TracingDebugInfo | None]): # When staging the branches of a conditional into jaxprs, constants are # extracted from each branch and converted to jaxpr arguments. To use the # staged jaxprs as the branches to a conditional *primitive*, we need for # their (input) signatures to match. This function "joins" the staged jaxprs: # for each one, it makes another that accepts *all* constants, but only uses # those that it needs (dropping the rest). - - jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals, primitive_name) - for fn in funs] + jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals, + debug_info) + for fn, debug_info in zip(funs, debug_infos)] if not jaxpr_data: return [], [], [] diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 17116e857671..e54a7ea07410 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -25,8 +25,7 @@ from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util -from jax._src.api_util import ( - _check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) +from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dispatch @@ -135,17 +134,18 @@ def switch(index, branches, *operands): if (config.disable_jit.value and core.is_concrete(index)): return branches[int(index)](*operands) + dbgs = [api_util.tracing_debug_info("switch", branch, operands, {}) + for branch in branches] ops, ops_tree = tree_flatten(operands) ops_avals = tuple(map(core.get_aval, ops)) if config.mutable_array_checks.value: - dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch') # type: ignore - _check_no_aliased_ref_args(dbg, ops_avals, ops) + api_util._check_no_aliased_ref_args(dbgs[0], ops_avals, ops) jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( - branches, ops_tree, ops_avals, primitive_name='switch') + branches, ops_tree, ops_avals, debug_infos=dbgs) if config.mutable_array_checks.value: - _check_no_aliased_closed_over_refs(dbg, (*jaxprs[0].consts, *consts), ops) + api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops) for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): _check_tree_and_avals("branch 0 output", out_trees[0], jaxprs[0].out_avals, @@ -238,13 +238,15 @@ def cond(pred, true_fun, false_fun, *operands): ops_avals = tuple(map(core.get_aval, ops)) if config.mutable_array_checks.value: - dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond') # type: ignore - _check_no_aliased_ref_args(dbg, ops_avals, ops) + dbg_true_fun = api_util.tracing_debug_info("cond", true_fun, operands, {}) + api_util._check_no_aliased_ref_args(dbg_true_fun, ops_avals, ops) + dbg_false_fun = api_util.tracing_debug_info("cond", false_fun, operands, {}) jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( - (true_fun, false_fun), ops_tree, ops_avals, 'cond') + (true_fun, false_fun), ops_tree, ops_avals, + debug_infos=[dbg_true_fun, dbg_false_fun]) true_jaxpr, false_jaxpr = jaxprs if config.mutable_array_checks.value: - _check_no_aliased_closed_over_refs(dbg, (*true_jaxpr.consts, *consts), ops) + api_util._check_no_aliased_closed_over_refs(dbg_true_fun, (*true_jaxpr.consts, *consts), ops) out_tree, false_out_tree = out_trees if any(isinstance(out_aval, AbstractRef) for out_aval in diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index c1e8782b68ef..f3b8fbce4715 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -21,7 +21,7 @@ from typing import Any, Generic, TypeVar from jax import lax -from jax.api_util import flatten_fun_nokwargs +from jax._src import api_util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -73,7 +73,7 @@ class Ref(Generic[T]): pass def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef, state_avals: Sequence[core.AbstractValue] ) -> tuple[core.Jaxpr, list[Any], PyTreeDef]: - f, out_tree_thunk = flatten_fun_nokwargs( + f, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree))) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( f, state_avals) @@ -195,10 +195,10 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]], def _create_jaxpr(init): init_flat = tree_leaves(init) _, in_tree = tree_flatten((init, xs)) - + dbg = api_util.tracing_debug_info("scan", f, (init, xs), {}) carry_avals = tuple(map(core.get_aval, init_flat)) jaxpr, _, out_tree = _initial_style_jaxpr( - f, in_tree, carry_avals + x_avals, "scan") + f, in_tree, carry_avals + x_avals, "scan", debug_info=dbg) return jaxpr, out_tree jaxpr, out_tree = _create_jaxpr(init) _, ys_avals = tree_unflatten(out_tree, jaxpr.out_avals) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 5a2a5608d732..2ba6c332aedf 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -283,8 +283,9 @@ def _create_jaxpr(init): init_flat, init_tree = tree_flatten(init) in_flat, in_tree = tree_flatten((init, xs)) carry_avals = tuple(_map(core.get_aval, init_flat)) + dbg = api_util.tracing_debug_info("scan", f, (init, xs), {}) jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs( - f, in_tree, (*carry_avals, *x_avals), "scan") + f, in_tree, (*carry_avals, *x_avals), debug_info=dbg) if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), in_flat) out_tree_children = out_tree.children() @@ -1355,10 +1356,12 @@ def while_loop(cond_fun, body_fun, init_val): def _create_jaxpr(init_val): init_vals, in_tree = tree_flatten((init_val,)) init_avals = tuple(_map(core.get_aval, init_vals)) + cond_dbg = api_util.tracing_debug_info("while_cond", cond_fun, (init_val,), {}) cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( - cond_fun, in_tree, init_avals, "while_cond") + cond_fun, in_tree, init_avals, "while_cond", debug_info=cond_dbg) + body_dbg = api_util.tracing_debug_info("while_body", body_fun, (init_val,), {}) body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( - body_fun, in_tree, init_avals, "while_loop") + body_fun, in_tree, init_avals, "while_body", debug_info=body_dbg) if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: msg = "cond_fun must return a boolean scalar, but got pytree {}." raise TypeError(msg.format(cond_tree)) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 650c1b620cf1..0458cd66e5ae 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -658,9 +658,12 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: @weakref_lru_cache -def _trace_composite_to_jaxpr(fun, in_tree, in_avals, name: str): +def _trace_composite_to_jaxpr(fun: Callable, + in_tree: tree_util.PyTreeDef, + in_avals: Sequence[core.AbstractValue], + name: str, + debug_info: api_util.TracingDebugInfo): flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - debug_info = pe.tracing_debug_info(fun, in_tree, out_tree, False, "composite") jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info) if any(isinstance(c, core.Tracer) for c in consts): raise UnexpectedTracerError( @@ -736,10 +739,12 @@ def composite( """ @functools.wraps(decomposition) def _decorator(*args, **kwargs): + debug_info = api_util.tracing_debug_info("composite", decomposition, + args, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) in_avals = tuple(core.get_aval(x) for x in flat_args) closed_jaxpr, out_tree = _trace_composite_to_jaxpr( - partial(decomposition, **kwargs), in_tree, in_avals, name + partial(decomposition, **kwargs), in_tree, in_avals, name, debug_info ) out_flat = composite_p.bind( *flat_args, diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index c2bd8bab78d5..87395986a3aa 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -770,10 +770,10 @@ def my_branch0(x0): return x0 def my_branch1(x1): leaked_tracers.append(x1) - return x1 + return x1 + 1 def my_branch2(x2): leaked_tracers.append(x2) - return x2 + return x2 + 2 return lax.switch(x, [my_branch0, my_branch1, my_branch2], x) self._check_tracers_and_jaxprs( @@ -884,11 +884,11 @@ def my_body(b): leaked_tracers=leaked_tracers, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)", - # TODO(necula): some Jaxprs without debug info - 'None'], + 'None', # TODO(necula): some missing debug info + ], expected_tracer_debug_infos=[ "traced_for=while_cond, fun=my_cond, arg_names=('a',)", - "traced_for=while_loop, fun=my_body, arg_names=('b',)" + "traced_for=while_body, fun=my_body, arg_names=('b',)", ]) def test_scan(self): @@ -1071,8 +1071,7 @@ def my_f(x): # TODO(necula): some Jaxprs without debug info 'None'], expected_tracer_debug_infos=[ - # TODO(necula): bad arg_names - "traced_for=custom_dce, fun=my_g, arg_names=('args[0]',)" + "traced_for=custom_dce, fun=my_g, arg_names=('x',)" ]) def test_custom_linear_solve_complex(self): @@ -1109,7 +1108,7 @@ def tr_solve(matvec, x): expected_tracer_debug_infos=[ # TODO(necula): we don't see any leaks from tr_solve? "None", # TODO(necula): there are missing debug info - re.compile(r"traced_for=custom_linear_solve, fun=f at .*control_flow/solves.py:.*, arg_names=\('x',\)"), + # re.compile(r"traced_for=custom_linear_solve, fun=f at .*control_flow/solves.py:.*, arg_names=\('x',\)"), ]) def test_custom_root_errors(self):