diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py
index 5930d214904f..b5bf9aff467b 100644
--- a/jax/_src/pallas/core.py
+++ b/jax/_src/pallas/core.py
@@ -427,16 +427,13 @@ def to_block_mapping(
           f"Block spec for {origin} has block_shape: {block_aval.shape}"
       )
 
+    fake_index_map_args, fake_index_map_kwargs = \
+        index_map_tree.unflatten([False] * index_map_tree.num_leaves)
+    debug = api_util.tracing_debug_info("pallas_call index_map",
+                                        index_map_func, fake_index_map_args,
+                                        fake_index_map_kwargs)
     flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
-        lu.wrap_init(index_map_func), index_map_tree
-    )
-    debug = pe.tracing_debug_info(
-        index_map_func,
-        index_map_tree,
-        index_map_out_tree_thunk,
-        False,
-        "pallas_call index_map",
-    )
+      lu.wrap_init(index_map_func, debug_info=debug), index_map_tree)
     index_map_src_info = NameAndSrcInfo.from_pallas_call(
         None, debug and debug.func_src_info  # type: ignore
     )
diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py
index fe1b25fe71b7..097654fba337 100644
--- a/jax/_src/pallas/pallas_call.py
+++ b/jax/_src/pallas/pallas_call.py
@@ -1345,8 +1345,8 @@ def _ensure_2d_error_shape(arg):
   jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
   wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
       lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
-  debug = pe.tracing_debug_info(
-    checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas")
+  debug = api_util.tracing_debug_info("checkify_pallas", checked_kernel_fn,
+                                      retrace_in_avals, {})
   with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
     final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
         wrapped_kernel_with_err, jaxpr_flat_avals, debug)
@@ -1415,7 +1415,8 @@ def _trace_kernel_to_jaxpr(
   wrapped_kernel_fun = primitives.wrap_with_transforms(
       wrapped_kernel_fun, kernel_in_transforms
   )
-  debug = pe.tracing_debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
+  fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
+  debug = api_util.tracing_debug_info("pallas_call", fun, fake_kernel_args, {})
   with grid_mapping.trace_env():
     jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
                                                      kernel_avals, debug)
diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py
index 39baa3f4e18f..ab95b03e8758 100644
--- a/tests/debug_info_test.py
+++ b/tests/debug_info_test.py
@@ -1191,11 +1191,7 @@ def my_index_map(i, j):
             # TODO(necula): missing Jaxpr debug info
             "None"],
         expected_tracer_debug_infos=[
-            # TODO(necula): arg_names seem to be wrong
-            # One tracer from every index map
-            "traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')",
-            "traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')",
-            "traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')",
+            "traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i', 'j')",
             "traced_for=pallas_call, fun=my_kernel, arg_names=('x_ref', 'y_ref', 'o_ref')",
         ],
         check_lowering=False,  # We need interpret mode on CPU. TODO(necula)