Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect Checkpoint storing path when using WandbLogger #17298

Closed
leng-yue opened this issue Apr 7, 2023 · 1 comment · Fixed by #17818
Closed

Incorrect Checkpoint storing path when using WandbLogger #17298

leng-yue opened this issue Apr 7, 2023 · 1 comment · Fixed by #17818
Assignees
Labels
bug Something isn't working logger: wandb Weights & Biases ver: 2.0.x
Milestone

Comments

@leng-yue
Copy link
Contributor

leng-yue commented Apr 7, 2023

Bug description

Instead of saving checkpoints to test/[ID]/checkpoints, the following example saves them to test/version_None/checkpoints.
To resolve this issue, we can call wandb_logger.experiment before creating the trainer, which will ensure that it writes into the correct folder.
Based on this observation, it seems likely that the setup function of ModelCheckpoint is called before a wandb experiment is created.

What version are you seeing the problem on?

2.0+

How to reproduce the bug

# Reference: https://gist.github.com/rain1024/8ea4c2f56aa4c9ba0e1cbf35edb68eca

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torch.nn import MSELoss
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset


class SimpleDataset(Dataset):
    def __init__(self):
        X = np.arange(10000)
        y = X * 2
        X = [[_] for _ in X]
        y = [[_] for _ in y]
        self.X = torch.Tensor(X)
        self.y = torch.Tensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return {"X": self.X[idx], "y": self.y[idx]}


class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1, 1)
        self.criterion = MSELoss()

    def forward(self, inputs_id, labels=None):
        outputs = self.fc(inputs_id)
        loss = 0

        if labels is not None:
            loss = self.criterion(outputs, labels)
            self.log("mse_loss", loss)

        return loss, outputs

    def train_dataloader(self):
        dataset = SimpleDataset()

        return DataLoader(dataset, batch_size=1000, num_workers=12)

    def training_step(self, batch, batch_idx):
        input_ids = batch["X"]
        labels = batch["y"]
        loss, outputs = self(input_ids, labels)
        return {"loss": loss}

    def configure_optimizers(self):
        optimizer = Adam(self.parameters())
        return optimizer


if __name__ == "__main__":
    wandb_logger = WandbLogger(project="test")
    model = MyModel()
    trainer = pl.Trainer(
        devices=[0],
        logger=wandb_logger,
        callbacks=[
            ModelCheckpoint(
                filename="{epoch}-{step}",
            )
        ],
        max_steps=200,
    )

    trainer.fit(model)

Error messages and logs

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
./test/version_None/checkpoints
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type    | Params
--------------------------------------
0 | fc        | Linear  | 2     
1 | criterion | MSELoss | 0     
--------------------------------------
2         Trainable params
0         Non-trainable params
2         Total params
0.000     Total estimated model params size (MB)
/home/lengyue/miniconda3/envs/fish-diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:280: PossibleUserWarning: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 29.23it/s]wandb: Currently logged in as: lengyue. Use `wandb login --relogin` to force relogin
wandb: wandb version 0.14.1 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
wandb: Tracking run with wandb version 0.13.11
wandb: Run data is saved locally in ./wandb/run-20230406_203908-tpp1ssbn
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run fast-lake-1
wandb: ⭐️ View project at https://wandb.ai/lengyue/test
wandb: 🚀 View run at https://wandb.ai/lengyue/test/runs/tpp1ssbn
Epoch 19: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 25.55it/s, v_num=ssbn]`Trainer.fit` stopped: `max_steps=200` reached.
Epoch 19: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 25.34it/s, v_num=ssbn]
wandb: Waiting for W&B process to finish... (success).
wandb: - 0.014 MB of 0.014 MB uploaded (0.000 MB deduped)
wandb: Run history:
wandb:               epoch ▁▃▆█
wandb:            mse_loss █▆▃▁
wandb: trainer/global_step ▁▃▆█
wandb: 
wandb: Run summary:
wandb:               epoch 19
wandb:            mse_loss 170994864.0
wandb: trainer/global_step 199
wandb: 
wandb: 🚀 View run fast-lake-1 at: https://wandb.ai/lengyue/test/runs/tpp1ssbn
wandb: Synced 7 W&B file(s), 0 media file(s), 3 artifact file(s) and 1 other file(s)
wandb: Find logs at: ./wandb/run-20230406_203908-tpp1ssbn/logs

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 3090
    - NVIDIA GeForce RTX 3090
    - available: True
    - version: 11.8
  • Lightning:
    - lightning-cloud: 0.5.32
    - lightning-utilities: 0.8.0
    - pytorch-lightning: 2.0.1
    - torch: 2.0.0+cu118
    - torchaudio: 2.0.1+cu118
    - torchcrepe: 0.0.17
    - torchmetrics: 0.11.4
    - torchvision: 0.15.1+cu118
  • Packages:
    - absl-py: 1.4.0
    - addict: 2.4.0
    - aiofiles: 23.1.0
    - aiohttp: 3.8.4
    - aiosignal: 1.3.1
    - alabaster: 0.7.13
    - altair: 4.2.2
    - antlr4-python3-runtime: 4.9.3
    - anyio: 3.6.2
    - appdirs: 1.4.4
    - arrow: 1.2.3
    - asgiref: 3.6.0
    - async-generator: 1.10
    - async-timeout: 4.0.2
    - attrs: 22.2.0
    - audioread: 3.0.0
    - babel: 2.12.1
    - beautifulsoup4: 4.11.1
    - black: 22.12.0
    - blessed: 1.20.0
    - brotlipy: 0.7.0
    - build: 0.10.0
    - cachecontrol: 0.12.11
    - cachetools: 5.3.0
    - certifi: 2022.12.7
    - cffi: 1.15.1
    - cfgv: 3.3.1
    - charset-normalizer: 3.1.0
    - cleo: 2.0.1
    - click: 8.1.3
    - cloudpickle: 2.2.1
    - cmake: 3.26.1
    - codecov: 2.1.12
    - colorama: 0.4.6
    - coloredlogs: 15.0.1
    - commonmark: 0.9.1
    - contourpy: 1.0.7
    - coverage: 6.5.0
    - crashtest: 0.4.1
    - croniter: 1.3.8
    - cryptography: 39.0.1
    - cycler: 0.11.0
    - cython: 0.29.34
    - dateutils: 0.6.12
    - decorator: 5.1.1
    - deepdiff: 6.2.3
    - demucs: 4.0.0
    - deprecated: 1.2.13
    - diffq: 0.2.3
    - distlib: 0.3.6
    - dnspython: 2.3.0
    - docker-pycreds: 0.4.0
    - docutils: 0.19
    - dora-search: 0.1.11
    - dulwich: 0.21.3
    - einops: 0.6.0
    - email-validator: 1.3.1
    - encodec: 0.1.1
    - entrypoints: 0.4
    - exceptiongroup: 1.1.1
    - fastapi: 0.88.0
    - ffmpeg-python: 0.2.0
    - ffmpy: 0.3.0
    - filelock: 3.10.7
    - fish-audio-preprocess: 0.1.10
    - fish-diffusion: 0.1.0
    - flask: 2.2.3
    - flask-cors: 3.0.10
    - flatbuffers: 23.3.3
    - flit-core: 3.8.0
    - fonttools: 4.39.3
    - frozenlist: 1.3.3
    - fsspec: 2023.3.0
    - furo: 2022.12.7
    - future: 0.18.3
    - gitdb: 4.0.10
    - gitpython: 3.1.31
    - gmpy2: 2.1.2
    - google-auth: 2.17.1
    - google-auth-oauthlib: 1.0.0
    - gpustat: 1.0.0
    - gradio: 3.24.1
    - gradio-client: 0.0.5
    - greenlet: 2.0.1
    - grpcio: 1.53.0
    - h11: 0.14.0
    - html5lib: 1.1
    - httpcore: 0.16.3
    - httptools: 0.5.0
    - httpx: 0.23.3
    - huggingface-hub: 0.13.3
    - humanfriendly: 10.0
    - identify: 2.5.22
    - idna: 3.4
    - imagesize: 1.4.1
    - importlib-metadata: 6.0.0
    - iniconfig: 2.0.0
    - inquirer: 3.1.2
    - installer: 0.6.0
    - isort: 5.12.0
    - itsdangerous: 2.1.2
    - jaconv: 0.3.4
    - jaraco.classes: 3.2.3
    - jinja2: 3.1.2
    - joblib: 1.2.0
    - jsonschema: 4.17.3
    - julius: 0.2.7
    - keyring: 23.13.1
    - kiwisolver: 1.4.4
    - lameenc: 1.4.2
    - libf0: 1.0.2
    - librosa: 0.9.1
    - lightning-cloud: 0.5.32
    - lightning-utilities: 0.8.0
    - linkify-it-py: 2.0.0
    - lit: 16.0.0
    - livereload: 2.6.3
    - llvmlite: 0.39.1
    - lockfile: 0.12.2
    - loguru: 0.6.0
    - markdown: 3.4.3
    - markdown-it-py: 2.2.0
    - markupsafe: 2.1.1
    - matplotlib: 3.7.1
    - mdit-py-plugins: 0.3.3
    - mdurl: 0.1.2
    - memray: 1.7.0
    - mkl-fft: 1.3.1
    - mkl-random: 1.2.2
    - mkl-service: 2.4.0
    - mmengine: 0.4.0
    - more-itertools: 9.1.0
    - mpmath: 1.3.0
    - msgpack: 1.0.4
    - multidict: 6.0.4
    - mypy-extensions: 1.0.0
    - myst-parser: 0.18.1
    - natsort: 8.3.1
    - networkx: 2.8.4
    - nodeenv: 1.7.0
    - numba: 0.56.4
    - numpy: 1.23.5
    - nvidia-ml-py: 11.495.46
    - oauthlib: 3.2.2
    - omegaconf: 2.3.0
    - onnx: 1.12.0
    - onnxruntime: 1.14.1
    - openai-whisper: 20230124
    - opencv-python: 4.7.0.72
    - openunmix: 1.2.1
    - ordered-set: 4.1.0
    - orjson: 3.8.9
    - outcome: 1.2.0
    - packaging: 23.0
    - pandas: 1.5.3
    - pathspec: 0.11.1
    - pathtools: 0.1.2
    - pillow: 9.5.0
    - pip: 23.0.1
    - pkginfo: 1.9.6
    - platformdirs: 2.6.2
    - playwright: 1.30.0
    - pluggy: 1.0.0
    - poetry: 1.4.0
    - poetry-core: 1.5.1
    - poetry-plugin-export: 1.3.0
    - pooch: 1.7.0
    - praat-parselmouth: 0.4.3
    - pre-commit: 3.2.2
    - protobuf: 4.22.1
    - psutil: 5.9.4
    - py: 1.11.0
    - pyasn1: 0.4.8
    - pyasn1-modules: 0.2.8
    - pycparser: 2.21
    - pydantic: 1.10.7
    - pydub: 0.25.1
    - pyee: 9.0.4
    - pygments: 2.14.0
    - pyjwt: 2.6.0
    - pykakasi: 2.2.1
    - pyloudnorm: 0.1.1
    - pympler: 1.0.1
    - pyopenssl: 23.0.0
    - pyparsing: 3.0.9
    - pypinyin: 0.48.0
    - pyproject-hooks: 1.0.0
    - pyrsistent: 0.19.3
    - pysocks: 1.7.1
    - pysoundfile: 0.9.0.post1
    - pytest: 7.2.2
    - pytest-asyncio: 0.20.3
    - pytest-cov: 4.0.0
    - pytest-doctestplus: 0.12.1
    - pytest-forked: 1.4.0
    - pytest-rerunfailures: 10.3
    - pytest-timeout: 2.1.0
    - python-dateutil: 2.8.2
    - python-dotenv: 1.0.0
    - python-editor: 1.0.4
    - python-multipart: 0.0.6
    - pytorch-lightning: 2.0.1
    - pytz: 2023.3
    - pyworld: 0.3.2
    - pyyaml: 6.0
    - rapidfuzz: 2.13.7
    - readchar: 4.0.5
    - regex: 2023.3.23
    - requests: 2.28.1
    - requests-mock: 1.10.0
    - requests-oauthlib: 1.3.1
    - requests-toolbelt: 0.10.1
    - resampy: 0.4.2
    - retrying: 1.3.4
    - rfc3986: 1.5.0
    - rich: 13.3.3
    - richuru: 0.1.1
    - rsa: 4.9
    - scikit-learn: 1.2.2
    - scipy: 1.9.3
    - semantic-version: 2.10.0
    - sentry-sdk: 1.18.0
    - setproctitle: 1.3.2
    - setuptools: 67.6.1
    - shellingham: 1.5.1
    - six: 1.16.0
    - smmap: 5.0.0
    - sniffio: 1.3.0
    - snowballstemmer: 2.2.0
    - sortedcontainers: 2.4.0
    - soundfile: 0.11.0
    - soupsieve: 2.3.2.post1
    - sphinx: 5.3.0
    - sphinx-autobuild: 2021.3.14
    - sphinx-basic-ng: 1.0.0b1
    - sphinxcontrib-applehelp: 1.0.4
    - sphinxcontrib-devhelp: 1.0.2
    - sphinxcontrib-htmlhelp: 2.0.1
    - sphinxcontrib-jsmath: 1.0.1
    - sphinxcontrib-qthelp: 1.0.3
    - sphinxcontrib-serializinghtml: 1.1.5
    - sqlalchemy: 1.4.41
    - sqlalchemy2-stubs: 0.0.2a32
    - sqlmodel: 0.0.8
    - starlette: 0.22.0
    - starsessions: 1.3.0
    - submitit: 1.4.5
    - sympy: 1.11.1
    - tensorboard: 2.12.1
    - tensorboard-data-server: 0.7.0
    - tensorboard-plugin-wit: 1.8.1
    - termcolor: 2.2.0
    - textgrid: 1.5
    - threadpoolctl: 3.1.0
    - tokenizers: 0.13.2
    - tomli: 2.0.1
    - tomlkit: 0.11.6
    - toolz: 0.12.0
    - torch: 2.0.0+cu118
    - torchaudio: 2.0.1+cu118
    - torchcrepe: 0.0.17
    - torchmetrics: 0.11.4
    - torchvision: 0.15.1+cu118
    - tornado: 6.2
    - tqdm: 4.65.0
    - traitlets: 5.8.1
    - transformers: 4.27.4
    - treetable: 0.2.5
    - trio: 0.21.0
    - triton: 2.0.0
    - trove-classifiers: 2023.2.8
    - typing-extensions: 4.5.0
    - uc-micro-py: 1.0.1
    - ujson: 5.7.0
    - urllib3: 1.26.13
    - uvicorn: 0.21.1
    - uvloop: 0.17.0
    - virtualenv: 20.19.0
    - wandb: 0.13.11
    - watchgod: 0.8.2
    - wcwidth: 0.2.6
    - webencodings: 0.5.1
    - websocket-client: 1.5.1
    - websockets: 11.0
    - werkzeug: 2.2.3
    - wheel: 0.40.0
    - wrapt: 1.15.0
    - yapf: 0.32.0
    - yarl: 1.8.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.10
    - version: Quantisation and Pruning Support #76-Ubuntu SMP Fri Mar 17 17:19:29 UTC 2023

More info

No response

cc @awaelchli @morganmcg1 @borisdayma @scottire @parambharat

@leng-yue leng-yue added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 7, 2023
@leng-yue
Copy link
Contributor Author

I was unable to replicate this bug in a newly installed environment. However, updating Lightning in the existing environment did not resolve the issue. I will continue debugging and provide further updates.

@awaelchli awaelchli added logger: wandb Weights & Biases and removed needs triage Waiting to be triaged by maintainers labels Jun 12, 2023
@awaelchli awaelchli self-assigned this Jun 12, 2023
@awaelchli awaelchli added this to the 2.0.x milestone Jun 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working logger: wandb Weights & Biases ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants