diff --git a/jax/_src/api.py b/jax/_src/api.py index e4ead66236b9..cb621a7c1cd5 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -61,7 +61,7 @@ flatten_axes, donation_vector, rebase_donate_argnums, _ensure_index, _ensure_index_tuple, apply_flat_fun_nokwargs, check_callable, tracing_debug_info, - result_paths, flat_out_axes, debug_info_final) + result_paths, flat_out_axes) from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc @@ -1392,15 +1392,15 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str, return global_axis_size -def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, +def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, in_devices, backend_name, axis_size, args, kwargs): if in_devices is not None and len(in_devices) == 0: raise ValueError("'devices' argument to pmap must be non-empty, or None.") dbg = tracing_debug_info( - 'pmap', fun, args, kwargs, - static_argnums=static_broadcasted_tuple) + "pmap", fun, args, kwargs, + static_argnums=static_broadcasted_tuple) f = lu.wrap_init(fun) if static_broadcasted_tuple: @@ -1450,9 +1450,10 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap") f, res_paths = result_paths(f) + dbg = dbg.add_result_paths(res_paths) + f = lu.add_debug_info(f, dbg) f, out_axes_thunk = flat_out_axes(f, out_axes) flat_fun, out_tree = flatten_fun(f, in_tree) - flat_fun = debug_info_final(flat_fun, dbg, res_paths) is_explicit_global_axis_size = axis_size is not None global_axis_size = _get_global_axis_size(local_axis_size, in_devices, diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 5d7e4b3d8c2f..5832db24536f 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -705,6 +705,7 @@ def result_paths(_fun, _store, *args, **kwargs): _store.store([keystr(path) for path, _ in generate_key_paths(ans)]) return ans +# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr def add_jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None, result_paths: tuple[str, ...] | None = None, @@ -712,7 +713,8 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr, """Add debug info to jaxpr, given trace-time debug info and result paths.""" if trace_debug is None: return jaxpr - assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None) + + # assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None) if result_paths is None: result_paths = trace_debug.result_paths_thunk() # type: ignore debug_info = core.JaxprDebugInfo( diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 8f3e4c3fbd16..f4fe0edbf274 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -27,6 +27,7 @@ from jax.experimental import shard_map from jax._src import api +from jax._src import api_util from jax._src import ad_checkpoint from jax._src import linear_util as lu from jax._src import config @@ -39,7 +40,6 @@ from jax._src import traceback_util from jax._src import tree_util as jtu from jax._src.ad_util import SymbolicZero -from jax._src.api_util import flatten_fun from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -1202,8 +1202,10 @@ def checked_fun(*args, **kwargs): in_tree = jtu.tree_structure(((), {})) closed_f = lambda: f(*args, **kwargs) # stage: - fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree) - debug = pe.tracing_debug_info(closed_f, in_tree, out_tree, False, 'checkify') + debug = api_util.tracing_debug_info("checkify", f, args, kwargs) + fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f, + debug_info=debug), + in_tree) jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug) jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_)) # checkify: diff --git a/jax/_src/custom_dce.py b/jax/_src/custom_dce.py index 152fee605a92..be07bc55b3aa 100644 --- a/jax/_src/custom_dce.py +++ b/jax/_src/custom_dce.py @@ -128,7 +128,11 @@ def __call__(self, *args, **kwargs): ) rule_name = util.fun_name(self.dce_rule) debug = api_util.tracing_debug_info("custom_dce", self.fun, - args, {}, static_argnums=self.static_argnums) + args, {}, + static_argnums=self.static_argnums) + debug_rule = api_util.tracing_debug_info("custom_dce_rule", self.dce_rule, + args, {}, + static_argnums=self.static_argnums) args = api_util.resolve_kwargs(self.fun, args, kwargs) if self.static_argnums: static_argnums = set(self.static_argnums) @@ -171,11 +175,8 @@ def dce_jaxpr_thunk( out_avals, ) assert self.dce_rule is not None - debug = pe.tracing_debug_info( - self.dce_rule, in_tree, rule_out_tree, False, "custom_dce_rule" - ) dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic( - flat_rule, in_avals, debug + flat_rule, in_avals, debug_rule ) # This second round of DCE is used to work out which inputs are actually @@ -190,7 +191,7 @@ def dce_jaxpr_thunk( return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_rule) closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) out_avals = closed_call.out_avals out_flat = custom_dce_p.bind( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index dd81d8d4a552..897faa7a2dc9 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1940,7 +1940,8 @@ def process_call(self, call_primitive, f: lu.WrappedFun, self.frame.add_eqn(eqn) return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] - def process_map(self, map_primitive, f, tracers, params): + def process_map(self, map_primitive, f: lu.WrappedFun, + tracers: Sequence[core.Tracer], params): tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] axis_name, axis_size = params['axis_name'], params['axis_size'] @@ -1949,8 +1950,7 @@ def process_map(self, map_primitive, f, tracers, params): for a, in_axis in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]): jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic( - f, reduced_in_avals, - debug_info=tracing_debug_info_final(f, map_primitive.name)) + f, reduced_in_avals, f.debug_info) ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) if ordered_effects: raise ValueError("Ordered effects not supported for " diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index fd9c05d0cb74..2181644c44e8 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -652,6 +652,7 @@ class ParallelCallableInfo: in_axes: Iterable[int | None] out_axes_thunk: Callable[[], Sequence[int | None]] avals: Sequence[core.AbstractValue] + debug_info: api_util.TracingDebugInfo @cached_property def local_devices(self): @@ -722,8 +723,8 @@ def stage_parallel_callable( "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic( - fun, sharded_avals, pe.tracing_debug_info_final(fun, "pmap")) - jaxpr = api_util.add_jaxpr_debug_info(jaxpr, orig_fun.debug_info) + fun, sharded_avals, pci.debug_info) + jaxpr = api_util.add_jaxpr_debug_info(jaxpr, pci.debug_info) assert len(out_sharded_avals) == len(pci.out_axes), ( len(out_sharded_avals), len(pci.out_axes)) @@ -757,7 +758,7 @@ def get_pmap_jaxpr( pci = ParallelCallableInfo( name, backend, axis_name, axis_size, global_axis_size, devices, - in_axes, out_axes_thunk, avals) + in_axes, out_axes_thunk, avals, fun.debug_info) with core.extend_axis_env_nd([(axis_name, axis_size)]): jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index e1f5efd70121..9549fe902b1a 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -71,7 +71,7 @@ def trans1(static_arg, *dynamic_args, **kwargs): from jax._src import config from jax._src import core from jax._src import traceback_util -from jax._src.util import curry, cache_clearing_funs +from jax._src.util import curry, cache_clearing_funs, HashableFunction traceback_util.register_exclusion(__file__) @@ -175,11 +175,11 @@ def wrap(self, gen, gen_static_args, if out_store is None: return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args), ((gen, gen_static_args),) + self.transforms, - (out_store,) + self.stores, self.params, None, None) + (out_store,) + self.stores, self.params, None, self.debug_info) else: return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args), ((gen, gen_static_args),) + self.transforms, - (out_store,) + self.stores, self.params, None, None) + (out_store,) + self.stores, self.params, None, self.debug_info) def populate_stores(self, stores): """Copy the values from the `stores` into `self.stores`.""" @@ -282,6 +282,12 @@ def from_jaxpr(cls, jaxpr: core.ClosedJaxpr) -> TracingDebugInfo | None: jaxpr_dbg.arg_names, lambda: jaxpr_dbg.result_paths) + def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]] + ) -> TracingDebugInfo: + assert self.result_paths_thunk is None + return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk, + closure=())) + def wrap_init(f: Callable, params=None, *, debug_info: TracingDebugInfo | None = None) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 09f3fb892523..37a8e39b4001 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -569,6 +569,7 @@ def _infer_params_impl( f = lu.wrap_init(fun) f, res_paths = result_paths(f) + dbg = dbg and dbg.add_result_paths(result_paths_thunk=res_paths) f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True) del args @@ -1293,7 +1294,7 @@ def _create_pjit_jaxpr( in_type: core.InputType | Sequence[core.AbstractValue], attr_data: int, debug_info: lu.TracingDebugInfo, - out_paths: Callable, + result_paths: Callable, ignored_inline: IgnoreKey ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: @@ -1305,19 +1306,18 @@ def _create_pjit_jaxpr( with dispatch.log_elapsed_time( "Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): - pe_debug = debug_info and pe.tracing_debug_info_final(fun, debug_info.traced_for) if config.dynamic_shapes.value: jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( - lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug) + lu.annotate(fun, cast(core.InputType, in_type)), debug_info=debug_info) attrs_tracked = [] else: jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( - fun, in_type, debug_info=pe_debug) + fun, in_type, debug_info=debug_info) # assert attr_data is sentinel or attr_data matches attrs_tracked # TODO(dougalm,mattjj): enable debug info with attrs_tracked if not config.dynamic_shapes.value and not attrs_tracked: - jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, out_paths()) + jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, result_paths()) if config.debug_key_reuse.value: # Import here to avoid circular imports diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index ab95b03e8758..0adbb3ca8cd9 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -630,8 +630,7 @@ def my_g(a, d=1): ], expected_tracer_debug_infos=[ "traced_for=jit, fun=my_g, arg_names=('a',)", - # TODO(necula): bad arg name - "traced_for=jit, fun=my_f, arg_names=('args[0]',)" + "traced_for=jit, fun=my_f, arg_names=('a',)" ]) @@ -1044,7 +1043,7 @@ def my_f(my_x): "None", # TODO(necula): missing tracer debug info ], expected_tracer_debug_infos=[ - "traced_for=xla_pmap, fun=my_f, arg_names=('my_x',)", + "traced_for=pmap, fun=my_f, arg_names=('my_x',)", ], check_lowering=False, # TODO(necula): warning during lowering ) @@ -1076,7 +1075,8 @@ def my_f(x): # TODO(necula): some Jaxprs without debug info 'None'], expected_tracer_debug_infos=[ - "traced_for=custom_dce, fun=my_g, arg_names=('x',)" + # TODO(necula): bad arg_names + "traced_for=custom_dce_rule, fun=my_g_dce, arg_names=('args[1]',)", ]) def test_custom_dce_consts(self): @@ -1087,7 +1087,7 @@ def my_f(x): return np.eye(1) * jnp.sin(x), jnp.cos(x) @my_f.def_dce - def rule(used_outs, x): + def my_rule(used_outs, x): leaked_tracers.append(x) return ( np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None, @@ -1103,7 +1103,8 @@ def rule(used_outs, x): # TODO(necula): some Jaxprs without debug info 'None'], expected_tracer_debug_infos=[ - "traced_for=custom_dce, fun=my_f, arg_names=('x',)" + # TODO(necula): bad arg_names + "traced_for=custom_dce_rule, fun=my_rule, arg_names=('args[0]',)", ]) def test_custom_linear_solve_complex(self):