Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OneCycleLR scheduler state not restored from checkpoint file #15462

Closed
vhewes opened this issue Nov 2, 2022 · 2 comments
Closed

OneCycleLR scheduler state not restored from checkpoint file #15462

vhewes opened this issue Nov 2, 2022 · 2 comments
Labels
bug Something isn't working optimizer
Milestone

Comments

@vhewes
Copy link

vhewes commented Nov 2, 2022

Bug description

i recently adapted a network architecture to a LightningModule, and find that when resuming a training in progress from a checkpoint file, the state of the OneCycleLR scheduler is not properly restored. i've tested with version 1.8.0 and confirmed that the issue persists.

the example pasted below will run a full end-to-end training of a toy model and dataset, then train the same model halfway to completion, and finally load the checkpoint file from the halfway-trained model and train it the rest of the way. the example will plot the learning rate in both cases, and demonstrate that in the latter case, the learning rate scheduler's internal state is not restored successfully when loading from the checkpoint file.

How to reproduce the bug

from typing import List
import torch
import pytorch_lightning as pl
import pandas as pd
import matplotlib.pyplot as plt

class ToyDataset(torch.utils.data.Dataset):
    def __init__(self):
        super(ToyDataset, self).__init__()

    def __len__(self) -> int:
        return 1000

    def __getitem__(self, idx: int) -> List[torch.Tensor]:
        return [ torch.rand(5), torch.randint(high=2, size=[1]) ]

class ToyModel(pl.LightningModule):
    def __init__(self):
        super(ToyModel, self).__init__()

        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_features=5,
                            out_features=1),
            torch.nn.Sigmoid())

        self.hparams.learning_rate = 0.1

        self.loss_func = torch.nn.BCELoss()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

    def training_step(self,
                      batch: torch.Tensor,
                      batch_idx: int) -> float:
        x, y = batch
        x = self(x)
        loss = self.loss_func(x, y.float())
        self.log('learning rate', self.optimizers().state_dict()['param_groups'][0]['lr'])
        return loss

    def configure_optimizers(self) -> tuple:
        print(self.trainer.estimated_stepping_batches)
        optimizer = torch.optim.SGD(self.parameters(),
                                    lr=self.hparams.learning_rate,
                                    momentum=0.9)
        onecycle = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.hparams.learning_rate,
            final_div_factor=1e6,
            total_steps=100)
        return [optimizer], {'scheduler': onecycle, 'interval': 'step'}

def main():

    data = torch.utils.data.DataLoader(ToyDataset(), batch_size=200)
    model = ToyModel()

    # train all the way
    logger = pl.loggers.CSVLogger(save_dir='.',
                                  name='toy',
                                  version='full')
    trainer = pl.Trainer(accelerator='cpu',
                         max_epochs=20,
                         log_every_n_steps=1,
                         logger=logger)
    trainer.fit(model, train_dataloaders=data)
    
    # train halfway
    logger = pl.loggers.CSVLogger(save_dir='.',
                                  name='toy',
                                  version='halfway')
    trainer = pl.Trainer(accelerator='cpu',
                         max_epochs=10,
                         log_every_n_steps=1,
                         logger=logger)
    trainer.fit(model, train_dataloaders=data)
    
    # load checkpoint and train the rest of the way
    logger = pl.loggers.CSVLogger(save_dir='.',
                                   name='toy',
                                   version='resumed')
    trainer = pl.Trainer(accelerator='cpu',
                         max_epochs=20,
                         log_every_n_steps=1,
                         logger=logger)
    trainer.fit(model,
                train_dataloaders=data,
                ckpt_path='toy/halfway/checkpoints/epoch=9-step=50.ckpt')

    # compare the learning rates of the full and resumed training instances
    full = pd.read_csv('toy/full/metrics.csv')
    halfway = pd.read_csv('toy/halfway/metrics.csv')
    resumed = pd.read_csv('toy/resumed/metrics.csv')
    combined = pd.concat([halfway, resumed])
    ax = full.plot(x='step', y='learning rate', ylabel='learning rate', label='full')
    combined.plot(x='step', y='learning rate', label='resumed', ax=ax)
    plt.show()

if __name__ == '__main__':

    main()

Error messages and logs

No response

Environment

  • CUDA:
    • GPU: None
    • available: False
    • version: None
  • Lightning:
    • pytorch-lightning: 1.7.7
    • torch: 1.13.0
    • torch-geometric: 2.1.0.post1
    • torch-scatter: 2.0.9
    • torchmetrics: 0.10.1
  • Packages:
    • absl-py: 1.3.0
    • aiohttp: 3.8.3
    • aiosignal: 1.2.0
    • anyio: 3.6.2
    • appnope: 0.1.3
    • argon2-cffi: 21.3.0
    • argon2-cffi-bindings: 21.2.0
    • asttokens: 2.0.8
    • async-timeout: 4.0.2
    • attrs: 22.1.0
    • awkward: 1.10.1
    • babel: 2.10.3
    • backcall: 0.2.0
    • backports.functools-lru-cache: 1.6.4
    • beautifulsoup4: 4.11.1
    • bleach: 5.0.1
    • blinker: 1.5
    • brotlipy: 0.7.0
    • build: 0.7.0
    • cached-property: 1.5.2
    • cachetools: 5.2.0
    • certifi: 2022.9.24
    • cffi: 1.15.1
    • charset-normalizer: 2.1.1
    • click: 8.1.3
    • colorama: 0.4.6
    • conda: 22.9.0
    • conda-package-handling: 1.9.0
    • contourpy: 1.0.6
    • cryptography: 38.0.2
    • cycler: 0.11.0
    • debugpy: 1.6.3
    • decorator: 5.1.1
    • defusedxml: 0.7.1
    • deprecated: 1.2.13
    • entrypoints: 0.4
    • executing: 1.1.1
    • fastjsonschema: 2.16.2
    • flit-core: 3.7.1
    • fonttools: 4.38.0
    • frozenlist: 1.3.1
    • fsspec: 2022.10.0
    • google-auth: 2.14.0
    • google-auth-oauthlib: 0.4.6
    • grpcio: 1.49.1
    • h5py: 3.7.0
    • hepunits: 2.3.0
    • idna: 3.4
    • importlib-metadata: 5.0.0
    • importlib-resources: 5.10.0
    • ipykernel: 6.17.0
    • ipython: 8.6.0
    • ipython-genutils: 0.2.0
    • ipywidgets: 8.0.2
    • jedi: 0.18.1
    • jinja2: 3.1.2
    • joblib: 1.2.0
    • json5: 0.9.5
    • jsonschema: 4.16.0
    • jupyter-client: 7.4.4
    • jupyter-core: 4.11.1
    • jupyter-server: 1.21.0
    • jupyterlab: 3.5.0
    • jupyterlab-pygments: 0.2.2
    • jupyterlab-server: 2.16.2
    • jupyterlab-widgets: 3.0.3
    • kiwisolver: 1.4.4
    • libmambapy: 0.27.0
    • llvmlite: 0.39.1
    • lz4: 4.0.2
    • mamba: 0.27.0
    • markdown: 3.4.1
    • markupsafe: 2.1.1
    • matplotlib: 3.6.0
    • matplotlib-inline: 0.1.6
    • mistune: 2.0.4
    • mpi4py: 3.1.3
    • multidict: 6.0.2
    • munkres: 1.1.4
    • nbclassic: 0.4.7
    • nbclient: 0.7.0
    • nbconvert: 7.2.3
    • nbformat: 5.7.0
    • nest-asyncio: 1.5.6
    • notebook: 6.4.12
    • notebook-shim: 0.2.0
    • nugraph: 0.1.0
    • numba: 0.56.3
    • numpy: 1.23.4
    • oauthlib: 3.2.2
    • packaging: 21.3
    • pandas: 1.5.1
    • pandocfilters: 1.5.0
    • parso: 0.8.3
    • particle: 0.20.1
    • patsy: 0.5.3
    • pep517: 0.12.0
    • pexpect: 4.8.0
    • pickleshare: 0.7.5
    • pillow: 9.2.0
    • pip: 22.3
    • pkgutil-resolve-name: 1.3.10
    • plotly: 5.11.0
    • prometheus-client: 0.15.0
    • prompt-toolkit: 3.0.31
    • protobuf: 4.21.9
    • psutil: 5.9.3
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • pyasn1: 0.4.8
    • pyasn1-modules: 0.2.7
    • pycosat: 0.6.4
    • pycparser: 2.21
    • pydeprecate: 0.3.2
    • pygments: 2.13.0
    • pyjwt: 2.6.0
    • pyopenssl: 22.1.0
    • pyparsing: 3.0.9
    • pyrsistent: 0.18.1
    • pysocks: 1.7.1
    • python-dateutil: 2.8.2
    • pytorch-lightning: 1.7.7
    • pytz: 2022.5
    • pyu2f: 0.1.5
    • pyyaml: 6.0
    • pyzmq: 24.0.1
    • requests: 2.28.1
    • requests-oauthlib: 1.3.1
    • rsa: 4.9
    • ruamel-yaml-conda: 0.15.80
    • scikit-learn: 1.1.3
    • scipy: 1.9.3
    • seaborn: 0.12.1
    • send2trash: 1.8.0
    • setuptools: 65.5.0
    • six: 1.16.0
    • sniffio: 1.3.0
    • soupsieve: 2.3.2.post1
    • stack-data: 0.5.1
    • statsmodels: 0.13.2
    • tenacity: 8.1.0
    • tensorboard: 2.10.1
    • tensorboard-data-server: 0.6.0
    • tensorboard-plugin-wit: 1.8.1
    • terminado: 0.17.0
    • threadpoolctl: 3.1.0
    • tinycss2: 1.2.1
    • tomli: 2.0.1
    • toolz: 0.12.0
    • torch: 1.13.0
    • torch-geometric: 2.1.0.post1
    • torch-scatter: 2.0.9
    • torchmetrics: 0.10.1
    • tornado: 6.2
    • tqdm: 4.64.1
    • traitlets: 5.5.0
    • typing-extensions: 4.4.0
    • unicodedata2: 15.0.0
    • uproot: 4.3.7
    • urllib3: 1.26.11
    • wcwidth: 0.2.5
    • webencodings: 0.5.1
    • websocket-client: 1.4.1
    • werkzeug: 2.2.2
    • wheel: 0.37.1
    • widgetsnbextension: 4.0.3
    • wrapt: 1.14.1
    • xrootd: 5.5.1
    • xxhash: 0.0.0
    • yarl: 1.8.1
    • zipp: 3.10.0
    • zstandard: 0.19.0
  • System:
    • OS: Darwin
    • architecture:
      • 64bit
    • processor: arm
    • python: 3.10.6
    • version: Darwin Kernel Version 21.4.0: Mon Feb 21 20:35:58 PST 2022; root:xnu-8020.101.4~2/RELEASE_ARM64_T6000

More info

the environment provided above is my local machine, where i constructed the toy, but i also observe the same issue on an Nvidia GPU cluster and in HPC environments, so it is not localised to a specific architecture.

cc @rohitgr7

@vhewes vhewes added the needs triage Waiting to be triaged by maintainers label Nov 2, 2022
@rohitgr7
Copy link
Contributor

rohitgr7 commented Nov 2, 2022

okay.. I know the issue
just to unblock you, can you use

self.log('learning rate', self.trainer.optimizers[0].state_dict()['param_groups'][0]['lr'])

@rohitgr7 rohitgr7 added bug Something isn't working optimizer and removed needs triage Waiting to be triaged by maintainers labels Nov 2, 2022
@rohitgr7 rohitgr7 self-assigned this Nov 2, 2022
@rohitgr7 rohitgr7 added this to the v1.8.x milestone Nov 2, 2022
@Borda Borda self-assigned this Nov 7, 2022
@Borda Borda modified the milestones: v1.8.x, v1.9 Jan 6, 2023
@Borda Borda modified the milestones: v1.9, v1.9.x Jan 16, 2023
@awaelchli
Copy link
Contributor

This was fixed in #18280
See my full reply here on another issue: #17296 (comment)

@awaelchli awaelchli modified the milestones: v1.9.x, 2.0.x Sep 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working optimizer
Projects
None yet
Development

No branches or pull requests

4 participants