diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py index 0c38a48a8f3..58df7f6a8d2 100644 --- a/benchmarks/experiment_runner.py +++ b/benchmarks/experiment_runner.py @@ -242,7 +242,7 @@ def run_single_config(self): # Repeat the experiment and accumulate metrics. last_output = None - with benchmark_model.pick_grad(): + with benchmark_model.pick_context(): accumulated_metrics = OrderedDict() for repeat_iteration in range(self._args.repeat): metrics, last_output = self.run_once_and_gather_metrics( diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index 20cd72d937a..eb7ed0e24c1 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -1,5 +1,6 @@ import functools import gc +import contextlib import importlib import logging import os @@ -9,6 +10,9 @@ import torch.nn as nn from torch._dynamo.testing import collect_results, reduce_to_scalar_loss from torch._dynamo.utils import clone_inputs +import torch_xla +import torch_xla.amp +import torch_xla.core.xla_model as xm import types import yaml from util import move_to_device, set_cwd @@ -407,8 +411,7 @@ def default_precision_flag(self): return 'XLA_USE_BF16' if self.is_cuda_precision_amp(): - raise ValueError( - f"AMP for PT/XLA:GPU is not implemented yet for torchbench models") + return None if self.is_cuda_precision_fp32(): logger.warning("Sticking with the default fp32 precision.") @@ -431,6 +434,21 @@ def pick_grad(self): elif self.benchmark_experiment.test == "train": return torch.enable_grad() + def pick_amp(self): + if (self.benchmark_experiment.accelerator == "cuda" and + self.is_cuda_precision_amp()): + if self.benchmark_experiment.xla: + return torch_xla.amp.autocast(xm.xla_device()) + else: + return torch.cuda.amp.autocast() + return contextlib.nullcontext() + + def pick_context(self): + stack = contextlib.ExitStack() + stack.enter_context(self.pick_amp()) + stack.enter_context(self.pick_grad()) + return stack + def compute_loss(self, pred): """Reduce the output of a model to get scalar loss""" if isinstance(pred, torch.Tensor):