Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typeerror when trying to fit a TemporalFusionTransformer model #17458

Closed
jsejdija opened this issue Apr 24, 2023 · 3 comments
Closed

typeerror when trying to fit a TemporalFusionTransformer model #17458

jsejdija opened this issue Apr 24, 2023 · 3 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers

Comments

@jsejdija
Copy link

Bug description

when trying to fit a TemporalFusionTransformer there is a typeerror.

What version are you seeing the problem on?

2.0+

How to reproduce the bug

trainer = pl.Trainer(
	max_epochs=10,
	devices=1, accelerator="gpu",
	enable_model_summary=True,
	gradient_clip_val=0.25,
	limit_train_batches=10
)

tft = TemporalFusionTransformer.from_dataset(
	training,
	lstm_layers=1,
	hidden_size=16,
	attention_head_size=2,
	dropout=0.2,
	hidden_continuous_size=8,
	output_size=1,
	loss=SMAPE(),
	log_interval=10,
	reduce_on_plateau_patience=4
)

trainer.fit(
	tft,
	train_dataloaders=train_dataloader,
	val_dataloaders=val_dataloader,
)

Error messages and logs

typeerror: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100-PCIE-40GB
    - available: True
    - version: 11.7
  • Lightning:
    - lightning: 2.0.1
    - lightning-cloud: 0.5.32
    - lightning-utilities: 0.8.0
    - pytorch-forecasting: 1.0.0
    - pytorch-lightning: 2.0.1.post0
    - pytorch-optimizer: 2.5.1
    - torch: 2.0.0
    - torchmetrics: 0.11.4
  • Packages:
    - absl-py: 1.4.0
    - aiohttp: 3.8.4
    - aiosignal: 1.3.1
    - alembic: 1.10.3
    - anyio: 3.6.2
    - argon2-cffi: 21.3.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.2.3
    - asttokens: 2.2.1
    - astunparse: 1.6.3
    - async-timeout: 4.0.2
    - attrs: 22.2.0
    - backcall: 0.2.0
    - backports.functools-lru-cache: 1.6.4
    - beautifulsoup4: 4.12.2
    - bleach: 6.0.0
    - blessed: 1.20.0
    - cachetools: 5.3.0
    - certifi: 2022.12.7
    - cffi: 1.15.1
    - charset-normalizer: 3.1.0
    - click: 8.1.3
    - cmaes: 0.9.1
    - cmake: 3.26.3
    - colorlog: 6.7.0
    - comm: 0.1.3
    - contourpy: 1.0.7
    - convertdate: 2.4.0
    - croniter: 1.3.10
    - cubinlinker: 0.2.2
    - cuda-python: 11.8.1
    - cudf: 23.4.0
    - cupy: 11.6.0
    - cycler: 0.11.0
    - dateutils: 0.6.12
    - debugpy: 1.6.7
    - decorator: 5.1.1
    - deepdiff: 6.3.0
    - defusedxml: 0.7.1
    - dnspython: 2.3.0
    - email-validator: 1.3.1
    - entrypoints: 0.4
    - executing: 1.2.0
    - fastapi: 0.88.0
    - fastavro: 1.7.3
    - fastjsonschema: 2.16.3
    - fastrlock: 0.8
    - filelock: 3.11.0
    - flatbuffers: 23.3.3
    - flit-core: 3.8.0
    - fonttools: 4.39.3
    - frozenlist: 1.3.3
    - fsspec: 2023.4.0
    - gast: 0.4.0
    - google-auth: 2.17.1
    - google-auth-oauthlib: 1.0.0
    - google-pasta: 0.2.0
    - greenlet: 2.0.2
    - grpcio: 1.53.0
    - h11: 0.14.0
    - h5py: 3.8.0
    - hijri-converter: 2.2.4
    - holidays: 0.23
    - httpcore: 0.17.0
    - httptools: 0.5.0
    - httpx: 0.24.0
    - hupper: 1.12
    - idna: 3.4
    - importlib-metadata: 6.6.0
    - importlib-resources: 5.12.0
    - inquirer: 3.1.3
    - ipykernel: 6.22.0
    - ipython: 8.12.0
    - ipython-genutils: 0.2.0
    - ipywidgets: 8.0.6
    - itsdangerous: 2.1.2
    - jedi: 0.18.2
    - jinja2: 3.1.2
    - joblib: 1.2.0
    - jsonschema: 4.17.3
    - jupyter: 1.0.0
    - jupyter-client: 8.2.0
    - jupyter-console: 6.6.3
    - jupyter-core: 5.3.0
    - jupyter-events: 0.6.3
    - jupyter-server: 2.5.0
    - jupyter-server-terminals: 0.4.4
    - jupyterlab-pygments: 0.2.2
    - jupyterlab-widgets: 3.0.7
    - kiwisolver: 1.4.4
    - korean-lunar-calendar: 0.3.1
    - libclang: 16.0.0
    - lightning: 2.0.1
    - lightning-cloud: 0.5.32
    - lightning-utilities: 0.8.0
    - lit: 16.0.1
    - llvmlite: 0.39.1
    - mako: 1.2.4
    - markdown: 3.4.3
    - markdown-it-py: 2.2.0
    - markupsafe: 2.1.2
    - matplotlib: 3.7.1
    - matplotlib-inline: 0.1.6
    - mdurl: 0.1.2
    - mistune: 2.0.5
    - mpmath: 1.3.0
    - multidict: 6.0.4
    - nbclassic: 0.5.5
    - nbclient: 0.7.3
    - nbconvert: 7.3.1
    - nbformat: 5.8.0
    - nest-asyncio: 1.5.6
    - networkx: 3.1
    - notebook: 6.5.4
    - notebook-shim: 0.2.2
    - numba: 0.56.4
    - numpy: 1.23.5
    - nvidia-cublas-cu11: 11.10.3.66
    - nvidia-cuda-cupti-cu11: 11.7.101
    - nvidia-cuda-nvrtc-cu11: 11.7.99
    - nvidia-cuda-runtime-cu11: 11.7.99
    - nvidia-cudnn-cu11: 8.5.0.96
    - nvidia-cufft-cu11: 10.9.0.58
    - nvidia-curand-cu11: 10.2.10.91
    - nvidia-cusolver-cu11: 11.4.0.1
    - nvidia-cusparse-cu11: 11.7.4.91
    - nvidia-nccl-cu11: 2.14.3
    - nvidia-nvtx-cu11: 11.7.91
    - nvtx: 0.2.5
    - oauthlib: 3.2.2
    - opt-einsum: 3.3.0
    - optuna: 3.1.1
    - ordered-set: 4.1.0
    - orjson: 3.8.10
    - packaging: 23.1
    - pandas: 1.5.3
    - pandocfilters: 1.5.0
    - parso: 0.8.3
    - pastedeploy: 3.0.1
    - patsy: 0.5.3
    - pexpect: 4.8.0
    - pickleshare: 0.7.5
    - pillow: 9.5.0
    - pip: 23.1.1
    - pkgutil-resolve-name: 1.3.10
    - plaster: 1.0
    - plaster-pastedeploy: 0.7
    - platformdirs: 3.2.0
    - ply: 3.11
    - prometheus-client: 0.16.0
    - prompt-toolkit: 3.0.38
    - protobuf: 4.21.12
    - psutil: 5.9.5
    - ptxcompiler: 0.7.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - pyarrow: 10.0.1
    - pyasn1: 0.4.8
    - pyasn1-modules: 0.2.8
    - pycparser: 2.21
    - pydantic: 1.10.7
    - pygments: 2.15.1
    - pyjwt: 2.6.0
    - pymeeus: 0.5.12
    - pyparsing: 3.0.9
    - pyqt5: 5.15.7
    - pyqt5-sip: 12.11.0
    - pyramid: 2.0.1
    - pyrsistent: 0.19.3
    - python-dateutil: 2.8.2
    - python-dotenv: 1.0.0
    - python-editor: 1.0.4
    - python-json-logger: 2.0.7
    - python-multipart: 0.0.6
    - pytorch-forecasting: 1.0.0
    - pytorch-lightning: 2.0.1.post0
    - pytorch-optimizer: 2.5.1
    - pytz: 2023.3
    - pyyaml: 6.0
    - pyzmq: 25.0.2
    - qtconsole: 5.4.2
    - qtpy: 2.3.1
    - rapids: 0.0.1
    - readchar: 4.0.5
    - requests: 2.28.2
    - requests-oauthlib: 1.3.1
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.3.3
    - rmm: 23.4.0
    - rsa: 4.9
    - scikit-learn: 1.2.2
    - scipy: 1.10.1
    - send2trash: 1.8.0
    - setuptools: 67.7.1
    - sip: 6.7.9
    - six: 1.16.0
    - sniffio: 1.3.0
    - soupsieve: 2.3.2.post1
    - sqlalchemy: 2.0.9
    - stack-data: 0.6.2
    - starlette: 0.22.0
    - starsessions: 1.3.0
    - statsmodels: 0.13.5
    - sympy: 1.11.1
    - tensorboard: 2.12.2
    - tensorboard-data-server: 0.7.0
    - tensorboard-plugin-wit: 1.8.1
    - tensorflow-io-gcs-filesystem: 0.32.0
    - termcolor: 2.2.0
    - terminado: 0.17.1
    - threadpoolctl: 3.1.0
    - tinycss2: 1.2.1
    - toml: 0.10.2
    - tomli: 2.0.1
    - torch: 2.0.0
    - torchmetrics: 0.11.4
    - tornado: 6.3
    - tqdm: 4.65.0
    - traitlets: 5.9.0
    - translationstring: 1.4
    - triton: 2.0.0
    - typing-extensions: 4.5.0
    - ujson: 5.7.0
    - urllib3: 1.26.15
    - uvicorn: 0.21.1
    - uvloop: 0.17.0
    - venusian: 3.0.0
    - watchfiles: 0.19.0
    - wcwidth: 0.2.6
    - webencodings: 0.5.1
    - webob: 1.8.7
    - websocket-client: 1.5.1
    - websockets: 11.0.1
    - werkzeug: 2.2.3
    - wheel: 0.40.0
    - widgetsnbextension: 4.0.7
    - yarl: 1.8.2
    - zipp: 3.15.0
    - zope.deprecation: 4.4.0
    - zope.interface: 6.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.10
    - version: Quantisation and Pruning Support #76-Ubuntu SMP Fri Mar 17 17:19:29 UTC 2023

More info

No response

@jsejdija jsejdija added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 24, 2023
@arrow9577
Copy link

how to fix it? i also met this problem

@hgersten5
Copy link

Was there a solution for this? I am running into the same error

@ruuttt
Copy link

ruuttt commented Dec 22, 2023

You can find a solution here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants