Skip to content

Commit

Permalink
[better_errors] Refactor more uses of pe.tracing_debug_info (part 3)
Browse files Browse the repository at this point in the history
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following jax-ml#26097, jax-ml#26099) for jit, pmap, checkify, custom_dce.
  • Loading branch information
gnecula committed Jan 26, 2025
1 parent cedb516 commit 4c3bbcf
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 35 deletions.
11 changes: 6 additions & 5 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
result_paths, flat_out_axes, debug_info_final)
result_paths, flat_out_axes)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -1392,15 +1392,15 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
return global_axis_size


def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, in_devices, backend_name,
axis_size, args, kwargs):
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")

dbg = tracing_debug_info(
'pmap', fun, args, kwargs,
static_argnums=static_broadcasted_tuple)
"pmap", fun, args, kwargs,
static_argnums=static_broadcasted_tuple)

f = lu.wrap_init(fun)
if static_broadcasted_tuple:
Expand Down Expand Up @@ -1450,9 +1450,10 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")

f, res_paths = result_paths(f)
dbg = dbg.add_result_paths(res_paths)
f = lu.add_debug_info(f, dbg)
f, out_axes_thunk = flat_out_axes(f, out_axes)
flat_fun, out_tree = flatten_fun(f, in_tree)
flat_fun = debug_info_final(flat_fun, dbg, res_paths)

is_explicit_global_axis_size = axis_size is not None
global_axis_size = _get_global_axis_size(local_axis_size, in_devices,
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,14 +705,16 @@ def result_paths(_fun, _store, *args, **kwargs):
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans

# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug: TracingDebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None:
return jaxpr
assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)

# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if result_paths is None:
result_paths = trace_debug.result_paths_thunk() # type: ignore
debug_info = core.JaxprDebugInfo(
Expand Down
8 changes: 5 additions & 3 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from jax.experimental import shard_map
from jax._src import api
from jax._src import api_util
from jax._src import ad_checkpoint
from jax._src import linear_util as lu
from jax._src import config
Expand All @@ -39,7 +40,6 @@
from jax._src import traceback_util
from jax._src import tree_util as jtu
from jax._src.ad_util import SymbolicZero
from jax._src.api_util import flatten_fun
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand Down Expand Up @@ -1202,8 +1202,10 @@ def checked_fun(*args, **kwargs):
in_tree = jtu.tree_structure(((), {}))
closed_f = lambda: f(*args, **kwargs)
# stage:
fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
debug = pe.tracing_debug_info(closed_f, in_tree, out_tree, False, 'checkify')
debug = api_util.tracing_debug_info("checkify", f, args, kwargs)
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
debug_info=debug),
in_tree)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:
Expand Down
13 changes: 7 additions & 6 deletions jax/_src/custom_dce.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def __call__(self, *args, **kwargs):
)
rule_name = util.fun_name(self.dce_rule)
debug = api_util.tracing_debug_info("custom_dce", self.fun,
args, {}, static_argnums=self.static_argnums)
args, {},
static_argnums=self.static_argnums)
debug_rule = api_util.tracing_debug_info("custom_dce_rule", self.dce_rule,
args, {},
static_argnums=self.static_argnums)
args = api_util.resolve_kwargs(self.fun, args, kwargs)
if self.static_argnums:
static_argnums = set(self.static_argnums)
Expand Down Expand Up @@ -171,11 +175,8 @@ def dce_jaxpr_thunk(
out_avals,
)
assert self.dce_rule is not None
debug = pe.tracing_debug_info(
self.dce_rule, in_tree, rule_out_tree, False, "custom_dce_rule"
)
dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic(
flat_rule, in_avals, debug
flat_rule, in_avals, debug_rule
)

# This second round of DCE is used to work out which inputs are actually
Expand All @@ -190,7 +191,7 @@ def dce_jaxpr_thunk(

return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins

jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_rule)
closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
out_avals = closed_call.out_avals
out_flat = custom_dce_p.bind(
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,7 +1940,8 @@ def process_call(self, call_primitive, f: lu.WrappedFun,
self.frame.add_eqn(eqn)
return [t for t, (_, keep) in zip(out_tracers, out_type) if keep]

def process_map(self, map_primitive, f, tracers, params):
def process_map(self, map_primitive, f: lu.WrappedFun,
tracers: Sequence[core.Tracer], params):
tracers = map(self.to_jaxpr_tracer, tracers)
in_avals = [t.aval for t in tracers]
axis_name, axis_size = params['axis_name'], params['axis_size']
Expand All @@ -1949,8 +1950,7 @@ def process_map(self, map_primitive, f, tracers, params):
for a, in_axis in zip(in_avals, params['in_axes'])]
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
f, reduced_in_avals,
debug_info=tracing_debug_info_final(f, map_primitive.name))
f, reduced_in_avals, f.debug_info)
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
if ordered_effects:
raise ValueError("Ordered effects not supported for "
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ class ParallelCallableInfo:
in_axes: Iterable[int | None]
out_axes_thunk: Callable[[], Sequence[int | None]]
avals: Sequence[core.AbstractValue]
debug_info: api_util.TracingDebugInfo

@cached_property
def local_devices(self):
Expand Down Expand Up @@ -722,8 +723,8 @@ def stage_parallel_callable(
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
fun, sharded_avals, pe.tracing_debug_info_final(fun, "pmap"))
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, orig_fun.debug_info)
fun, sharded_avals, pci.debug_info)
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, pci.debug_info)

assert len(out_sharded_avals) == len(pci.out_axes), (
len(out_sharded_avals), len(pci.out_axes))
Expand Down Expand Up @@ -757,7 +758,7 @@ def get_pmap_jaxpr(

pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals)
in_axes, out_axes_thunk, avals, fun.debug_info)
with core.extend_axis_env_nd([(axis_name, axis_size)]):
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
Expand Down
12 changes: 9 additions & 3 deletions jax/_src/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def trans1(static_arg, *dynamic_args, **kwargs):
from jax._src import config
from jax._src import core
from jax._src import traceback_util
from jax._src.util import curry, cache_clearing_funs
from jax._src.util import curry, cache_clearing_funs, HashableFunction


traceback_util.register_exclusion(__file__)
Expand Down Expand Up @@ -175,11 +175,11 @@ def wrap(self, gen, gen_static_args,
if out_store is None:
return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args),
((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params, None, None)
(out_store,) + self.stores, self.params, None, self.debug_info)
else:
return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args),
((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params, None, None)
(out_store,) + self.stores, self.params, None, self.debug_info)

def populate_stores(self, stores):
"""Copy the values from the `stores` into `self.stores`."""
Expand Down Expand Up @@ -282,6 +282,12 @@ def from_jaxpr(cls, jaxpr: core.ClosedJaxpr) -> TracingDebugInfo | None:
jaxpr_dbg.arg_names,
lambda: jaxpr_dbg.result_paths)

def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]]
) -> TracingDebugInfo:
assert self.result_paths_thunk is None
return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk,
closure=()))

def wrap_init(f: Callable, params=None, *,
debug_info: TracingDebugInfo | None = None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ def _infer_params_impl(

f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
dbg = dbg and dbg.add_result_paths(result_paths_thunk=res_paths)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
del args

Expand Down Expand Up @@ -1293,7 +1294,7 @@ def _create_pjit_jaxpr(
in_type: core.InputType | Sequence[core.AbstractValue],
attr_data: int,
debug_info: lu.TracingDebugInfo,
out_paths: Callable,
result_paths: Callable,
ignored_inline: IgnoreKey
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
Expand All @@ -1305,19 +1306,18 @@ def _create_pjit_jaxpr(
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
pe_debug = debug_info and pe.tracing_debug_info_final(fun, debug_info.traced_for)
if config.dynamic_shapes.value:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug)
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=debug_info)
attrs_tracked = []
else:
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
fun, in_type, debug_info=pe_debug)
fun, in_type, debug_info=debug_info)
# assert attr_data is sentinel or attr_data matches attrs_tracked

# TODO(dougalm,mattjj): enable debug info with attrs_tracked
if not config.dynamic_shapes.value and not attrs_tracked:
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, out_paths())
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, result_paths())

if config.debug_key_reuse.value:
# Import here to avoid circular imports
Expand Down
13 changes: 7 additions & 6 deletions tests/debug_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,7 @@ def my_g(a, d=1):
],
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_g, arg_names=('a',)",
# TODO(necula): bad arg name
"traced_for=jit, fun=my_f, arg_names=('args[0]',)"
"traced_for=jit, fun=my_f, arg_names=('a',)"
])


Expand Down Expand Up @@ -1044,7 +1043,7 @@ def my_f(my_x):
"None", # TODO(necula): missing tracer debug info
],
expected_tracer_debug_infos=[
"traced_for=xla_pmap, fun=my_f, arg_names=('my_x',)",
"traced_for=pmap, fun=my_f, arg_names=('my_x',)",
],
check_lowering=False, # TODO(necula): warning during lowering
)
Expand Down Expand Up @@ -1076,7 +1075,8 @@ def my_f(x):
# TODO(necula): some Jaxprs without debug info
'None'],
expected_tracer_debug_infos=[
"traced_for=custom_dce, fun=my_g, arg_names=('x',)"
# TODO(necula): bad arg_names
"traced_for=custom_dce_rule, fun=my_g_dce, arg_names=('args[1]',)",
])

def test_custom_dce_consts(self):
Expand All @@ -1087,7 +1087,7 @@ def my_f(x):
return np.eye(1) * jnp.sin(x), jnp.cos(x)

@my_f.def_dce
def rule(used_outs, x):
def my_rule(used_outs, x):
leaked_tracers.append(x)
return (
np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None,
Expand All @@ -1103,7 +1103,8 @@ def rule(used_outs, x):
# TODO(necula): some Jaxprs without debug info
'None'],
expected_tracer_debug_infos=[
"traced_for=custom_dce, fun=my_f, arg_names=('x',)"
# TODO(necula): bad arg_names
"traced_for=custom_dce_rule, fun=my_rule, arg_names=('args[0]',)",
])

def test_custom_linear_solve_complex(self):
Expand Down

0 comments on commit 4c3bbcf

Please sign in to comment.