diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 5930d214904f..b5bf9aff467b 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -427,16 +427,13 @@ def to_block_mapping( f"Block spec for {origin} has block_shape: {block_aval.shape}" ) + fake_index_map_args, fake_index_map_kwargs = \ + index_map_tree.unflatten([False] * index_map_tree.num_leaves) + debug = api_util.tracing_debug_info("pallas_call index_map", + index_map_func, fake_index_map_args, + fake_index_map_kwargs) flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun( - lu.wrap_init(index_map_func), index_map_tree - ) - debug = pe.tracing_debug_info( - index_map_func, - index_map_tree, - index_map_out_tree_thunk, - False, - "pallas_call index_map", - ) + lu.wrap_init(index_map_func, debug_info=debug), index_map_tree) index_map_src_info = NameAndSrcInfo.from_pallas_call( None, debug and debug.func_src_info # type: ignore ) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index fe1b25fe71b7..097654fba337 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1345,8 +1345,8 @@ def _ensure_2d_error_shape(arg): jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals) wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(checked_kernel_fn), jaxpr_in_tree) - debug = pe.tracing_debug_info( - checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas") + debug = api_util.tracing_debug_info("checkify_pallas", checked_kernel_fn, + retrace_in_avals, {}) with pallas_core.tracing_grid_env(grid_mapping.grid, ()): final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( wrapped_kernel_with_err, jaxpr_flat_avals, debug) @@ -1415,7 +1415,8 @@ def _trace_kernel_to_jaxpr( wrapped_kernel_fun = primitives.wrap_with_transforms( wrapped_kernel_fun, kernel_in_transforms ) - debug = pe.tracing_debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call") + fake_kernel_args = kernel_in_tree.unflatten(kernel_avals) + debug = api_util.tracing_debug_info("pallas_call", fun, fake_kernel_args, {}) with grid_mapping.trace_env(): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun, kernel_avals, debug) diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 39baa3f4e18f..ab95b03e8758 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -1191,11 +1191,7 @@ def my_index_map(i, j): # TODO(necula): missing Jaxpr debug info "None"], expected_tracer_debug_infos=[ - # TODO(necula): arg_names seem to be wrong - # One tracer from every index map - "traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')", - "traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')", - "traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')", + "traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i', 'j')", "traced_for=pallas_call, fun=my_kernel, arg_names=('x_ref', 'y_ref', 'o_ref')", ], check_lowering=False, # We need interpret mode on CPU. TODO(necula)