Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_autocast_torch_bf16 fails if PyTorch is compiled with CUDA support. #6085

Closed
ysiraichi opened this issue Dec 9, 2023 · 2 comments
Closed
Labels

Comments

@ysiraichi
Copy link
Collaborator

🐛 Bug

Running test_autocast_torch_bf16 test produces the following error, if PyTorch was compiled with CUDA support:

$ python test/test_autocast.py -v -k test_autocast_torch_bf16
test_autocast_torch_bf16 (__main__.TestAutocastCuda) ... ERROR

======================================================================
ERROR: test_autocast_torch_bf16 (__main__.TestAutocastCuda)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_autocast.py", line 391, in test_autocast_torch_bf16
    self._run_autocast_outofplace(
  File "test/test_autocast.py", line 278, in _run_autocast_outofplace
    with autocast(xm.xla_device(), dtype=autocast_dtype):
  File "xla/torch_xla/amp/autocast_mode.py", line 45, in __init__
    super().__init__(
  File "torch/amp/autocast_mode.py", line 306, in __init__
    raise RuntimeError(
RuntimeError: Current CUDA Device does not support bfloat16. Please switch dtype to float16.

----------------------------------------------------------------------
Ran 1 test in 0.140s

FAILED (errors=1)

Environment

Additional Context

Blocking: #6070

@ysiraichi
Copy link
Collaborator Author

The main problem here is that torch.cuda.is_bf16_supported() returns false, while torch.tensor([1.], dtype=torch.bfloat16, device=xm.xla_device()) works.

@JackCaoG
Copy link
Collaborator

@yeounoh FYI

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Jan 1, 2024
…#115924)

Fix: #115900 pytorch/xla#6085

This PR adds a last resort for testing for BF16 support on CUDA. This is necessary on GPUs
such as RTX 2060, where `torch.cuda.is_bf_supported()` returns False, but we can
successfully create a BF16 tensor on CUDA.

Before this PR:

```python
>>> torch.cuda.is_bf_supported()
False
>>> torch.tensor([1.], dtype=torch.bfloat16, device="cuda")
tensor([...], device='cuda:0', dtype=torch.bfloat16)
```

After this PR:

```python
>>> torch.cuda.is_bf_supported()
True
>>> torch.tensor([1.], dtype=torch.bfloat16, device="cuda")
tensor([...], device='cuda:0', dtype=torch.bfloat16)
```

Pull Request resolved: #115924
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants