Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load From Checkpoint Bug #17593

Closed
thisistejaspandey opened this issue May 8, 2023 · 2 comments
Closed

Load From Checkpoint Bug #17593

thisistejaspandey opened this issue May 8, 2023 · 2 comments
Labels
bug Something isn't working duplicate This issue or pull request already exists ver: 2.0.x

Comments

@thisistejaspandey
Copy link

Bug description

I'm trying to run inference on a model using pytorch. Loading checkpoints using lightning gives random results. Loading from state_dict is giving me a constant value during inference for one of the ouputs.

checkpoint_path = "path_to_checkpoint"

# Works, but one output is a constant.
checkpoint = torch.load(checkpoint_path, map_location="cuda")
model.load_state_dict(checkpoint["state_dict"])

# Too much error during inference. 
# model.load_from_checkpoint(checkpoint_path)

model.cuda().eval()

What version are you seeing the problem on?

v2.0

How to reproduce the bug

No response

Error messages and logs

No response

Environment

Inference Environement
  • CUDA:
    • GPU:
      • NVIDIA GeForce RTX 3090
    • available: True
    • version: 11.7
  • Lightning:
    • lightning: 2.0.2
    • lightning-cloud: 0.5.34
    • lightning-utilities: 0.8.0
    • pytorch-lightning: 2.0.1.post0
    • pytorch-triton: 2.1.0+7d1a95b046
    • torch: 2.1.0.dev20230508+cu117
    • torchaudio: 2.1.0.dev20230508+cu117
    • torchmetrics: 0.11.4
    • torchvision: 0.16.0.dev20230508+cu117
  • Packages:
    • absl-py: 1.4.0
    • addict: 2.4.0
    • aiohttp: 3.8.4
    • aiosignal: 1.3.1
    • alabaster: 0.7.13
    • anyio: 3.6.1
    • appdirs: 1.4.4
    • argcomplete: 2.0.0
    • argon2-cffi: 21.3.0
    • argon2-cffi-bindings: 21.2.0
    • arrow: 1.2.3
    • astroid: 2.15.3
    • asttokens: 2.0.5
    • async-timeout: 4.0.2
    • attrs: 22.2.0
    • babel: 2.10.3
    • backcall: 0.2.0
    • beautifulsoup4: 4.11.1
    • black: 23.1.0
    • bleach: 5.0.1
    • blessed: 1.20.0
    • blessings: 1.7
    • boto3: 1.26.118
    • botocore: 1.29.118
    • cachetools: 5.3.0
    • certifi: 2022.12.7
    • cffi: 1.15.1
    • charset-normalizer: 3.1.0
    • click: 8.1.3
    • colorama: 0.4.6
    • commonmark: 0.9.1
    • contourpy: 1.0.7
    • croniter: 1.3.14
    • cycler: 0.11.0
    • cython: 0.29.33
    • dateutils: 0.6.12
    • dcnv2: 0.1
    • debugpy: 1.6.0
    • decorator: 5.1.1
    • deepdiff: 6.3.0
    • defusedxml: 0.7.1
    • deprecation: 2.1.0
    • descartes: 1.1.0
    • dill: 0.3.6
    • distro: 1.8.0
    • docker-pycreds: 0.4.0
    • docutils: 0.18.1
    • easydict: 1.10
    • einops: 0.6.0
    • entrypoints: 0.4
    • exceptiongroup: 1.1.0
    • executing: 0.8.3
    • fastapi: 0.88.0
    • fastjsonschema: 2.15.3
    • filelock: 3.8.2
    • fire: 0.5.0
    • flake8: 6.0.0
    • flake8-docstrings: 1.7.0
    • fonttools: 4.38.0
    • frozenlist: 1.3.3
    • fsspec: 2023.4.0
    • gitdb: 4.0.10
    • gitpython: 3.1.31
    • gmplot: 1.4.1
    • google-auth: 2.16.0
    • google-auth-oauthlib: 0.4.6
    • grpcio: 1.51.1
    • h11: 0.14.0
    • idna: 3.4
    • imageio: 2.25.0
    • imagesize: 1.4.1
    • importlib-metadata: 4.12.0
    • importlib-resources: 5.10.2
    • iniconfig: 2.0.0
    • inquirer: 3.1.3
    • ipykernel: 6.15.0
    • ipython: 8.4.0
    • ipython-genutils: 0.2.0
    • ipywidgets: 7.7.1
    • isort: 5.12.0
    • itsdangerous: 2.1.2
    • jedi: 0.18.1
    • jinja2: 3.1.2
    • jmespath: 1.0.1
    • joblib: 1.1.0
    • json5: 0.9.8
    • jsonschema: 4.17.3
    • jupyter: 1.0.0
    • jupyter-client: 7.3.4
    • jupyter-console: 6.4.4
    • jupyter-core: 4.10.0
    • jupyter-packaging: 0.12.2
    • jupyter-server: 1.18.1
    • jupyterlab: 3.4.3
    • jupyterlab-pygments: 0.2.2
    • jupyterlab-server: 2.15.0
    • jupyterlab-widgets: 1.1.1
    • kiwisolver: 1.4.4
    • lazy-object-proxy: 1.9.0
    • lightning: 2.0.2
    • lightning-cloud: 0.5.34
    • lightning-utilities: 0.8.0
    • llvmlite: 0.36.0
    • loguru: 0.7.0
    • lyft-dataset-sdk: 0.0.8
    • markdown: 3.4.1
    • markdown-it-py: 2.2.0
    • markupsafe: 2.1.1
    • matplotlib: 3.6.3
    • matplotlib-inline: 0.1.3
    • mccabe: 0.7.0
    • mdurl: 0.1.2
    • mistune: 0.8.4
    • mmcls: 0.25.0
    • mmcv-full: 1.6.0
    • mmdet: 2.28.0
    • mmdet3d: 1.0.0rc6
    • mmsegmentation: 0.30.0
    • model-index: 0.1.11
    • motmetrics: 1.4.0
    • mpmath: 1.2.1
    • multidict: 6.0.4
    • mypy: 1.2.0
    • mypy-extensions: 1.0.0
    • nanoid: 2.0.0
    • nbclassic: 0.4.0
    • nbclient: 0.6.6
    • nbconvert: 6.5.0
    • nbformat: 5.4.0
    • nest-asyncio: 1.5.5
    • networkx: 2.2
    • ninja: 1.11.1
    • notebook: 6.4.12
    • notebook-shim: 0.1.0
    • numba: 0.53.0
    • numpy: 1.23.5
    • nuscenes-devkit: 1.1.9
    • oauthlib: 3.2.2
    • onnx: 1.13.1
    • onnx-simplifier: 0.4.10
    • open3d: 0.15.2
    • openai: 0.27.2
    • opencv-python: 4.7.0.68
    • openmim: 0.3.5
    • ordered-set: 4.1.0
    • orjson: 3.8.6
    • packaging: 23.0
    • pandas: 1.3.5
    • pandocfilters: 1.5.0
    • parso: 0.8.3
    • pathspec: 0.11.0
    • pathtools: 0.1.2
    • pexpect: 4.8.0
    • pickleshare: 0.7.5
    • pillow: 9.2.0
    • pip: 23.1.2
    • pipx: 1.1.0
    • pkgutil-resolve-name: 1.3.10
    • platformdirs: 3.1.1
    • plotly: 5.13.0
    • pluggy: 1.0.0
    • plyfile: 0.7.4
    • prefetch-generator: 1.0.3
    • prettytable: 3.6.0
    • progress: 1.6
    • prometheus-client: 0.14.1
    • prompt-toolkit: 3.0.38
    • protobuf: 3.20.3
    • psutil: 5.9.4
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • pyasn1: 0.4.8
    • pyasn1-modules: 0.2.8
    • pycocotools: 2.0.6
    • pycodestyle: 2.10.0
    • pycparser: 2.21
    • pydantic: 1.10.7
    • pydocstyle: 6.3.0
    • pyflakes: 3.0.1
    • pygame: 2.1.2
    • pygments: 2.14.0
    • pyjwt: 2.6.0
    • pylint: 2.17.2
    • pyparsing: 3.0.9
    • pyquaternion: 0.9.9
    • pyrsistent: 0.19.3
    • pysocks: 1.7.1
    • pyte: 0.8.1
    • pytest: 7.2.1
    • python-dateutil: 2.8.2
    • python-editor: 1.0.4
    • python-multipart: 0.0.6
    • pytorch-lightning: 2.0.1.post0
    • pytorch-triton: 2.1.0+7d1a95b046
    • pytz: 2022.7.1
    • pywavelets: 1.4.1
    • pyyaml: 6.0
    • pyzmq: 23.2.0
    • qtconsole: 5.4.0
    • qtpy: 2.3.0
    • readchar: 4.0.5
    • requests: 2.28.2
    • requests-oauthlib: 1.3.1
    • rich: 12.6.0
    • rsa: 4.9
    • s3transfer: 0.6.0
    • scalabel: 0.3.0
    • scikit-image: 0.19.3
    • scikit-learn: 1.1.1
    • scipy: 1.10.0
    • seaborn: 0.12.2
    • send2trash: 1.8.0
    • sentry-sdk: 1.20.0
    • setproctitle: 1.3.2
    • setuptools: 63.1.0
    • shapely: 2.0.1
    • shutup: 0.2.0
    • six: 1.16.0
    • smmap: 5.0.0
    • sniffio: 1.2.0
    • snowballstemmer: 2.2.0
    • soupsieve: 2.3.2.post1
    • sphinx: 6.1.3
    • sphinx-rtd-theme: 1.2.0
    • sphinxcontrib-applehelp: 1.0.4
    • sphinxcontrib-devhelp: 1.0.2
    • sphinxcontrib-htmlhelp: 2.0.1
    • sphinxcontrib-jquery: 4.1
    • sphinxcontrib-jsmath: 1.0.1
    • sphinxcontrib-qthelp: 1.0.3
    • sphinxcontrib-serializinghtml: 1.1.5
    • stack-data: 0.3.0
    • starlette: 0.22.0
    • starsessions: 1.3.0
    • sympy: 1.11.1
    • tabulate: 0.9.0
    • tenacity: 8.1.0
    • tensorboard: 2.11.2
    • tensorboard-data-server: 0.6.1
    • tensorboard-plugin-wit: 1.8.1
    • termcolor: 2.2.0
    • terminado: 0.15.0
    • terminaltables: 3.1.10
    • thop: 0.1.1.post2209072238
    • threadpoolctl: 3.1.0
    • tifffile: 2023.1.23.1
    • tinycss2: 1.1.1
    • toml: 0.10.2
    • tomli: 2.0.1
    • tomlkit: 0.11.0
    • torch: 2.1.0.dev20230508+cu117
    • torchaudio: 2.1.0.dev20230508+cu117
    • torchmetrics: 0.11.4
    • torchvision: 0.16.0.dev20230508+cu117
    • tornado: 6.2
    • tqdm: 4.65.0
    • traitlets: 5.3.0
    • trimesh: 2.35.39
    • typer: 0.7.0
    • typing-extensions: 4.5.0
    • undervolt: 0.3.0
    • urllib3: 1.26.15
    • userpath: 1.8.0
    • uvicorn: 0.21.1
    • wandb: 0.15.0
    • wcwidth: 0.2.6
    • webencodings: 0.5.1
    • websocket-client: 1.3.3
    • websockets: 11.0.2
    • werkzeug: 2.2.2
    • wheel: 0.37.1
    • widgetsnbextension: 3.6.1
    • wrapt: 1.15.0
    • xmltodict: 0.13.0
    • yacs: 0.1.8
    • yapf: 0.32.0
    • yarl: 1.8.2
    • yolox: 0.3.0
    • zipp: 3.12.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.16
    • release: 5.15.0-69-generic
    • version: Quantisation and Pruning Support #76~20.04.1-Ubuntu SMP Mon Mar 20 15:54:19 UTC 2023
Training environment
  • CUDA:
    • GPU: A40
    • available: False
    • version: 11.7
  • Lightning:
    • lightning: 2.0.2
    • lightning-cloud: 0.5.34
    • lightning-utilities: 0.8.0
    • pytorch-lightning: 2.0.2
    • pytorch-triton: 2.1.0+7d1a95b046
    • torch: 2.1.0.dev20230506+cu117
    • torchaudio: 2.0.0
    • torchmetrics: 0.11.4
    • torchvision: 0.15.0
  • Packages:
    • aiohttp: 3.8.4
    • aiosignal: 1.3.1
    • anyio: 3.6.2
    • appdirs: 1.4.4
    • arrow: 1.2.3
    • async-timeout: 4.0.2
    • attrs: 23.1.0
    • beautifulsoup4: 4.12.2
    • blessed: 1.20.0
    • brotlipy: 0.7.0
    • certifi: 2022.12.7
    • cffi: 1.15.1
    • charset-normalizer: 2.0.4
    • click: 8.1.3
    • croniter: 1.3.14
    • cryptography: 39.0.1
    • dateutils: 0.6.12
    • deepdiff: 6.3.0
    • docker-pycreds: 0.4.0
    • einops: 0.6.1
    • fastapi: 0.88.0
    • filelock: 3.9.0
    • frozenlist: 1.3.3
    • fsspec: 2023.4.0
    • gitdb: 4.0.10
    • gitpython: 3.1.31
    • gmpy2: 2.1.2
    • h11: 0.14.0
    • idna: 3.4
    • inquirer: 3.1.3
    • itsdangerous: 2.1.2
    • jinja2: 3.1.2
    • lightning: 2.0.2
    • lightning-cloud: 0.5.34
    • lightning-utilities: 0.8.0
    • markdown-it-py: 2.2.0
    • markupsafe: 2.1.2
    • mdurl: 0.1.2
    • mkl-fft: 1.3.1
    • mkl-random: 1.2.2
    • mkl-service: 2.4.0
    • mpmath: 1.2.1
    • multidict: 6.0.4
    • networkx: 3.0rc1
    • numpy: 1.24.3
    • opencv-python: 4.7.0.72
    • ordered-set: 4.1.0
    • packaging: 23.1
    • pathtools: 0.1.2
    • pillow: 9.4.0
    • pip: 23.0.1
    • protobuf: 4.22.3
    • psutil: 5.9.5
    • pycparser: 2.21
    • pydantic: 1.10.7
    • pygments: 2.15.1
    • pyjwt: 2.6.0
    • pyopenssl: 23.0.0
    • pysocks: 1.7.1
    • python-dateutil: 2.8.2
    • python-editor: 1.0.4
    • python-multipart: 0.0.6
    • pytorch-lightning: 2.0.2
    • pytorch-triton: 2.1.0+7d1a95b046
    • pytz: 2023.3
    • pyyaml: 6.0
    • readchar: 4.0.5
    • requests: 2.29.0
    • rich: 13.3.5
    • scipy: 1.10.1
    • sentry-sdk: 1.21.1
    • setproctitle: 1.3.2
    • setuptools: 66.0.0
    • six: 1.16.0
    • smmap: 5.0.0
    • sniffio: 1.3.0
    • soupsieve: 2.4.1
    • starlette: 0.22.0
    • starsessions: 1.3.0
    • sympy: 1.11.1
    • termcolor: 2.3.0
    • torch: 2.1.0.dev20230506+cu117
    • torchaudio: 2.0.0
    • torchmetrics: 0.11.4
    • torchvision: 0.15.0
    • tqdm: 4.65.0
    • traitlets: 5.9.0
    • triton: 2.0.0
    • typing-extensions: 4.4.0
    • urllib3: 1.26.15
    • uvicorn: 0.22.0
    • wandb: 0.15.1
    • wcwidth: 0.2.6
    • websocket-client: 1.5.1
    • websockets: 11.0.2
    • wheel: 0.38.4
    • yarl: 1.9.2
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.10.11
    • release: 3.10.0-1160.80.1.el7.x86_64
    • version: Proposal for help #1 SMP Tue Nov 8 15:48:59 UTC 2022

More info

No response

@thisistejaspandey thisistejaspandey added bug Something isn't working needs triage Waiting to be triaged by maintainers labels May 8, 2023
@thisistejaspandey
Copy link
Author

# Works fine.
checkpoint_path = "path_to_checkpoint"   
trainer.validate(model, datamodule=data, ckpt_path=checkpoint_path)

# Doesn't work.
model.load_from_checkpoint(checkpoint_path)
trainer.validate(model, datamodule=data)

@awaelchli
Copy link
Contributor

Hi @thisistejaspandey
See the resolution here: #18169

model.load_from_checkpoint(checkpoint_path)

needs to be

model = YourModelClass.load_from_checkpoint(checkpoint_path)

We are working in #18169 to make this clearer.

@awaelchli awaelchli closed this as not planned Won't fix, can't repro, duplicate, stale Aug 14, 2023
@awaelchli awaelchli added duplicate This issue or pull request already exists and removed needs triage Waiting to be triaged by maintainers labels Aug 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working duplicate This issue or pull request already exists ver: 2.0.x
Projects
None yet
Development

No branches or pull requests

2 participants