Skip to content

Commit

Permalink
Add PyTorch version check for FA backend on AMD GPUs (#35813)
Browse files Browse the repository at this point in the history
Disable FA backend for SDPA on AMD GPUs (PyTorch < 2.4.1)
  • Loading branch information
mht-sharma authored Jan 22, 2025
1 parent 3b97705 commit fdcc62c
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,7 @@ def _autoset_attn_implementation(
torch.version.hip is not None
and config._attn_implementation == "sdpa"
and torch.cuda.device_count() > 1
and version.parse(torch.__version__) < version.parse("2.4.1")
):
logger.warning_once(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
Expand Down

0 comments on commit fdcc62c

Please sign in to comment.