diff --git a/jax/_src/api.py b/jax/_src/api.py
index e4ead66236b9..cb621a7c1cd5 100644
--- a/jax/_src/api.py
+++ b/jax/_src/api.py
@@ -61,7 +61,7 @@
     flatten_axes, donation_vector,
     rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
     apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
-    result_paths, flat_out_axes, debug_info_final)
+    result_paths, flat_out_axes)
 from jax._src.lax import lax as lax_internal
 from jax._src.lib import jax_jit
 from jax._src.lib import xla_client as xc
@@ -1392,15 +1392,15 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
   return global_axis_size
 
 
-def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
+def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
                   donate_tuple, in_devices, backend_name,
                   axis_size, args, kwargs):
   if in_devices is not None and len(in_devices) == 0:
     raise ValueError("'devices' argument to pmap must be non-empty, or None.")
 
   dbg = tracing_debug_info(
-      'pmap', fun, args, kwargs,
-       static_argnums=static_broadcasted_tuple)
+      "pmap", fun, args, kwargs,
+      static_argnums=static_broadcasted_tuple)
 
   f = lu.wrap_init(fun)
   if static_broadcasted_tuple:
@@ -1450,9 +1450,10 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
   local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
 
   f, res_paths = result_paths(f)
+  dbg = dbg.add_result_paths(res_paths)
+  f = lu.add_debug_info(f, dbg)
   f, out_axes_thunk = flat_out_axes(f, out_axes)
   flat_fun, out_tree = flatten_fun(f, in_tree)
-  flat_fun = debug_info_final(flat_fun, dbg, res_paths)
 
   is_explicit_global_axis_size = axis_size is not None
   global_axis_size = _get_global_axis_size(local_axis_size, in_devices,
diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py
index 5d7e4b3d8c2f..5832db24536f 100644
--- a/jax/_src/api_util.py
+++ b/jax/_src/api_util.py
@@ -705,6 +705,7 @@ def result_paths(_fun, _store, *args, **kwargs):
   _store.store([keystr(path) for path, _ in generate_key_paths(ans)])
   return ans
 
+# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
 def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
                          trace_debug: TracingDebugInfo | None,
                          result_paths: tuple[str, ...] | None = None,
@@ -712,7 +713,8 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
   """Add debug info to jaxpr, given trace-time debug info and result paths."""
   if trace_debug is None:
     return jaxpr
-  assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
+
+  # assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
   if result_paths is None:
     result_paths = trace_debug.result_paths_thunk()  # type: ignore
   debug_info = core.JaxprDebugInfo(
diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py
index 8f3e4c3fbd16..f4fe0edbf274 100644
--- a/jax/_src/checkify.py
+++ b/jax/_src/checkify.py
@@ -27,6 +27,7 @@
 
 from jax.experimental import shard_map
 from jax._src import api
+from jax._src import api_util
 from jax._src import ad_checkpoint
 from jax._src import linear_util as lu
 from jax._src import config
@@ -39,7 +40,6 @@
 from jax._src import traceback_util
 from jax._src import tree_util as jtu
 from jax._src.ad_util import SymbolicZero
-from jax._src.api_util import flatten_fun
 from jax._src.interpreters import ad
 from jax._src.interpreters import batching
 from jax._src.interpreters import mlir
@@ -1202,8 +1202,10 @@ def checked_fun(*args, **kwargs):
     in_tree = jtu.tree_structure(((), {}))
     closed_f = lambda: f(*args, **kwargs)
     # stage:
-    fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
-    debug = pe.tracing_debug_info(closed_f, in_tree, out_tree, False, 'checkify')
+    debug = api_util.tracing_debug_info("checkify", f, args, kwargs)
+    fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
+                                                       debug_info=debug),
+                                          in_tree)
     jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
     jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
     # checkify:
diff --git a/jax/_src/custom_dce.py b/jax/_src/custom_dce.py
index 152fee605a92..be07bc55b3aa 100644
--- a/jax/_src/custom_dce.py
+++ b/jax/_src/custom_dce.py
@@ -128,7 +128,11 @@ def __call__(self, *args, **kwargs):
       )
     rule_name = util.fun_name(self.dce_rule)
     debug = api_util.tracing_debug_info("custom_dce", self.fun,
-                                        args, {}, static_argnums=self.static_argnums)
+                                        args, {},
+                                        static_argnums=self.static_argnums)
+    debug_rule = api_util.tracing_debug_info("custom_dce_rule", self.dce_rule,
+                                             args, {},
+                                             static_argnums=self.static_argnums)
     args = api_util.resolve_kwargs(self.fun, args, kwargs)
     if self.static_argnums:
       static_argnums = set(self.static_argnums)
