Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StepLR doesn't work as expected after loading from checkpoint using Trainer.fit(ckpt_path=...) #17296

Closed
rafathasan opened this issue Apr 6, 2023 · 5 comments
Labels
bug Something isn't working checkpointing Related to checkpointing ver: 2.0.x
Milestone

Comments

@rafathasan
Copy link

rafathasan commented Apr 6, 2023

Bug description

Well! the title speaks for itself. When the train.fit(ckpt_path=...) is called with checkpoint, it breaks StepLR. And the lr no longer get changed by lr scheduler. I have provided with highly reproduceable code. And no detailed explanation is required.

What version are you seeing the problem on?

2.0+

How to reproduce the bug
#!/opt/conda/bin/python
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR

class DemoModel(pl.LightningModule):
    def __init__(self, hidden_dim=64, learning_rate=2e-4):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate

        self.fc1 = nn.Linear(28 * 28, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss

    def on_train_epoch_end(self):
        self.log("lr", self.optimizers().param_groups[0]['lr'],  prog_bar=True, sync_dist=True)
        print(f"lr => {self.optimizers().param_groups[0]['lr']}")

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.9)

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def prepare_data(self):
        # Download dataset
        MNIST(root='data/', train=True, download=True)
        MNIST(root='data/', train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(root='data/', train=True, transform=transforms.ToTensor())
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(root='data/', train=False, transform=transforms.ToTensor())

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=512, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=512, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64, num_workers=4)


# Initialize the model
model = DemoModel()

# Initialize the ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath='./checkpoints',
    save_last=True,
)

# Initialize the trainer
trainer = pl.Trainer(devices="0,1,2,3",
accelerator="cuda",
strategy="ddp",
max_epochs=5,
callbacks=[checkpoint_callback],
log_every_n_steps=1,
)

# Train the model with ModelCheckpoint callback
trainer.fit(model)

print("################################## loading checkpoint #############################################")

# Initialize the model
model = DemoModel()

# Initialize the trainer
trainer = pl.Trainer(devices="0,1,2,3", accelerator="cuda", strategy="ddp", max_epochs=10, log_every_n_steps=1)

# Train the model with ModelCheckpoint callback
trainer.fit(model, ckpt_path="./checkpoints/last.ckpt")
Error messages and logs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
  warning_cache.warn(
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type   | Params
--------------------------------
0 | fc1  | Linear | 50.2 K
1 | fc2  | Linear | 650   
--------------------------------
50.9 K    Trainable params
0         Non-trainable params
50.9 K    Total params
0.204     Total estimated model params size (MB)
Sanity Checking DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.86it/s]/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 21.20it/s, v_num=1]lr => 0.00018                                                                                                                                                                                                       
lr => 0.00018
lr => 0.00018
lr => 0.00018
Epoch 1: 100%|████████████████lr => 0.000162███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 27.05it/s, v_num=1, lr=0.00018]
lr => 0.000162t [00:00, ?it/s]
Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.63it/s, v_num=1, lr=0.00018]lr => 0.000162                                                                                                                                                                                                      
lr => 0.000162
Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 26.38it/s, v_num=1, lr=0.000162lr => 0.000145800000000000020%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 328.04it/s]
Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.76it/s, v_num=1, lr=0.000162]lr => 0.00014580000000000002                                                                                                                                                                                        
lr => 0.00014580000000000002
lr => 0.00014580000000000002
Epoch 3: 100%|████████████████lr => 0.00013122000000000003████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 26.05it/s, v_num=1, lr=0.000146]
Validation: 0it [00:00, ?it/s]
Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 19.85it/s, v_num=1, lr=0.000146]lr => 0.00013122000000000003                                                                                                                                                                                        
lr => 0.00013122000000000003
Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 26.99it/s, v_num=1, lr=0.000131lr => 0.000118098000000000033%|██████████████████████████████████████████████████▎                                                                                                    | 1/3 [00:00<00:00, 279.14it/s]
lr => 0.00011809800000000003

Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.67it/s, v_num=1, lr=0.000131]lr => 0.00011809800000000003                                                                                                                                                                                        
`Trainer.fit` stopped: `max_epochs=5` reached.
Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.42it/s, v_num=1, lr=0.000131]
################################## loading checkpoint #############################################
################################## loading checkpoint #############################################
################################## loading checkpoint #############################################
################################## loading checkpoint #############################################
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at ./checkpoints/last.ckpt
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:337: UserWarning: The dirpath has changed from '/src/notebooks/checkpoints' to '/src/notebooks/lightning_logs/version_2/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.
  warnings.warn(
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:337: UserWarning: The dirpath has changed from '/src/notebooks/checkpoints' to '/src/notebooks/lightning_logs/version_2/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.
  warnings.warn(
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:337: UserWarning: The dirpath has changed from '/src/notebooks/checkpoints' to '/src/notebooks/lightning_logs/version_2/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.
  warnings.warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:337: UserWarning: The dirpath has changed from '/src/notebooks/checkpoints' to '/src/notebooks/lightning_logs/version_2/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.
  warnings.warn(
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type   | Params
--------------------------------
0 | fc1  | Linear | 50.2 K
1 | fc2  | Linear | 650   
--------------------------------
50.9 K    Trainable params
0         Non-trainable params
50.9 K    Total params
0.204     Total estimated model params size (MB)
Restored all states from the checkpoint at ./checkpoints/last.ckpt
Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 21.26it/s, v_num=2]lr => 0.0002                                                                                                                                                                                                        
lr => 0.0002
lr => 0.0002
lr => 0.0002
Epoch 6: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 21.34it/s, v_num=2, lr=0.0002]lr => 0.0002                                                                                                                                                                                                        
lr => 0.0002
lr => 0.0002
lr => 0.0002
Epoch 7: 100%|████████████████lr => 0.0002██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 26.33it/s, v_num=2, lr=0.0002]
Epoch 7: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.50it/s, v_num=2, lr=0.0002]lr => 0.0002                                                                                                                                                                                                        
lr => 0.0002
lr => 0.0002
Epoch 8: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 21.41it/s, v_num=2, lr=0.0002]lr => 0.0002                                                                                                                                                                                                        
lr => 0.0002
lr => 0.0002
lr => 0.0002
Epoch 9: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 26.95it/s, v_num=2, lr=0.0002]
lr => 0.0002  0%|                                                                                                                                                                             | 0/3 [00:00<?, ?it/slr => 0.0002DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 124.44it/s]
Epoch 9: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.53it/s, v_num=2, lr=0.0002]lr => 0.0002                                                                                                                                                                                                        
lr => 0.0002
`Trainer.fit` stopped: `max_epochs=10` reached.
Epoch 9: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.23it/s, v_num=2, lr=0.0002]

Environment

Current environment
  • CUDA:
    - GPU:
    - Tesla K80
    - Tesla K80
    - Tesla K80
    - Tesla K80
    - Tesla K80
    - Tesla K80
    - Tesla K80
    - Tesla K80
    - available: True
    - version: 11.7
  • Lightning:
    - lightning-utilities: 0.8.0
    - pytorch-lightning: 2.0.0
    - torch: 2.0.0
    - torchdata: 0.6.0
    - torchelastic: 0.2.2
    - torchmetrics: 0.11.4
    - torchsummary: 1.5.1
    - torchtext: 0.15.1
    - torchvision: 0.15.1
  • Packages:
    - aiohttp: 3.8.4
    - aiosignal: 1.3.1
    - antlr4-python3-runtime: 4.9.3
    - anyio: 3.6.2
    - appdirs: 1.4.4
    - argon2-cffi: 21.3.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.2.3
    - asttokens: 2.0.5
    - astunparse: 1.6.3
    - async-timeout: 4.0.2
    - attrs: 22.1.0
    - autopep8: 2.0.2
    - backcall: 0.2.0
    - beautifulsoup4: 4.11.1
    - bleach: 6.0.0
    - brotlipy: 0.7.0
    - certifi: 2022.9.24
    - cffi: 1.15.1
    - chardet: 4.0.0
    - charset-normalizer: 2.0.4
    - click: 8.1.3
    - cmake: 3.26.1
    - comm: 0.1.3
    - conda: 22.11.1
    - conda-build: 3.23.3
    - conda-package-handling: 1.9.0
    - contourpy: 1.0.7
    - cryptography: 38.0.1
    - cycler: 0.11.0
    - debugpy: 1.6.6
    - decorator: 5.1.1
    - defusedxml: 0.7.1
    - dnspython: 2.2.1
    - docker-pycreds: 0.4.0
    - docopt: 0.6.2
    - exceptiongroup: 1.0.4
    - executing: 0.8.3
    - expecttest: 0.1.4
    - fastjsonschema: 2.16.3
    - filelock: 3.6.0
    - flit-core: 3.6.0
    - fonttools: 4.39.2
    - fqdn: 1.5.1
    - frozenlist: 1.3.3
    - fsspec: 2023.3.0
    - future: 0.18.2
    - gdown: 4.7.1
    - gitdb: 4.0.10
    - gitpython: 3.1.31
    - glob2: 0.7
    - hypothesis: 6.61.0
    - idna: 3.4
    - ipykernel: 6.22.0
    - ipython: 8.11.0
    - ipython-genutils: 0.2.0
    - ipywidgets: 8.0.5
    - isoduration: 20.11.0
    - jedi: 0.18.1
    - jinja2: 3.1.2
    - joblib: 1.2.0
    - jsonpointer: 2.3
    - jsonschema: 4.17.3
    - jupyter: 1.0.0
    - jupyter-client: 8.1.0
    - jupyter-console: 6.6.3
    - jupyter-core: 5.3.0
    - jupyter-events: 0.6.3
    - jupyter-server: 2.5.0
    - jupyter-server-terminals: 0.4.4
    - jupyterlab-pygments: 0.2.2
    - jupyterlab-widgets: 3.0.6
    - kiwisolver: 1.4.4
    - libarchive-c: 2.9
    - lightning-utilities: 0.8.0
    - lit: 16.0.0
    - markupsafe: 2.0.1
    - matplotlib: 3.7.1
    - matplotlib-inline: 0.1.6
    - mistune: 2.0.5
    - mkl-fft: 1.3.1
    - mkl-random: 1.2.2
    - mkl-service: 2.4.0
    - mpmath: 1.2.1
    - multidict: 6.0.4
    - nbclassic: 0.5.3
    - nbclient: 0.7.2
    - nbconvert: 7.2.10
    - nbformat: 5.8.0
    - nest-asyncio: 1.5.6
    - networkx: 3.0
    - notebook: 6.5.3
    - notebook-shim: 0.2.2
    - numpy: 1.24.2
    - nvidia-cublas-cu11: 11.10.3.66
    - nvidia-cuda-cupti-cu11: 11.7.101
    - nvidia-cuda-nvrtc-cu11: 11.7.99
    - nvidia-cuda-runtime-cu11: 11.7.99
    - nvidia-cudnn-cu11: 8.5.0.96
    - nvidia-cufft-cu11: 10.9.0.58
    - nvidia-curand-cu11: 10.2.10.91
    - nvidia-cusolver-cu11: 11.4.0.1
    - nvidia-cusparse-cu11: 11.7.4.91
    - nvidia-nccl-cu11: 2.14.3
    - nvidia-nvtx-cu11: 11.7.91
    - omegaconf: 2.3.0
    - onedrivedownloader: 1.1.3
    - packaging: 23.0
    - pandocfilters: 1.5.0
    - parso: 0.8.3
    - pathtools: 0.1.2
    - pexpect: 4.8.0
    - pickleshare: 0.7.5
    - pillow: 9.4.0
    - pip: 22.3.1
    - pipreqs: 0.4.11
    - pkginfo: 1.8.3
    - platformdirs: 3.2.0
    - pluggy: 1.0.0
    - prometheus-client: 0.16.0
    - prompt-toolkit: 3.0.38
    - protobuf: 4.22.1
    - psutil: 5.9.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - pycodestyle: 2.10.0
    - pycosat: 0.6.4
    - pycparser: 2.21
    - pygments: 2.11.2
    - pyopenssl: 22.0.0
    - pyparsing: 3.0.9
    - pyrsistent: 0.19.3
    - pysocks: 1.7.1
    - python-dateutil: 2.8.2
    - python-etcd: 0.4.5
    - python-json-logger: 2.0.7
    - pytorch-lightning: 2.0.0
    - pytz: 2022.1
    - pyyaml: 6.0
    - pyzmq: 25.0.2
    - qtconsole: 5.4.1
    - qtpy: 2.3.0
    - requests: 2.28.1
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - ruamel.yaml: 0.17.21
    - ruamel.yaml.clib: 0.2.6
    - scikit-learn: 1.2.2
    - scipy: 1.10.1
    - send2trash: 1.8.0
    - sentry-sdk: 1.17.0
    - setproctitle: 1.3.2
    - setuptools: 65.5.0
    - six: 1.16.0
    - smmap: 5.0.0
    - sniffio: 1.3.0
    - sortedcontainers: 2.4.0
    - soupsieve: 2.3.2.post1
    - stack-data: 0.2.0
    - sympy: 1.11.1
    - terminado: 0.17.1
    - thop: 0.1.1.post2209072238
    - threadpoolctl: 3.1.0
    - tinycss2: 1.2.1
    - toml: 0.10.2
    - tomli: 2.0.1
    - toolz: 0.12.0
    - torch: 2.0.0
    - torchdata: 0.6.0
    - torchelastic: 0.2.2
    - torchmetrics: 0.11.4
    - torchsummary: 1.5.1
    - torchtext: 0.15.1
    - torchvision: 0.15.1
    - tornado: 6.2
    - tqdm: 4.65.0
    - traitlets: 5.7.1
    - triton: 2.0.0
    - types-dataclasses: 0.6.6
    - typing-extensions: 4.4.0
    - uri-template: 1.2.0
    - urllib3: 1.26.13
    - wandb: 0.14.0
    - wcwidth: 0.2.5
    - webcolors: 1.12
    - webencodings: 0.5.1
    - websocket-client: 1.5.1
    - wheel: 0.37.1
    - widgetsnbextension: 4.0.6
    - yarg: 0.1.9
    - yarl: 1.8.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    -
    - processor: x86_64
    - python: 3.10.8
    - version: Quantisation and Pruning Support #76~20.04.1-Ubuntu SMP Mon Mar 20 15:54:19 UTC 2023

More info

No response

cc @awaelchli

@rafathasan rafathasan added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 6, 2023
@rafathasan rafathasan changed the title StepLR doesn't work as expected after loading from checkpoint using Trainer(ckpt_path=...) StepLR doesn't work as expected after loading from checkpoint using Trainer.fit(ckpt_path=...) Apr 6, 2023
@ryan597
Copy link
Contributor

ryan597 commented Apr 8, 2023

I can get the lr updates with both of these calls instead.

    def on_train_epoch_end(self):
        self.log("lr", self.lr_schedulers().get_last_lr()[0],  prog_bar=True, sync_dist=True)
        print(f"lr => {self.lr_schedulers().state_dict()['_last_lr']}")

I have checked that the optimizer hook uses the correct learning rate (ie it continues to step the lr) by looking at this step in the lightning/pytorch/loops/optimization/automatic.py file

https://github.com/Lightning-AI/lightning/blob/fd4697c62c059fc7b9946e84d91625ecb6efdbe5/src/lightning/pytorch/loops/optimization/automatic.py#L237-L271

and checking the lr with
print("OPT PARMS", optimizer.optimizer.param_groups[0]['lr'])

The trainer is keeping track of its own optimizer configs through trainer.lr_scheduler_configs as seen in lightning/pytorch/trainer/connectors/checkpoint_connector.py

https://github.com/Lightning-AI/lightning/blob/fd4697c62c059fc7b9946e84d91625ecb6efdbe5/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py#L383-L391

So it seems to be in the method that you call self.optimizers().param_groups[0]['lr'] that is not updated.

Edit:
I looked further to this and see that by setting use_pl_optimizer=False the optimizer methods are correct again

def on_train_epoch_end(self):
        print(f"lr => {self.optimizers(use_pl_optimizer=False).param_groups[0]['lr']}")

But I still get different losses if I run the first training for 10 epochs, versus 5 and then fit with the checkpoint and go another 5 epochs so there is maybe something affecting the rng differently.

Further Testing Script

import os

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
# import pytorch_lightning as pl
# from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR


class DemoModel(pl.LightningModule):
    def __init__(self, hidden_dim=64, learning_rate=2e-4):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate

        self.fc1 = nn.Linear(28 * 28, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        print(f"lr => {self.optimizers(use_pl_optimizer=False).param_groups[0]['lr']}")

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.9)

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def prepare_data(self):
        # Download dataset
        MNIST(root='data/', train=True, download=True)
        MNIST(root='data/', train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(root='data/', train=True, transform=transforms.ToTensor())
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(root='data/', train=False, transform=transforms.ToTensor())

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=512, num_workers=4, shuffle=False, pin_memory=True, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=512, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64, num_workers=4)


def test_load(fit, refit):
    # Initialize the ModelCheckpoint callback
    pl.seed_everything(100)
    torch.manual_seed(100)

    checkpoint_callback = ModelCheckpoint(
        dirpath='./checkpoints',
        save_last=True
    )

    if fit:
        model = DemoModel()
        trainer = pl.Trainer(devices=1,
                             accelerator="cuda",
                             strategy="auto",
                             max_epochs=fit,
                             callbacks=[checkpoint_callback],
                             log_every_n_steps=1,
                             enable_model_summary=False
                             )
        trainer.fit(model)

    if refit:
        print("## loading checkpoint ##")

        model = DemoModel()
        trainer = pl.Trainer(devices=1,
                             accelerator="cuda",
                             strategy="auto",
                             max_epochs=10 if refit == "ckpt_path" else 5,
                             callbacks=[checkpoint_callback],
                             log_every_n_steps=1,
                             enable_model_summary=False
                             )
        if refit == "from_checkpoint":
            model = DemoModel.load_from_checkpoint("./checkpoints/last.ckpt", map_location=torch.device('cuda'))
            trainer.fit(model)
        if refit == "ckpt_path":
            trainer.fit(model, ckpt_path="./checkpoints/last.ckpt")

    os.remove("./checkpoints/last.ckpt")


if __name__ == "__main__":
    torch.set_float32_matmul_precision('high')

    print("FROM CHECKPOINT")
    test_load(5, "from_checkpoint")
    print("\n\nCKPT PATH")
    test_load(5, "ckpt_path")
    print("\n\nCONTINUOUS")
    test_load(10, False)

Further Testing Logs

FROM CHECKPOINT
Epoch 0: 100%|████████████████████████| 108/108 [00:00<00:00, 243.36it/s, v_num=165, loss=2.050]lr => 0.00018                                                                                   
Epoch 1: 100%|████████████████████████| 108/108 [00:00<00:00, 260.28it/s, v_num=165, loss=1.850]lr => 0.000162                                                                                  
Epoch 2: 100%|████████████████████████| 108/108 [00:00<00:00, 250.01it/s, v_num=165, loss=1.740]lr => 0.00014580000000000002                                                                    
Epoch 3: 100%|████████████████████████| 108/108 [00:00<00:00, 220.72it/s, v_num=165, loss=1.670]lr => 0.00013122000000000003                                                                    
Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 249.70it/s, v_num=165, loss=1.650]lr => 0.00011809800000000003                                                                    
Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 247.65it/s, v_num=165, loss=1.650]
## loading checkpoint ##
Epoch 0: 100%|████████████████████████| 108/108 [00:00<00:00, 231.37it/s, v_num=166, loss=1.610]lr => 0.00018                                                                                   
Epoch 1: 100%|████████████████████████| 108/108 [00:00<00:00, 237.92it/s, v_num=166, loss=1.600]lr => 0.000162                                                                                  
Epoch 2: 100%|████████████████████████| 108/108 [00:00<00:00, 239.84it/s, v_num=166, loss=1.590]lr => 0.00014580000000000002                                                                    
Epoch 3: 100%|████████████████████████| 108/108 [00:00<00:00, 249.72it/s, v_num=166, loss=1.580]lr => 0.00013122000000000003                                                                    
Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 234.87it/s, v_num=166, loss=1.570]lr => 0.00011809800000000003                                                                    
Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 232.78it/s, v_num=166, loss=1.570]


CKPT PATH
Epoch 0: 100%|████████████████████████| 108/108 [00:00<00:00, 235.67it/s, v_num=167, loss=2.050]lr => 0.00018                                                                                   
Epoch 1: 100%|████████████████████████| 108/108 [00:00<00:00, 222.51it/s, v_num=167, loss=1.850]lr => 0.000162                                                                                  
Epoch 2: 100%|████████████████████████| 108/108 [00:00<00:00, 236.36it/s, v_num=167, loss=1.740]lr => 0.00014580000000000002                                                                    
Epoch 3: 100%|████████████████████████| 108/108 [00:00<00:00, 242.70it/s, v_num=167, loss=1.670]lr => 0.00013122000000000003                                                                    
Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 230.25it/s, v_num=167, loss=1.650]lr => 0.00011809800000000003                                                                    
Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 228.53it/s, v_num=167, loss=1.650]
## loading checkpoint ##
Epoch 5: 100%|████████████████████████| 108/108 [00:00<00:00, 213.97it/s, v_num=168, loss=1.590]lr => 0.00010628820000000004                                                                    
Epoch 6: 100%|████████████████████████| 108/108 [00:00<00:00, 234.14it/s, v_num=168, loss=1.580]lr => 9.565938000000004e-05                                                                     
Epoch 7: 100%|████████████████████████| 108/108 [00:00<00:00, 216.25it/s, v_num=168, loss=1.580]lr => 8.609344200000004e-05                                                                     
Epoch 8: 100%|████████████████████████| 108/108 [00:00<00:00, 245.35it/s, v_num=168, loss=1.570]lr => 7.748409780000004e-05                                                                     
Epoch 9: 100%|████████████████████████| 108/108 [00:00<00:00, 229.21it/s, v_num=168, loss=1.570]lr => 6.973568802000003e-05                                                                     
Epoch 9: 100%|████████████████████████| 108/108 [00:00<00:00, 227.43it/s, v_num=168, loss=1.570]


CONTINUOUS
Epoch 0: 100%|████████████████████████| 108/108 [00:00<00:00, 203.86it/s, v_num=169, loss=2.050]lr => 0.00018                                                                                   
Epoch 1: 100%|████████████████████████| 108/108 [00:00<00:00, 230.33it/s, v_num=169, loss=1.850]lr => 0.000162                                                                                  
Epoch 2: 100%|████████████████████████| 108/108 [00:00<00:00, 243.14it/s, v_num=169, loss=1.740]lr => 0.00014580000000000002                                                                    
Epoch 3: 100%|████████████████████████| 108/108 [00:00<00:00, 219.94it/s, v_num=169, loss=1.670]lr => 0.00013122000000000003                                                                    
Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 245.71it/s, v_num=169, loss=1.650]lr => 0.00011809800000000003                                                                    
Epoch 5: 100%|████████████████████████| 108/108 [00:00<00:00, 229.97it/s, v_num=169, loss=1.630]lr => 0.00010628820000000004                                                                    
Epoch 6: 100%|████████████████████████| 108/108 [00:00<00:00, 235.31it/s, v_num=169, loss=1.620]lr => 9.565938000000004e-05                                                                     
Epoch 7: 100%|████████████████████████| 108/108 [00:00<00:00, 249.66it/s, v_num=169, loss=1.610]lr => 8.609344200000004e-05                                                                     
Epoch 8: 100%|████████████████████████| 108/108 [00:00<00:00, 240.83it/s, v_num=169, loss=1.600]lr => 7.748409780000004e-05                                                                     
Epoch 9: 100%|████████████████████████| 108/108 [00:00<00:00, 245.06it/s, v_num=169, loss=1.600]lr => 6.973568802000003e-05                                                                     
Epoch 9: 100%|████████████████████████| 108/108 [00:00<00:00, 243.11it/s, v_num=169, loss=1.600]

@rafathasan
Copy link
Author

rafathasan commented Apr 20, 2023

@ryan597 I have a question. If self.automatic_optimization=False is set and I have to manually do loss.backward(), self.optimizers().step() and self.optimizers().zero_grad(). By calling self.optimizers() points the correct optimizer without explicitly passing use_pl_optimizer=False. So, the question is, isn't this making it more ambiguous?

    def __init__(self, hidden_dim=64, learning_rate=2e-4):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate

        self.fc1 = nn.Linear(28 * 28, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)
        self.automatic_optimization=False
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss)

        loss.backward()

        self.optimizers().step()
        self.optimizers().zero_grad()

        return loss
    def on_train_epoch_end(self):
        self.lr_schedulers().step()
        self.log("lr", self.optimizers(use_pl_optimizer=0).param_groups[0]['lr'],  prog_bar=True, sync_dist=True)
        print(f"lr => {self.optimizers(use_pl_optimizer=0).param_groups[0]['lr']}")

