Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[better_errors] Refactor more uses of pe.tracing_debug_info (part 1) #26097

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jan 25, 2025

We replace uses of pe.tracing_debug_info with with api_util.tracing_debug_info,
which uses the actual args and kwargs, instead of in_tree to manufacture fake
args and kwargs. This ends up being more accurate, especially for arg_names;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in WrappedFun and
Jaxpr.

This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.

@gnecula gnecula force-pushed the debug_info_no_pe_debug_info branch from 92ad1a9 to 27b51e5 Compare January 25, 2025 14:56
@gnecula gnecula self-assigned this Jan 25, 2025
@gnecula gnecula added the pull ready Ready for copybara import and testing label Jan 25, 2025
@gnecula gnecula force-pushed the debug_info_no_pe_debug_info branch 2 times, most recently from 7dc9bb6 to f72f39e Compare January 25, 2025 15:06
@gnecula gnecula changed the title [better_errors] Refactor more uses of partial_eval.tracing_debug_info [better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1) Jan 25, 2025
@gnecula gnecula force-pushed the debug_info_no_pe_debug_info branch 2 times, most recently from b5e69d6 to ce5ba76 Compare January 25, 2025 17:07
gnecula added a commit to gnecula/jax that referenced this pull request Jan 25, 2025
… (part 2)

We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for arg_names.

This is part 2 of a series (following jax-ml#26097) for Pallas.
gnecula added a commit to gnecula/jax that referenced this pull request Jan 25, 2025
… (part 3)

We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for arg_names.

This is part 3 of a series (following jax-ml#26097, xxx) for jit, checkify, custom_dce.
@gnecula gnecula force-pushed the debug_info_no_pe_debug_info branch from ce5ba76 to 21a82e6 Compare January 26, 2025 06:13
gnecula added a commit to gnecula/jax that referenced this pull request Jan 26, 2025
… (part 2)

We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for arg_names.

This is part 2 of a series (following jax-ml#26097) for Pallas.
@gnecula gnecula force-pushed the debug_info_no_pe_debug_info branch 3 times, most recently from 2029c18 to b004848 Compare January 26, 2025 07:57
@gnecula gnecula requested review from mattjj and dfm January 26, 2025 07:57
@gnecula gnecula changed the title [better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1) [better_errors] Refactor more uses of pe.tracing_debug_info (part 1) Jan 26, 2025
gnecula added a commit to gnecula/jax that referenced this pull request Jan 26, 2025
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 2 of a series (following jax-ml#26097) for Pallas.
gnecula added a commit to gnecula/jax that referenced this pull request Jan 26, 2025
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following jax-ml#26097, jax-ml#26099) for jit, pmap, checkify, custom_dce.
gnecula added a commit to gnecula/jax that referenced this pull request Jan 26, 2025
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following jax-ml#26097, jax-ml#26099) for jit, pmap, checkify, custom_dce.
@gnecula gnecula force-pushed the debug_info_no_pe_debug_info branch 6 times, most recently from c8a6053 to d833b03 Compare January 27, 2025 10:08
gnecula added a commit to gnecula/jax that referenced this pull request Jan 27, 2025
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 2 of a series (following jax-ml#26097) for Pallas.
gnecula added a commit to gnecula/jax that referenced this pull request Jan 27, 2025
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following jax-ml#26097, jax-ml#26099) for jit, pmap, checkify, custom_dce.
… (part 1)

We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.

This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
@gnecula gnecula force-pushed the debug_info_no_pe_debug_info branch from d833b03 to 7361d17 Compare January 27, 2025 10:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant