Skip to content

Commit

Permalink
[better_errors] Refactor more uses of pe.tracing_debug_info (part 2)
Browse files Browse the repository at this point in the history
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 2 of a series (following #26097) for Pallas.
  • Loading branch information
gnecula committed Jan 26, 2025
1 parent b004848 commit cedb516
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 17 deletions.
15 changes: 6 additions & 9 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,16 +427,13 @@ def to_block_mapping(
f"Block spec for {origin} has block_shape: {block_aval.shape}"
)

fake_index_map_args, fake_index_map_kwargs = \
index_map_tree.unflatten([False] * index_map_tree.num_leaves)
debug = api_util.tracing_debug_info("pallas_call index_map",
index_map_func, fake_index_map_args,
fake_index_map_kwargs)
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(index_map_func), index_map_tree
)
debug = pe.tracing_debug_info(
index_map_func,
index_map_tree,
index_map_out_tree_thunk,
False,
"pallas_call index_map",
)
lu.wrap_init(index_map_func, debug_info=debug), index_map_tree)
index_map_src_info = NameAndSrcInfo.from_pallas_call(
None, debug and debug.func_src_info # type: ignore
)
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,8 +1345,8 @@ def _ensure_2d_error_shape(arg):
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
debug = pe.tracing_debug_info(
checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas")
debug = api_util.tracing_debug_info("checkify_pallas", checked_kernel_fn,
retrace_in_avals, {})
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
wrapped_kernel_with_err, jaxpr_flat_avals, debug)
Expand Down Expand Up @@ -1415,7 +1415,8 @@ def _trace_kernel_to_jaxpr(
wrapped_kernel_fun = primitives.wrap_with_transforms(
wrapped_kernel_fun, kernel_in_transforms
)
debug = pe.tracing_debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
debug = api_util.tracing_debug_info("pallas_call", fun, fake_kernel_args, {})
with grid_mapping.trace_env():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
kernel_avals, debug)
Expand Down
6 changes: 1 addition & 5 deletions tests/debug_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,11 +1191,7 @@ def my_index_map(i, j):
# TODO(necula): missing Jaxpr debug info
"None"],
expected_tracer_debug_infos=[
# TODO(necula): arg_names seem to be wrong
# One tracer from every index map
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')",
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')",
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')",
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i', 'j')",
"traced_for=pallas_call, fun=my_kernel, arg_names=('x_ref', 'y_ref', 'o_ref')",
],
check_lowering=False, # We need interpret mode on CPU. TODO(necula)
Expand Down

0 comments on commit cedb516

Please sign in to comment.