typeerror when trying to fit a TemporalFusionTransformer model #17458

jsejdija opened this issue Apr 24, 2023 · 3 comments

bug Something isn't working needs triage Waiting to be triaged by maintainers


Bug description

when trying to fit a TemporalFusionTransformer there is a typeerror.

What version are you seeing the problem on?


How to reproduce the bug

trainer = pl.Trainer(
	devices=1, accelerator="gpu",

tft = TemporalFusionTransformer.from_dataset(

Error messages and logs

typeerror: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`


Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100-PCIE-40GB
    - available: True
    - version: 11.7
  • Lightning:
    - lightning: 2.0.1
    - lightning-cloud: 0.5.32
    - lightning-utilities: 0.8.0
    - pytorch-forecasting: 1.0.0
    - pytorch-lightning: 2.0.1.post0
    - pytorch-optimizer: 2.5.1
    - torch: 2.0.0
    - torchmetrics: 0.11.4
  • Packages:
    - absl-py: 1.4.0
    - aiohttp: 3.8.4
    - aiosignal: 1.3.1
    - alembic: 1.10.3
    - anyio: 3.6.2
    - argon2-cffi: 21.3.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.2.3
    - asttokens: 2.2.1
    - astunparse: 1.6.3
    - async-timeout: 4.0.2
    - attrs: 22.2.0
    - backcall: 0.2.0
    - backports.functools-lru-cache: 1.6.4
    - beautifulsoup4: 4.12.2
    - bleach: 6.0.0
    - blessed: 1.20.0
    - cachetools: 5.3.0
    - certifi: 2022.12.7
    - cffi: 1.15.1
    - charset-normalizer: 3.1.0
    - click: 8.1.3
    - cmaes: 0.9.1
    - cmake: 3.26.3
    - colorlog: 6.7.0
    - comm: 0.1.3
    - contourpy: 1.0.7
    - convertdate: 2.4.0
    - croniter: 1.3.10
    - cubinlinker: 0.2.2
    - cuda-python: 11.8.1
    - cudf: 23.4.0
    - cupy: 11.6.0
    - cycler: 0.11.0
    - dateutils: 0.6.12
    - debugpy: 1.6.7
    - decorator: 5.1.1
    - deepdiff: 6.3.0
    - defusedxml: 0.7.1
    - dnspython: 2.3.0
    - email-validator: 1.3.1
    - entrypoints: 0.4
    - executing: 1.2.0
    - fastapi: 0.88.0
    - fastavro: 1.7.3
    - fastjsonschema: 2.16.3
    - fastrlock: 0.8
    - filelock: 3.11.0
    - flatbuffers: 23.3.3
    - flit-core: 3.8.0
    - fonttools: 4.39.3
    - frozenlist: 1.3.3
    - fsspec: 2023.4.0
    - gast: 0.4.0
    - google-auth: 2.17.1
    - google-auth-oauthlib: 1.0.0
    - google-pasta: 0.2.0
    - greenlet: 2.0.2
    - grpcio: 1.53.0
    - h11: 0.14.0
    - h5py: 3.8.0
    - hijri-converter: 2.2.4
    - holidays: 0.23
    - httpcore: 0.17.0
    - httptools: 0.5.0
    - httpx: 0.24.0
    - hupper: 1.12
    - idna: 3.4
    - importlib-metadata: 6.6.0
    - importlib-resources: 5.12.0
    - inquirer: 3.1.3
    - ipykernel: 6.22.0
    - ipython: 8.12.0
    - ipython-genutils: 0.2.0
    - ipywidgets: 8.0.6
    - itsdangerous: 2.1.2
    - jedi: 0.18.2
    - jinja2: 3.1.2
    - joblib: 1.2.0
    - jsonschema: 4.17.3
    - jupyter: 1.0.0
    - jupyter-client: 8.2.0
    - jupyter-console: 6.6.3
    - jupyter-core: 5.3.0
    - jupyter-events: 0.6.3
    - jupyter-server: 2.5.0
    - jupyter-server-terminals: 0.4.4
    - jupyterlab-pygments: 0.2.2
    - jupyterlab-widgets: 3.0.7
    - kiwisolver: 1.4.4
    - korean-lunar-calendar: 0.3.1
    - libclang: 16.0.0
    - lightning: 2.0.1
    - lightning-cloud: 0.5.32
    - lightning-utilities: 0.8.0
    - lit: 16.0.1
    - llvmlite: 0.39.1
    - mako: 1.2.4
    - markdown: 3.4.3
    - markdown-it-py: 2.2.0
    - markupsafe: 2.1.2
    - matplotlib: 3.7.1
    - matplotlib-inline: 0.1.6
    - mdurl: 0.1.2
    - mistune: 2.0.5
    - mpmath: 1.3.0
    - multidict: 6.0.4
    - nbclassic: 0.5.5
    - nbclient: 0.7.3
    - nbconvert: 7.3.1
    - nbformat: 5.8.0
    - nest-asyncio: 1.5.6
    - networkx: 3.1
    - notebook: 6.5.4
    - notebook-shim: 0.2.2
    - numba: 0.56.4
    - numpy: 1.23.5
    - nvidia-cublas-cu11:
    - nvidia-cuda-cupti-cu11: 11.7.101
    - nvidia-cuda-nvrtc-cu11: 11.7.99
    - nvidia-cuda-runtime-cu11: 11.7.99
    - nvidia-cudnn-cu11:
    - nvidia-cufft-cu11:
    - nvidia-curand-cu11:
    - nvidia-cusolver-cu11:
    - nvidia-cusparse-cu11:
    - nvidia-nccl-cu11: 2.14.3
    - nvidia-nvtx-cu11: 11.7.91
    - nvtx: 0.2.5
    - oauthlib: 3.2.2
    - opt-einsum: 3.3.0
    - optuna: 3.1.1
    - ordered-set: 4.1.0
    - orjson: 3.8.10
    - packaging: 23.1
    - pandas: 1.5.3
    - pandocfilters: 1.5.0
    - parso: 0.8.3
    - pastedeploy: 3.0.1
    - patsy: 0.5.3
    - pexpect: 4.8.0
    - pickleshare: 0.7.5
    - pillow: 9.5.0
    - pip: 23.1.1
    - pkgutil-resolve-name: 1.3.10
    - plaster: 1.0
    - plaster-pastedeploy: 0.7
    - platformdirs: 3.2.0
    - ply: 3.11
    - prometheus-client: 0.16.0
    - prompt-toolkit: 3.0.38
    - protobuf: 4.21.12
    - psutil: 5.9.5
    - ptxcompiler: 0.7.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - pyarrow: 10.0.1
    - pyasn1: 0.4.8
    - pyasn1-modules: 0.2.8
    - pycparser: 2.21
    - pydantic: 1.10.7
    - pygments: 2.15.1
    - pyjwt: 2.6.0
    - pymeeus: 0.5.12
    - pyparsing: 3.0.9
    - pyqt5: 5.15.7
    - pyqt5-sip: 12.11.0
    - pyramid: 2.0.1
    - pyrsistent: 0.19.3
    - python-dateutil: 2.8.2
    - python-dotenv: 1.0.0
    - python-editor: 1.0.4
    - python-json-logger: 2.0.7
    - python-multipart: 0.0.6
    - pytorch-forecasting: 1.0.0
    - pytorch-lightning: 2.0.1.post0
    - pytorch-optimizer: 2.5.1
    - pytz: 2023.3
    - pyyaml: 6.0
    - pyzmq: 25.0.2
    - qtconsole: 5.4.2
    - qtpy: 2.3.1
    - rapids: 0.0.1
    - readchar: 4.0.5
    - requests: 2.28.2
    - requests-oauthlib: 1.3.1
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.3.3
    - rmm: 23.4.0
    - rsa: 4.9
    - scikit-learn: 1.2.2
    - scipy: 1.10.1
    - send2trash: 1.8.0
    - setuptools: 67.7.1
    - sip: 6.7.9
    - six: 1.16.0
    - sniffio: 1.3.0
    - soupsieve: 2.3.2.post1
    - sqlalchemy: 2.0.9
    - stack-data: 0.6.2
    - starlette: 0.22.0
    - starsessions: 1.3.0
    - statsmodels: 0.13.5
    - sympy: 1.11.1
    - tensorboard: 2.12.2
    - tensorboard-data-server: 0.7.0
    - tensorboard-plugin-wit: 1.8.1
    - tensorflow-io-gcs-filesystem: 0.32.0
    - termcolor: 2.2.0
    - terminado: 0.17.1
    - threadpoolctl: 3.1.0
    - tinycss2: 1.2.1
    - toml: 0.10.2
    - tomli: 2.0.1
    - torch: 2.0.0
    - torchmetrics: 0.11.4
    - tornado: 6.3
    - tqdm: 4.65.0
    - traitlets: 5.9.0
    - translationstring: 1.4
    - triton: 2.0.0
    - typing-extensions: 4.5.0
    - ujson: 5.7.0
    - urllib3: 1.26.15
    - uvicorn: 0.21.1
    - uvloop: 0.17.0
    - venusian: 3.0.0
    - watchfiles: 0.19.0
    - wcwidth: 0.2.6
    - webencodings: 0.5.1
    - webob: 1.8.7
    - websocket-client: 1.5.1
    - websockets: 11.0.1
    - werkzeug: 2.2.3
    - wheel: 0.40.0
    - widgetsnbextension: 4.0.7
    - yarl: 1.8.2
    - zipp: 3.15.0
    - zope.deprecation: 4.4.0
    - zope.interface: 6.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.10
    - version: Quantisation and Pruning Support #76-Ubuntu SMP Fri Mar 17 17:19:29 UTC 2023

More info

No response

how to fix it? i also met this problem

Was there a solution for this? I am running into the same error

ruuttt commented Dec 22, 2023

You can find a solution here

