Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch_xla2 benchmarking using torchbench #7013

Merged
merged 20 commits into from
May 16, 2024
Merged

Support torch_xla2 benchmarking using torchbench #7013

merged 20 commits into from
May 16, 2024

Conversation

zpcore
Copy link
Collaborator

@zpcore zpcore commented May 1, 2024

Summary

Integrate the torch_xla2 testing into the benchmarking.

Details

The PR ports the torch_xla2 running script based on file here into the existing benchmark script.

How to run

To run the torch_xla2 benchmarking, you need to install the torch_xla2 based on instructions experimental/torch_xla2/README.md

Currently torch_xla2 doesn't work with dynamo. We need to append flag --torch-xla2 to switch to torch_xla2, e.g.,

export JAX_PLATFORMS=TPU;
python experiment_runner.py \
--suite-name=torchbench \
--accelerator=tpu \
--progress-bar  \
--xla=PJRT  \
--test=eval \
--filter=dcgan \
--torch-xla2

In practical, we need to make sure JAX version and torch_xla are using the same PJRT version. I did the openxla pin update (backport) for torch_xla in order to use PJRT 0.47 verison.

Sample result

Just tried a simple model dcgan on TPU v5p-8, the torch_xla2 performance is impressive:

benchmark platform torch_xla version backend median_total_time (s) compile_time (s)
dcgan (eval) v5-8 torch_xla LTC 0.0022 1.5729
dcgan (eval) v5-8 torch_xla openxla 0.0005 1.8661
dcgan (eval) v5-8 torch_xla2 jax.jit 0.0003911869862349704 1.464099046002957

@zpcore zpcore marked this pull request as ready for review May 2, 2024 07:46
benchmarks/benchmark_experiment.py Outdated Show resolved Hide resolved
benchmarks/experiment_runner.py Outdated Show resolved Hide resolved
benchmarks/experiment_runner.py Outdated Show resolved Hide resolved
benchmarks/experiment_runner.py Outdated Show resolved Hide resolved
benchmarks/experiment_runner.py Outdated Show resolved Hide resolved
benchmarks/util.py Outdated Show resolved Hide resolved
benchmarks/util.py Outdated Show resolved Hide resolved
benchmarks/util.py Outdated Show resolved Hide resolved
benchmarks/benchmark_experiment.py Outdated Show resolved Hide resolved
benchmarks/benchmark_model.py Outdated Show resolved Hide resolved
@zpcore zpcore merged commit 9e18935 into master May 16, 2024
19 of 20 checks passed
@zpcore zpcore deleted the piz/xla2_bm branch May 16, 2024 00:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants