Skip to content

Commit

Permalink
import the jax import in class
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed May 17, 2024
1 parent 3ed01aa commit 4aadc07
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def __init__(self, args):
self.model_loader = ModelLoader(self._args)
else:
raise NotImplementedError

if benchmark_experiment.torch_xla2:
import jax
self.jax = jax

self.output_dir = os.path.abspath(self._args.output_dirname)
os.makedirs(self.output_dir, exist_ok=True)
Expand Down Expand Up @@ -420,11 +424,8 @@ def _prepare_inputs(self, example_inputs, should_randomize_input):
def _mark_step(self, benchmark_experiment, tensors_to_check=None):
if benchmark_experiment.xla:
if benchmark_experiment.torch_xla2:
# jax module should be cached and we expect this to cause 0 overhead.
# We didn't import the module globally since torch_xla2 is still in experimental stage.
import jax
assert tensors_to_check is not None, "torch_xla2 requires input tensor to block_until_ready"
jax.block_until_ready(tensors_to_check)
self.jax.block_until_ready(tensors_to_check)
else:
xm.mark_step()

Expand Down

0 comments on commit 4aadc07

Please sign in to comment.