@@ -171,11 +175,8 @@ def dce_jaxpr_thunk(
           out_avals,
       )
       assert self.dce_rule is not None
-      debug = pe.tracing_debug_info(
-          self.dce_rule, in_tree, rule_out_tree, False, "custom_dce_rule"
-      )
       dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic(
-          flat_rule, in_avals, debug
+          flat_rule, in_avals, debug_rule
       )
 
       # This second round of DCE is used to work out which inputs are actually
@@ -190,7 +191,7 @@ def dce_jaxpr_thunk(
 
       return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins
 
-    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
+    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_rule)
     closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
     out_avals = closed_call.out_avals
     out_flat = custom_dce_p.bind(
diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py
index dd81d8d4a552..897faa7a2dc9 100644
--- a/jax/_src/interpreters/partial_eval.py
+++ b/jax/_src/interpreters/partial_eval.py
@@ -1940,7 +1940,8 @@ def process_call(self, call_primitive, f: lu.WrappedFun,
     self.frame.add_eqn(eqn)
     return [t for t, (_, keep) in zip(out_tracers, out_type) if keep]
 
-  def process_map(self, map_primitive, f, tracers, params):
+  def process_map(self, map_primitive, f: lu.WrappedFun,
+                  tracers: Sequence[core.Tracer], params):
     tracers = map(self.to_jaxpr_tracer, tracers)
     in_avals = [t.aval for t in tracers]
     axis_name, axis_size = params['axis_name'], params['axis_size']
@@ -1949,8 +1950,7 @@ def process_map(self, map_primitive, f, tracers, params):
                         for a, in_axis in zip(in_avals, params['in_axes'])]
     with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
       jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
-          f, reduced_in_avals,
-          debug_info=tracing_debug_info_final(f, map_primitive.name))
+          f, reduced_in_avals, f.debug_info)
       ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
       if ordered_effects:
         raise ValueError("Ordered effects not supported for "
diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py
index fd9c05d0cb74..2181644c44e8 100644
--- a/jax/_src/interpreters/pxla.py
+++ b/jax/_src/interpreters/pxla.py
@@ -652,6 +652,7 @@ class ParallelCallableInfo:
   in_axes: Iterable[int | None]
   out_axes_thunk: Callable[[], Sequence[int | None]]
   avals: Sequence[core.AbstractValue]
+  debug_info: api_util.TracingDebugInfo
 
   @cached_property
   def local_devices(self):
@@ -722,8 +723,8 @@ def stage_parallel_callable(
         "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
         fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
       jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
-          fun, sharded_avals, pe.tracing_debug_info_final(fun, "pmap"))
-  jaxpr = api_util.add_jaxpr_debug_info(jaxpr, orig_fun.debug_info)
+          fun, sharded_avals, pci.debug_info)
+  jaxpr = api_util.add_jaxpr_debug_info(jaxpr, pci.debug_info)
 
   assert len(out_sharded_avals) == len(pci.out_axes), (
       len(out_sharded_avals), len(pci.out_axes))
@@ -757,7 +758,7 @@ def get_pmap_jaxpr(
 
   pci = ParallelCallableInfo(
       name, backend, axis_name, axis_size, global_axis_size, devices,
-      in_axes, out_axes_thunk, avals)
+      in_axes, out_axes_thunk, avals, fun.debug_info)
   with core.extend_axis_env_nd([(axis_name, axis_size)]):
     jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
   jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py
index e1f5efd70121..9549fe902b1a 100644
--- a/jax/_src/linear_util.py
+++ b/jax/_src/linear_util.py
@@ -71,7 +71,7 @@ def trans1(static_arg, *dynamic_args, **kwargs):
 from jax._src import config
 from jax._src import core
 from jax._src import traceback_util
-from jax._src.util import curry, cache_clearing_funs
+from jax._src.util import curry, cache_clearing_funs, HashableFunction
 
 
 traceback_util.register_exclusion(__file__)
@@ -175,11 +175,11 @@ def wrap(self, gen, gen_static_args,
     if out_store is None:
       return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args),
                         ((gen, gen_static_args),) + self.transforms,
-                        (out_store,) + self.stores, self.params, None, None)
+                        (out_store,) + self.stores, self.params, None, self.debug_info)
     else:
       return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args),
                         ((gen, gen_static_args),) + self.transforms,
-                        (out_store,) + self.stores, self.params, None, None)
+                        (out_store,) + self.stores, self.params, None, self.debug_info)
 
   def populate_stores(self, stores):
     """Copy the values from the `stores` into `self.stores`."""
@@ -282,6 +282,12 @@ def from_jaxpr(cls, jaxpr: core.ClosedJaxpr) -> TracingDebugInfo | None:
                             jaxpr_dbg.arg_names,
                             lambda: jaxpr_dbg.result_paths)
 
+  def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]]
+                       ) -> TracingDebugInfo:
+    assert self.result_paths_thunk is None
+    return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk,
+                                                             closure=()))
+
 def wrap_init(f: Callable, params=None, *,
               debug_info: TracingDebugInfo | None = None) -> WrappedFun:
   """Wraps function `f` as a `WrappedFun`, suitable for transformation."""
diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py
index 09f3fb892523..37a8e39b4001 100644
--- a/jax/_src/pjit.py
+++ b/jax/_src/pjit.py
@@ -569,6 +569,7 @@ def _infer_params_impl(
 
   f = lu.wrap_init(fun)
   f, res_paths = result_paths(f)
+  dbg = dbg and dbg.add_result_paths(result_paths_thunk=res_paths)
   f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
   del args
 
@@ -1293,7 +1294,7 @@ def _create_pjit_jaxpr(
     in_type: core.InputType | Sequence[core.AbstractValue],
     attr_data: int,
     debug_info: lu.TracingDebugInfo,
-    out_paths: Callable,
+    result_paths: Callable,
     ignored_inline: IgnoreKey
 ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
            list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
@@ -1305,19 +1306,18 @@ def _create_pjit_jaxpr(
   with dispatch.log_elapsed_time(
       "Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
       fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
-    pe_debug = debug_info and pe.tracing_debug_info_final(fun, debug_info.traced_for)
     if config.dynamic_shapes.value:
       jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
-          lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug)
+          lu.annotate(fun, cast(core.InputType, in_type)), debug_info=debug_info)
       attrs_tracked = []
     else:
       jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
-          fun, in_type, debug_info=pe_debug)
+          fun, in_type, debug_info=debug_info)
       # assert attr_data is sentinel or attr_data matches attrs_tracked
 
   # TODO(dougalm,mattjj): enable debug info with attrs_tracked
   if not config.dynamic_shapes.value and not attrs_tracked:
-    jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, out_paths())
+    jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, result_paths())
 
   if config.debug_key_reuse.value:
     # Import here to avoid circular imports
diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py
index ab95b03e8758..0adbb3ca8cd9 100644
--- a/tests/debug_info_test.py
+++ b/tests/debug_info_test.py
@@ -630,8 +630,7 @@ def my_g(a, d=1):
         ],
         expected_tracer_debug_infos=[
             "traced_for=jit, fun=my_g, arg_names=('a',)",
-            # TODO(necula): bad arg name
-            "traced_for=jit, fun=my_f, arg_names=('args[0]',)"
+            "traced_for=jit, fun=my_f, arg_names=('a',)"
         ])
 
 
@@ -1044,7 +1043,7 @@ def my_f(my_x):
             "None",  # TODO(necula): missing tracer debug info
         ],
         expected_tracer_debug_infos=[
-            "traced_for=xla_pmap, fun=my_f, arg_names=('my_x',)",
+            "traced_for=pmap, fun=my_f, arg_names=('my_x',)",
         ],
         check_lowering=False,  # TODO(necula): warning during lowering
     )
@@ -1076,7 +1075,8 @@ def my_f(x):
             # TODO(necula): some Jaxprs without debug info
             'None'],
         expected_tracer_debug_infos=[
-            "traced_for=custom_dce, fun=my_g, arg_names=('x',)"
+            # TODO(necula): bad arg_names
+            "traced_for=custom_dce_rule, fun=my_g_dce, arg_names=('args[1]',)",
         ])
 
   def test_custom_dce_consts(self):
@@ -1087,7 +1087,7 @@ def my_f(x):
       return np.eye(1) * jnp.sin(x), jnp.cos(x)
 
     @my_f.def_dce
-    def rule(used_outs, x):
+    def my_rule(used_outs, x):
       leaked_tracers.append(x)
       return (
           np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None,
@@ -1103,7 +1103,8 @@ def rule(used_outs, x):
             # TODO(necula): some Jaxprs without debug info
             'None'],
         expected_tracer_debug_infos=[
-            "traced_for=custom_dce, fun=my_f, arg_names=('x',)"
+            # TODO(necula): bad arg_names
+            "traced_for=custom_dce_rule, fun=my_rule, arg_names=('args[0]',)",
         ])
 
   def test_custom_linear_solve_complex(self):