-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StepLR
doesn't work as expected after loading from checkpoint using Trainer.fit(ckpt_path=...)
#17296
Comments
I can get the lr updates with both of these calls instead. def on_train_epoch_end(self):
self.log("lr", self.lr_schedulers().get_last_lr()[0], prog_bar=True, sync_dist=True)
print(f"lr => {self.lr_schedulers().state_dict()['_last_lr']}") I have checked that the optimizer hook uses the correct learning rate (ie it continues to step the lr) by looking at this step in the and checking the lr with The trainer is keeping track of its own optimizer configs through So it seems to be in the method that you call Edit: def on_train_epoch_end(self):
print(f"lr => {self.optimizers(use_pl_optimizer=False).param_groups[0]['lr']}") But I still get different losses if I run the first training for 10 epochs, versus 5 and then fit with the checkpoint and go another 5 epochs so there is maybe something affecting the rng differently. Further Testing Script
import os
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
# import pytorch_lightning as pl
# from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR
class DemoModel(pl.LightningModule):
def __init__(self, hidden_dim=64, learning_rate=2e-4):
super().__init__()
self.hidden_dim = hidden_dim
self.learning_rate = learning_rate
self.fc1 = nn.Linear(28 * 28, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.softmax(self.fc2(x), dim=1)
return x
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
self.log('train_loss', loss, prog_bar=True)
return loss
def on_train_epoch_end(self):
print(f"lr => {self.optimizers(use_pl_optimizer=False).param_groups[0]['lr']}")
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
self.log('val_loss', loss)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
self.log('test_loss', loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.9)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
def prepare_data(self):
# Download dataset
MNIST(root='data/', train=True, download=True)
MNIST(root='data/', train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(root='data/', train=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(root='data/', train=False, transform=transforms.ToTensor())
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=512, num_workers=4, shuffle=False, pin_memory=True, persistent_workers=True)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=512, num_workers=4)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=64, num_workers=4)
def test_load(fit, refit):
# Initialize the ModelCheckpoint callback
pl.seed_everything(100)
torch.manual_seed(100)
checkpoint_callback = ModelCheckpoint(
dirpath='./checkpoints',
save_last=True
)
if fit:
model = DemoModel()
trainer = pl.Trainer(devices=1,
accelerator="cuda",
strategy="auto",
max_epochs=fit,
callbacks=[checkpoint_callback],
log_every_n_steps=1,
enable_model_summary=False
)
trainer.fit(model)
if refit:
print("## loading checkpoint ##")
model = DemoModel()
trainer = pl.Trainer(devices=1,
accelerator="cuda",
strategy="auto",
max_epochs=10 if refit == "ckpt_path" else 5,
callbacks=[checkpoint_callback],
log_every_n_steps=1,
enable_model_summary=False
)
if refit == "from_checkpoint":
model = DemoModel.load_from_checkpoint("./checkpoints/last.ckpt", map_location=torch.device('cuda'))
trainer.fit(model)
if refit == "ckpt_path":
trainer.fit(model, ckpt_path="./checkpoints/last.ckpt")
os.remove("./checkpoints/last.ckpt")
if __name__ == "__main__":
torch.set_float32_matmul_precision('high')
print("FROM CHECKPOINT")
test_load(5, "from_checkpoint")
print("\n\nCKPT PATH")
test_load(5, "ckpt_path")
print("\n\nCONTINUOUS")
test_load(10, False) Further Testing Logs
|
@ryan597 I have a question. If
logsoutputs
|
@rafathasan I haven't looked specifically at that, but if you have told it you want to do manual optimization then it won't connect/wrap the pl optimizers to it, so its going to return the plain optimizer without needing to set I do agree though, you should be getting the same LR regardless of passing |
@ryan597 I think I should clarify my question a bit further. When I try to get lr with |
StepLR
doesn't work as expected after loading from checkpoint using Trainer.fit(ckpt_path=...)
Hey everyone. This was fixed in #18280 and released in 2.0.7. But no worries, the scheduler and optimizer was always correctly reloaded. The only bug was that the optimizer wrapper returned by I'm closing the issue because I was able to use the provided repro (thanks a ton, way to go!) to verify the fix. Cheers! |
Bug description
Well! the title speaks for itself. When the train.fit(ckpt_path=...) is called with checkpoint, it breaks StepLR. And the lr no longer get changed by lr scheduler. I have provided with highly reproduceable code. And no detailed explanation is required.
What version are you seeing the problem on?
2.0+
How to reproduce the bug
Error messages and logs
Environment
Current environment
- GPU:
- Tesla K80
- Tesla K80
- Tesla K80
- Tesla K80
- Tesla K80
- Tesla K80
- Tesla K80
- Tesla K80
- available: True
- version: 11.7
- lightning-utilities: 0.8.0
- pytorch-lightning: 2.0.0
- torch: 2.0.0
- torchdata: 0.6.0
- torchelastic: 0.2.2
- torchmetrics: 0.11.4
- torchsummary: 1.5.1
- torchtext: 0.15.1
- torchvision: 0.15.1
- aiohttp: 3.8.4
- aiosignal: 1.3.1
- antlr4-python3-runtime: 4.9.3
- anyio: 3.6.2
- appdirs: 1.4.4
- argon2-cffi: 21.3.0
- argon2-cffi-bindings: 21.2.0
- arrow: 1.2.3
- asttokens: 2.0.5
- astunparse: 1.6.3
- async-timeout: 4.0.2
- attrs: 22.1.0
- autopep8: 2.0.2
- backcall: 0.2.0
- beautifulsoup4: 4.11.1
- bleach: 6.0.0
- brotlipy: 0.7.0
- certifi: 2022.9.24
- cffi: 1.15.1
- chardet: 4.0.0
- charset-normalizer: 2.0.4
- click: 8.1.3
- cmake: 3.26.1
- comm: 0.1.3
- conda: 22.11.1
- conda-build: 3.23.3
- conda-package-handling: 1.9.0
- contourpy: 1.0.7
- cryptography: 38.0.1
- cycler: 0.11.0
- debugpy: 1.6.6
- decorator: 5.1.1
- defusedxml: 0.7.1
- dnspython: 2.2.1
- docker-pycreds: 0.4.0
- docopt: 0.6.2
- exceptiongroup: 1.0.4
- executing: 0.8.3
- expecttest: 0.1.4
- fastjsonschema: 2.16.3
- filelock: 3.6.0
- flit-core: 3.6.0
- fonttools: 4.39.2
- fqdn: 1.5.1
- frozenlist: 1.3.3
- fsspec: 2023.3.0
- future: 0.18.2
- gdown: 4.7.1
- gitdb: 4.0.10
- gitpython: 3.1.31
- glob2: 0.7
- hypothesis: 6.61.0
- idna: 3.4
- ipykernel: 6.22.0
- ipython: 8.11.0
- ipython-genutils: 0.2.0
- ipywidgets: 8.0.5
- isoduration: 20.11.0
- jedi: 0.18.1
- jinja2: 3.1.2
- joblib: 1.2.0
- jsonpointer: 2.3
- jsonschema: 4.17.3
- jupyter: 1.0.0
- jupyter-client: 8.1.0
- jupyter-console: 6.6.3
- jupyter-core: 5.3.0
- jupyter-events: 0.6.3
- jupyter-server: 2.5.0
- jupyter-server-terminals: 0.4.4
- jupyterlab-pygments: 0.2.2
- jupyterlab-widgets: 3.0.6
- kiwisolver: 1.4.4
- libarchive-c: 2.9
- lightning-utilities: 0.8.0
- lit: 16.0.0
- markupsafe: 2.0.1
- matplotlib: 3.7.1
- matplotlib-inline: 0.1.6
- mistune: 2.0.5
- mkl-fft: 1.3.1
- mkl-random: 1.2.2
- mkl-service: 2.4.0
- mpmath: 1.2.1
- multidict: 6.0.4
- nbclassic: 0.5.3
- nbclient: 0.7.2
- nbconvert: 7.2.10
- nbformat: 5.8.0
- nest-asyncio: 1.5.6
- networkx: 3.0
- notebook: 6.5.3
- notebook-shim: 0.2.2
- numpy: 1.24.2
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cuda-cupti-cu11: 11.7.101
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cudnn-cu11: 8.5.0.96
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-curand-cu11: 10.2.10.91
- nvidia-cusolver-cu11: 11.4.0.1
- nvidia-cusparse-cu11: 11.7.4.91
- nvidia-nccl-cu11: 2.14.3
- nvidia-nvtx-cu11: 11.7.91
- omegaconf: 2.3.0
- onedrivedownloader: 1.1.3
- packaging: 23.0
- pandocfilters: 1.5.0
- parso: 0.8.3
- pathtools: 0.1.2
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 9.4.0
- pip: 22.3.1
- pipreqs: 0.4.11
- pkginfo: 1.8.3
- platformdirs: 3.2.0
- pluggy: 1.0.0
- prometheus-client: 0.16.0
- prompt-toolkit: 3.0.38
- protobuf: 4.22.1
- psutil: 5.9.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- pycodestyle: 2.10.0
- pycosat: 0.6.4
- pycparser: 2.21
- pygments: 2.11.2
- pyopenssl: 22.0.0
- pyparsing: 3.0.9
- pyrsistent: 0.19.3
- pysocks: 1.7.1
- python-dateutil: 2.8.2
- python-etcd: 0.4.5
- python-json-logger: 2.0.7
- pytorch-lightning: 2.0.0
- pytz: 2022.1
- pyyaml: 6.0
- pyzmq: 25.0.2
- qtconsole: 5.4.1
- qtpy: 2.3.0
- requests: 2.28.1
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- ruamel.yaml: 0.17.21
- ruamel.yaml.clib: 0.2.6
- scikit-learn: 1.2.2
- scipy: 1.10.1
- send2trash: 1.8.0
- sentry-sdk: 1.17.0
- setproctitle: 1.3.2
- setuptools: 65.5.0
- six: 1.16.0
- smmap: 5.0.0
- sniffio: 1.3.0
- sortedcontainers: 2.4.0
- soupsieve: 2.3.2.post1
- stack-data: 0.2.0
- sympy: 1.11.1
- terminado: 0.17.1
- thop: 0.1.1.post2209072238
- threadpoolctl: 3.1.0
- tinycss2: 1.2.1
- toml: 0.10.2
- tomli: 2.0.1
- toolz: 0.12.0
- torch: 2.0.0
- torchdata: 0.6.0
- torchelastic: 0.2.2
- torchmetrics: 0.11.4
- torchsummary: 1.5.1
- torchtext: 0.15.1
- torchvision: 0.15.1
- tornado: 6.2
- tqdm: 4.65.0
- traitlets: 5.7.1
- triton: 2.0.0
- types-dataclasses: 0.6.6
- typing-extensions: 4.4.0
- uri-template: 1.2.0
- urllib3: 1.26.13
- wandb: 0.14.0
- wcwidth: 0.2.5
- webcolors: 1.12
- webencodings: 0.5.1
- websocket-client: 1.5.1
- wheel: 0.37.1
- widgetsnbextension: 4.0.6
- yarg: 0.1.9
- yarl: 1.8.2
- OS: Linux
- architecture:
- 64bit
-
- processor: x86_64
- python: 3.10.8
- version: Quantisation and Pruning Support #76~20.04.1-Ubuntu SMP Mon Mar 20 15:54:19 UTC 2023
More info
No response
cc @awaelchli
The text was updated successfully, but these errors were encountered: