Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[benchmarks] Add support for AMP execution. #6447

Merged
merged 3 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def run_single_config(self):

# Repeat the experiment and accumulate metrics.
last_output = None
with benchmark_model.pick_grad():
with benchmark_model.pick_context():
accumulated_metrics = OrderedDict()
for repeat_iteration in range(self._args.repeat):
metrics, last_output = self.run_once_and_gather_metrics(
Expand Down
22 changes: 20 additions & 2 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import gc
import contextlib
import importlib
import logging
import os
Expand All @@ -9,6 +10,9 @@
import torch.nn as nn
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
from torch._dynamo.utils import clone_inputs
import torch_xla
import torch_xla.amp
import torch_xla.core.xla_model as xm
import types
import yaml
from util import move_to_device, set_cwd
Expand Down Expand Up @@ -403,8 +407,7 @@ def default_precision_flag(self):
return 'XLA_USE_BF16'

if self.is_cuda_precision_amp():
raise ValueError(
f"AMP for PT/XLA:GPU is not implemented yet for torchbench models")
return None

if self.is_cuda_precision_fp32():
logger.warning("Sticking with the default fp32 precision.")
Expand All @@ -427,6 +430,21 @@ def pick_grad(self):
elif self.benchmark_experiment.test == "train":
return torch.enable_grad()

def pick_amp(self):
if (self.benchmark_experiment.accelerator == "cuda" and
self.is_cuda_precision_amp()):
if self.benchmark_experiment.xla:
return torch_xla.amp.autocast(xm.xla_device())
else:
return torch.cuda.amp.autocast()
return contextlib.nullcontext()

def pick_context(self):
stack = contextlib.ExitStack()
stack.enter_context(self.pick_amp())
stack.enter_context(self.pick_grad())
return stack

def compute_loss(self, pred):
"""Reduce the output of a model to get scalar loss"""
if isinstance(pred, torch.Tensor):
Expand Down
Loading