-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue in Manual optimisation, during self.manual_backward call #19810
Comments
I removed excess code, made a new Conda environment, installing just pytorch-lightning and tensorboard, and was able to replicate the same issue even with lightning version 2.2.3. I have edited the above issue to reflect the same. |
@pranavrao-qure Here I made the same PyTorch code (no Lightning) to show that this results in the same error: import math
import torch
from torch import nn, GradScaler
from torch.utils.data import TensorDataset, DataLoader
class TestModule(nn.Module):
def __init__(self, in_dim=512, out_dim=16):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.query = nn.Linear(self.in_dim, self.out_dim, bias=True)
def forward(self, input):
return self.query(input)
def get_mini_batches(input):
num_mb = math.ceil(16 / 4)
return torch.chunk(input, num_mb)
if __name__ == '__main__':
test_data_loader = DataLoader(TensorDataset(torch.randn(512, 512)), batch_size=16, shuffle=False)
test_module_obj = TestModule()
scaler = GradScaler(device="cpu")
with torch.autocast(device_type='cpu', dtype=torch.float16):
batch = next(iter(test_data_loader))
input = batch[0]
mini_batches = get_mini_batches(input)
mb_model_output_list = list()
with torch.no_grad():
for mb in mini_batches:
mb_model_output_list.append(test_module_obj(mb).mean().detach())
all_loss = sum(mb_model_output_list)
test_module_obj.train()
test_module_obj.requires_grad_(True)
torch.set_grad_enabled(True)
assert torch.is_grad_enabled()
assert all(p.requires_grad for p in test_module_obj.parameters())
for _, mb in enumerate(mini_batches):
mb_model_output = test_module_obj(mb).mean()
scaler.scale(mb_model_output).backward() When you use with torch.no_grad(), torch.autocast(device_type=self.device.type, enabled=False):
... This seems to be a quirk with PyTorch and how these context managers interact. There is nothing that could be done on the Lightning side to my knowledge. |
It also works if u do this
which mean the pytorch lightning encloses the entire module with torch.autocase enabled if u do precision=16 (which by default it mixed precision). The behavior you observe happens because you do both a no_grad forward pass and a grad-enabled forward pass within the same autocast context. In the no_grad forward pass, FP16 param copies are created and cached. Because it’s a no_grad context, when these FP16 copies are created they have requires_grad=False. When you run net(input) again in a grad-exposed way, you are still within the same autocast context, so the cache is live and the FP16 copies are not recreated (instead, net's FP16list ops use the cached copies). Since these cached copies have requires_grad=False, net(input) does not build an autograd graph, and z ends up having requires_grad=False. |
Bug description
I have set automatic_optimization to False, and am using self.manual_backward to calculate and populate the gradients. The code breaks during the self.manual_backward call, raising the error "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn". I have posted the code below for replicating the issue.
The issue does not arise when I set args['use_minibatch_clip_loss'] = False, or when I set args['batch_size'] = args['minibatch_size'] = 16. I suspect the issue only arises when I try to do backwards after running the model under torch.no_grad()
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Error messages and logs
Environment
Current environment
- GPU:
- NVIDIA A100 80GB PCIe
- NVIDIA A100 80GB PCIe
- NVIDIA A100 80GB PCIe
- NVIDIA A100 80GB PCIe
- available: True
- version: 12.1
- lightning: 2.2.3
- lightning-utilities: 0.11.2
- pytorch-lightning: 2.2.3
- torch: 2.3.0
- torchmetrics: 1.3.2
- absl-py: 2.1.0
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- attrs: 23.2.0
- filelock: 3.13.4
- frozenlist: 1.4.1
- fsspec: 2024.3.1
- grpcio: 1.62.2
- idna: 3.7
- jinja2: 3.1.3
- lightning: 2.2.3
- lightning-utilities: 0.11.2
- markdown: 3.6
- markupsafe: 2.1.5
- mpmath: 1.3.0
- multidict: 6.0.5
- networkx: 3.3
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu12: 8.9.2.26
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvjitlink-cu12: 12.4.127
- nvidia-nvtx-cu12: 12.1.105
- packaging: 24.0
- pip: 23.3.1
- protobuf: 5.26.1
- pytorch-lightning: 2.2.3
- pyyaml: 6.0.1
- setuptools: 68.2.2
- six: 1.16.0
- sympy: 1.12
- tensorboard: 2.16.2
- tensorboard-data-server: 0.7.2
- torch: 2.3.0
- torchmetrics: 1.3.2
- tqdm: 4.66.2
- triton: 2.3.0
- typing-extensions: 4.11.0
- werkzeug: 3.0.2
- wheel: 0.41.2
- yarl: 1.9.4
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.11.9
- release: 5.15.0-69-generic
- version: Quantisation and Pruning Support #76~20.04.1-Ubuntu SMP Mon Mar 20 15:54:19 UTC 2023
More info
I am training a Vision Language model with CLIP loss. The batch size I want to use is large, which requires to calculate the embeddings in mini batches and then calculate the gradient in mini batches as done in the repo https://github.com/Zasder3/train-CLIP/tree/main (See lines: https://github.com/Zasder3/train-CLIP/blob/79d4c7960072047a9e0d39335ab60dcb150640c3/models/wrapper.py#L64-L109 )
The issue arose when I implemented the similar algorithm as above for my use case and tried to train it. I have tried to isolate the problem as much I could, and produce a simple script reproducing the same error I get.
cc @carmocca @justusschock @awaelchli
The text was updated successfully, but these errors were encountered: