Skip to content

Commit

Permalink
support bf16 (#25879)
Browse files Browse the repository at this point in the history
* added bf16 support

* added cuda availability check

* applied make style, quality
  • Loading branch information
etemadiamd authored Nov 2, 2023
1 parent af3de8d commit 7adaefe
Showing 1 changed file with 1 addition and 20 deletions.
21 changes: 1 addition & 20 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,26 +305,7 @@ def is_torch_bf16_gpu_available():

import torch

# since currently no utility function is available we build our own.
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
# with additional check for torch version
# to succeed: (torch is required to be >= 1.10 anyway)
# 1. the hardware needs to support bf16 (GPU arch >= Ampere, or CPU)
# 2. if using gpu, CUDA >= 11
# 3. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)
if torch.cuda.is_available() and torch.version.cuda is not None:
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if not hasattr(torch.cuda.amp, "autocast"):
return False
else:
return False

return True
return torch.cuda.is_available() and torch.cuda.is_bf16_supported()


def is_torch_bf16_cpu_available():
Expand Down

0 comments on commit 7adaefe

Please sign in to comment.