diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index a3aa5185c6e..bd95f780b00 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -230,7 +230,7 @@ def load_benchmark(self): # torchbench uses `xla` as device instead of `tpu` if device := self.benchmark_experiment.accelerator == 'tpu': - device = 'xla' + device = str(self.benchmark_experiment.get_device()) return benchmark_cls( test=self.benchmark_experiment.test, device=device,