Skip to content

Commit

Permalink
add extract_jax for torch_xla2
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed May 15, 2024
1 parent be40604 commit a6ba2c1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
3 changes: 2 additions & 1 deletion benchmarks/benchmark_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def load_experiment(self, experiment_config):
accelerator=accelerator,
xla=xla,
xla_flags=xla_flags,
torch_xla2=self._args.torch_xla2,
torch_xla2=torch_xla2,
dynamo=dynamo,
test=test,
batch_size=batch_size)
Expand Down Expand Up @@ -200,6 +200,7 @@ def to_dict(self):
d["accelerator_model"] = self.accelerator_model
d["xla"] = self.xla
d["xla_flags"] = self.xla_flags
d["torch_xla2"] = self.torch_xla2
d["dynamo"] = self.dynamo
d["test"] = self.test
d["batch_size"] = self.batch_size
Expand Down
21 changes: 11 additions & 10 deletions benchmarks/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,25 +123,26 @@ def prepare_for_experiment(self, dynamo_compilation_opts):
raise NotImplementedError

if self.benchmark_experiment.torch_xla2:
# for torch_xla2, we export model to FX graph and move weights to JAX device
import torch_xla2.export
import torch_xla2
import jax
import jax.numpy as jnp
if benchmark_experiment.torch_xla2 == 'torch_export':
if self.benchmark_experiment.torch_xla2 == 'torch_export':
# for torch_xla2, we export model to FX graph and move weights to JAX device
exported = torch.export.export(self.module, self.example_inputs)
weights, jax_func = torch_xla2.export.exported_program_to_jax(exported)
elif benchmark_experiment.torch_xla2 == 'extract_jax':
jax_func = jax.jit(jax_func)
device = jax.devices()[0]
weights = pytree.tree_map_only(jnp.ndarray,
lambda x: jax.device_put(x, device),
weights)
elif self.benchmark_experiment.torch_xla2 == 'extract_jax':
weights, jax_func = torch_xla2.extract_jax(self.module)
jax_func = jax.jit(jax_func)
else:
raise ValueError("torch_xla2 option unavailable")
jax_func = jax.jit(jax_func)
device = jax.devices()[0]
weights = pytree.tree_map_only(jnp.ndarray,
lambda x: jax.device_put(x, device),
weights)
# map the module function to jax_func
self.module = lambda x: jax_func(weights, x)

self.module = lambda x: jax_func(weights, (x,))
self.example_inputs = move_to_device(self.example_inputs, device,
self.benchmark_experiment.torch_xla2)
else:
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def run_once_and_gather_metrics(self, benchmark_experiment, benchmark_model,

# Reset state and sync.
reset_rng_state(benchmark_experiment)
if self._args.torch_xla2:
if benchmark_experiment.torch_xla2:
for inputs in inputs_list:
self._mark_step(benchmark_experiment, inputs)
else:
Expand Down

0 comments on commit a6ba2c1

Please sign in to comment.