Skip to content

Commit

Permalink
update based on feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed May 14, 2024
1 parent 4d59e43 commit dcab385
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 12 deletions.
3 changes: 2 additions & 1 deletion benchmarks/benchmark_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _is_available(self, experiment_config):
exclude_tags=()):
return False

# torch_xla2 doesn't run with dynamo.
if cfg_dynamo is not None and self._args.torch_xla2:
return False

Expand Down Expand Up @@ -151,7 +152,7 @@ def update_process_env(self, process_env):
process_env.pop("XLA_FLAGS", None)

if self.use_torch_xla2:
process_env.pop("JAX_PLATFORMS", self.accelerator.upper())
process_env["JAX_PLATFORMS"] = self.accelerator.lower()

if self.xla == "PJRT":
process_env["PJRT_DEVICE"] = self.accelerator.upper()
Expand Down
9 changes: 5 additions & 4 deletions benchmarks/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,13 @@ def prepare_for_experiment(self, dynamo_compilation_opts):
weights)
# map the module function to jax_func
self.module = lambda x: jax_func(weights, x)
self.example_inputs = move_to_device(
self.example_inputs, device, self.benchmark_experiment.use_torch_xla2)
else:
self.module = self.module.to(self.device)

self.example_inputs = move_to_device(
self.example_inputs, self.device,
self.benchmark_experiment.use_torch_xla2)
self.example_inputs = move_to_device(
self.example_inputs, self.device,
self.benchmark_experiment.use_torch_xla2)

if self.benchmark_experiment.dynamo:
compilation_opts = dynamo_compilation_opts.copy()
Expand Down
15 changes: 9 additions & 6 deletions benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,11 @@ def run_once_and_gather_metrics(self, benchmark_experiment, benchmark_model,

# Reset state and sync.
reset_rng_state(benchmark_experiment)
self._mark_step(benchmark_experiment, inputs_list[0])
if self._args.use_torch_xla2:
for inputs in inputs_list:
self._mark_step(benchmark_experiment, inputs)
else:
self._mark_step(benchmark_experiment)
self._synchronize(benchmark_experiment)
met.clear_all()
dynamo_utils.counters.clear()
Expand All @@ -307,7 +311,7 @@ def loop(pytorch_profile=None, iter_fn=None):
total_timing += timing

# Mark step.
self._mark_step(benchmark_experiment, inputs_list[i])
self._mark_step(benchmark_experiment, output)
if pytorch_profile is not None:
pytorch_profile.step()

Expand Down Expand Up @@ -414,11 +418,11 @@ def _prepare_inputs(self, example_inputs, should_randomize_input):
inputs_list.append(inputs)
return inputs_list

def _mark_step(self, benchmark_experiment, tensor_to_check=None):
def _mark_step(self, benchmark_experiment, tensors_to_check=None):
if benchmark_experiment.xla:
if benchmark_experiment.use_torch_xla2:
assert tensor_to_check is not None, "torch_xla2 requires input tensor to block_until_ready"
for t in tensor_to_check:
assert tensors_to_check is not None, "torch_xla2 requires input tensor to block_until_ready"
for t in tensors_to_check:
t.block_until_ready()
else:
xm.mark_step()
Expand Down Expand Up @@ -867,7 +871,6 @@ def __str__(self):
parser.add_argument(
"--torch-xla2",
action="store_true",
default=False,
help="Choose to use torch_xla2 or not.",
)
parser.add_argument(
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def move_to_device(item, device, use_torch_xla2: bool = False):
if use_torch_xla2:
import torch_xla2
import jax
move_to_device_func = lambda t: jax.device_put(torch_xla2.tensor.t2j(t))
move_to_device_func = lambda t: jax.device_put(
torch_xla2.tensor.t2j(t), device)
else:
move_to_device_func = lambda t: t.to(device)
return pytree.tree_map_only(torch.Tensor, move_to_device_func, item)
Expand Down

0 comments on commit dcab385

Please sign in to comment.