logs

outputs
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:lightning_fabric.utilities.distributed:Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
INFO:pytorch_lightning.utilities.rank_zero:----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name | Type   | Params
--------------------------------
0 | fc1  | Linear | 50.2 K
1 | fc2  | Linear | 650   
--------------------------------
50.9 K    Trainable params
0         Non-trainable params
50.9 K    Total params
0.204     Total estimated model params size (MB)
/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:561: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
Epoch 4: 100%
108/108 [00:07<00:00, 15.23it/s, v_num=9, lr=0.000131]
lr => 0.00018
lr => 0.000162
lr => 0.00014580000000000002
lr => 0.00013122000000000003
lr => 0.00011809800000000003
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
################################## loading checkpoint #############################################
INFO:lightning_fabric.utilities.distributed:Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
INFO:pytorch_lightning.utilities.rank_zero:----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at ./checkpoints/last.ckpt
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:337: UserWarning: The dirpath has changed from '/content/checkpoints' to '/content/lightning_logs/version_10/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.
  warnings.warn(
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name | Type   | Params
--------------------------------
0 | fc1  | Linear | 50.2 K
1 | fc2  | Linear | 650   
--------------------------------
50.9 K    Trainable params
0         Non-trainable params
50.9 K    Total params
0.204     Total estimated model params size (MB)
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint at ./checkpoints/last.ckpt
/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:561: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
Epoch 9: 100%
108/108 [00:07<00:00, 14.95it/s, v_num=10, lr=7.75e-5]
lr => 0.00010628820000000004
lr => 9.565938000000004e-05
lr => 8.609344200000004e-05
lr => 7.748409780000004e-05
lr => 6.973568802000003e-05
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.

@ryan597
Copy link
Contributor

ryan597 commented Apr 20, 2023

@rafathasan I haven't looked specifically at that, but if you have told it you want to do manual optimization then it won't connect/wrap the pl optimizers to it, so its going to return the plain optimizer without needing to set use_pl_optimizers=False.

I do agree though, you should be getting the same LR regardless of passing use_pl_optimizers=False or not.

@rafathasan
Copy link
Author

rafathasan commented Apr 20, 2023

@ryan597 I think I should clarify my question a bit further. When I try to get lr with self.optimizers().param_groups[0]['lr'] the problem still persist while self.automatic_optimization=False is set. It only work with self.optimizers(use_pl_optimizers=False).param_groups[0]['lr'] while self.automatic_optimization=False is set. So my question was how come I can use optimizers correctly by manually calling self.optimizers().step() and self.optimizers().zero_grad() without passing use_pl_optimizers=False but it does not work for self.optimizers().param_groups[0]['lr'] ??

@Borda Borda changed the title StepLR doesn't work as expected after loading from checkpoint using Trainer.fit(ckpt_path=...) StepLR doesn't work as expected after loading from checkpoint using Trainer.fit(ckpt_path=...) May 3, 2023
@Lightning-AI Lightning-AI deleted a comment from rafathasan May 3, 2023
@Borda Borda added checkpointing Related to checkpointing tuner and removed needs triage Waiting to be triaged by maintainers labels May 3, 2023
@awaelchli awaelchli removed the tuner label May 5, 2023
@awaelchli
Copy link
Contributor

awaelchli commented Sep 20, 2023

Hey everyone. This was fixed in #18280 and released in 2.0.7. But no worries, the scheduler and optimizer was always correctly reloaded. The only bug was that the optimizer wrapper returned by self.optimizers() had an outdated state, but the internal optimizer was always using the correct state and that's the one used for training. The PR I linked will make sure the wrapper correctly represents the state of the user's optimizer.

I'm closing the issue because I was able to use the provided repro (thanks a ton, way to go!) to verify the fix. Cheers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing ver: 2.0.x
Projects
None yet
Development

No branches or pull requests

4 participants