diff --git a/benchmarks/benchmark_experiment.py b/benchmarks/benchmark_experiment.py index a8201a5210d..1b8fa941f25 100644 --- a/benchmarks/benchmark_experiment.py +++ b/benchmarks/benchmark_experiment.py @@ -22,6 +22,8 @@ def list_experiment_configs(self): "xla": [None, "PJRT", "XRT"], "xla_flags": [None], "dynamo": [None, "inductor", "openxla_eval", "openxla"], + "torch_xla2": ["extract_jax", + "torch_export"], # options only apply to torch_xla2 "test": ["eval", "train"], } @@ -30,6 +32,9 @@ def list_experiment_configs(self): config_choices["accelerator"] = list(set(self._args.accelerator)) if self._args.xla: config_choices["xla"] = list(map(parse_none_str, set(self._args.xla))) + if self._args.torch_xla2: + config_choices["torch_xla2"] = list( + map(parse_none_str, set(self._args.torch_xla2))) if self._args.dynamo: config_choices["dynamo"] = list( map(parse_none_str, set(self._args.dynamo))) @@ -72,7 +77,7 @@ def _is_available(self, experiment_config): exclude_tags=()): return False - # torch_xla2 doesn't run with dynamo. + # torch_xla2 doesn't support dynamo at this time. if cfg_dynamo is not None and self._args.torch_xla2: return False @@ -114,34 +119,29 @@ def load_experiment(self, experiment_config): dynamo = experiment_config["dynamo"] test = experiment_config["test"] batch_size = experiment_config.get("batch_size", self._args.batch_size) + torch_xla2 = experiment_config["torch_xla2"] return BenchmarkExperiment( accelerator=accelerator, xla=xla, xla_flags=xla_flags, + torch_xla2=self._args.torch_xla2, dynamo=dynamo, test=test, - batch_size=batch_size, - use_torch_xla2=self._args.torch_xla2) + batch_size=batch_size) class BenchmarkExperiment: - def __init__(self, - accelerator, - xla, - xla_flags, - dynamo, - test, - batch_size, - use_torch_xla2: bool = False): + def __init__(self, accelerator, xla, xla_flags, torch_xla2, dynamo, test, + batch_size): self.accelerator = accelerator self.xla = xla self.xla_flags = xla_flags + self.torch_xla2 = torch_xla2 self.dynamo = dynamo self.test = test self.batch_size = batch_size self.accelerator_model = get_accelerator_model(self.accelerator) - self.use_torch_xla2 = use_torch_xla2 def update_process_env(self, process_env): @@ -151,7 +151,7 @@ def update_process_env(self, process_env): process_env.pop("XRT_TPU_CONFIG", None) process_env.pop("XLA_FLAGS", None) - if self.use_torch_xla2: + if self.torch_xla2: process_env["JAX_PLATFORMS"] = self.accelerator.lower() if self.xla == "PJRT": diff --git a/benchmarks/benchmark_model.py b/benchmarks/benchmark_model.py index 9187650dfc2..efb1b7741bc 100644 --- a/benchmarks/benchmark_model.py +++ b/benchmarks/benchmark_model.py @@ -122,14 +122,19 @@ def prepare_for_experiment(self, dynamo_compilation_opts): else: raise NotImplementedError - if self.benchmark_experiment.use_torch_xla2: + if self.benchmark_experiment.torch_xla2: # for torch_xla2, we export model to FX graph and move weights to JAX device import torch_xla2.export import torch_xla2 import jax import jax.numpy as jnp - exported = torch.export.export(self.module, self.example_inputs) - weights, jax_func = torch_xla2.export.exported_program_to_jax(exported) + if benchmark_experiment.torch_xla2 == 'torch_export': + exported = torch.export.export(self.module, self.example_inputs) + weights, jax_func = torch_xla2.export.exported_program_to_jax(exported) + elif benchmark_experiment.torch_xla2 == 'extract_jax': + weights, jax_func = torch_xla2.extract_jax(self.module) + else: + raise ValueError("torch_xla2 option unavailable") jax_func = jax.jit(jax_func) device = jax.devices()[0] weights = pytree.tree_map_only(jnp.ndarray, @@ -137,13 +142,12 @@ def prepare_for_experiment(self, dynamo_compilation_opts): weights) # map the module function to jax_func self.module = lambda x: jax_func(weights, x) - self.example_inputs = move_to_device( - self.example_inputs, device, self.benchmark_experiment.use_torch_xla2) + self.example_inputs = move_to_device(self.example_inputs, device, + self.benchmark_experiment.torch_xla2) else: self.module = self.module.to(self.device) - self.example_inputs = move_to_device( - self.example_inputs, self.device, - self.benchmark_experiment.use_torch_xla2) + self.example_inputs = move_to_device(self.example_inputs, self.device, + self.benchmark_experiment.torch_xla2) if self.benchmark_experiment.dynamo: compilation_opts = dynamo_compilation_opts.copy() diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py index bf46955e97f..fa2f3e80d46 100644 --- a/benchmarks/experiment_runner.py +++ b/benchmarks/experiment_runner.py @@ -284,7 +284,7 @@ def run_once_and_gather_metrics(self, benchmark_experiment, benchmark_model, # Reset state and sync. reset_rng_state(benchmark_experiment) - if self._args.use_torch_xla2: + if self._args.torch_xla2: for inputs in inputs_list: self._mark_step(benchmark_experiment, inputs) else: @@ -420,7 +420,7 @@ def _prepare_inputs(self, example_inputs, should_randomize_input): def _mark_step(self, benchmark_experiment, tensors_to_check=None): if benchmark_experiment.xla: - if benchmark_experiment.use_torch_xla2: + if benchmark_experiment.torch_xla2: assert tensors_to_check is not None, "torch_xla2 requires input tensor to block_until_ready" for t in tensors_to_check: t.block_until_ready() @@ -870,8 +870,9 @@ def __str__(self): ) parser.add_argument( "--torch-xla2", - action="store_true", - help="Choose to use torch_xla2 or not.", + choices=["extract_jax", "torch_export"], + action="append", + help="Choose to use torch_xla2 and which mode to use.", ) parser.add_argument( "--disable-tf32", diff --git a/benchmarks/util.py b/benchmarks/util.py index 2f8089ccc4a..88bf452bbdd 100644 --- a/benchmarks/util.py +++ b/benchmarks/util.py @@ -50,7 +50,7 @@ def reset_rng_state(benchmark_experiment=None): torch.manual_seed(1337) random.seed(1337) np.random.seed(1337) - if benchmark_experiment is not None and benchmark_experiment.xla is not None and not benchmark_experiment.use_torch_xla2: + if benchmark_experiment is not None and benchmark_experiment.xla is not None and benchmark_experiment.torch_xla2 is not None: device = benchmark_experiment.get_device() xm.set_rng_state(1337, str(device)) @@ -76,8 +76,8 @@ def is_xla_device_available(devkind): return r.returncode == 0 -def move_to_device(item, device, use_torch_xla2: bool = False): - if use_torch_xla2: +def move_to_device(item, device, torch_xla2): + if torch_xla2: import torch_xla2 import jax move_to_device_func = lambda t: jax.device_put(