Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[better_errors] Refactor more uses of pe.tracing_debug_info (part 2) #26099

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jan 25, 2025

We replace uses of pe.tracing_debug_info with with api_util.tracing_debug_info,
which uses the actual args and kwargs, instead of in_tree to manufacture fake
args and kwargs. This ends up being more accurate, especially for arg_names;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in WrappedFun and
Jaxpr.

This is part 2 of a series (following #26097) for Pallas.

@gnecula gnecula self-assigned this Jan 25, 2025
@gnecula gnecula force-pushed the debug_info_no_pe_debug_info_2 branch from ef783cd to 1203839 Compare January 26, 2025 06:14
@gnecula gnecula changed the title [better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 2) [better_errors] Refactor more uses of pe.tracing_debug_info (part 2) Jan 26, 2025
@gnecula gnecula force-pushed the debug_info_no_pe_debug_info_2 branch from 1203839 to cedb516 Compare January 26, 2025 08:03
gnecula added a commit to gnecula/jax that referenced this pull request Jan 26, 2025
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following jax-ml#26097, jax-ml#26099) for jit, pmap, checkify, custom_dce.
gnecula added a commit to gnecula/jax that referenced this pull request Jan 26, 2025
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following jax-ml#26097, jax-ml#26099) for jit, pmap, checkify, custom_dce.
@gnecula gnecula added the pull ready Ready for copybara import and testing label Jan 26, 2025
… (part 1)

We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.

This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 2 of a series (following jax-ml#26097) for Pallas.
@gnecula gnecula force-pushed the debug_info_no_pe_debug_info_2 branch from cedb516 to 09fe6b1 Compare January 27, 2025 10:20
gnecula added a commit to gnecula/jax that referenced this pull request Jan 27, 2025
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following jax-ml#26097, jax-ml#26099) for jit, pmap, checkify, custom_dce.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant