Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchForecastingModel model precision #1987

Closed
evanwrm opened this issue Sep 8, 2023 · 3 comments · Fixed by #2046
Closed

TorchForecastingModel model precision #1987

evanwrm opened this issue Sep 8, 2023 · 3 comments · Fixed by #2046
Labels
triage Issue waiting for triaging

Comments

@evanwrm
Copy link

evanwrm commented Sep 8, 2023

I had a few questions around precision for TorchForecastingModels:

  • It seems we can only choose between float64 & float32, I'm curious if we could use half-precision (float16 or bfloat16), it seems this is discouraged when setting through pl_trainer_kwargs.
  • When loading the model, the precision seems to be reset to float64 (train_sample).

Context:

import torch
import numpy as np

from darts import models
from darts.utils.likelihood_models import QuantileRegression
from darts.datasets import AirPassengersDataset

# create fake dataset
series_air = AirPassengersDataset().load().astype(np.float32) # can we use float16?

tft_accelerator = (
    {"accelerator": "gpu", "devices": -1}
    if torch.cuda.is_available()
    else {"accelerator": "cpu"}
)
def create_model(model_name: str):
    return models.TFTModel(
        input_chunk_length=4,
        output_chunk_length=4,
        categorical_embedding_sizes={}, # <-- error if not set + loading weights
        likelihood=QuantileRegression(), # <-- error if not set + loading weights
        n_epochs=10,
        model_name=model_name,
        work_dir="models",
        force_reset=True,
        save_checkpoints=True,
        pl_trainer_kwargs={**tft_accelerator},
    )

model = create_model("TFTModel")
model.fit(
    [series_air, series_air], 
    past_covariates=[series_air, series_air],
    future_covariates=[series_air, series_air],
)

model = create_model("TFTModel_finetune")
model.load_weights_from_checkpoint( # <-- train_sample now float64 regardless of series dtype
    model_name="TFTModel",
    work_dir="models",
    best=False,
)

model.predict(
    n=4,
    series=series_air[:-4], 
    past_covariates=series_air,
    future_covariates=series_air,
)

System:
Python version: 3.11.4
darts version 0.25.0

@evanwrm evanwrm added the triage Issue waiting for triaging label Sep 8, 2023
@Eliotdoesprogramming
Copy link
Contributor

Eliotdoesprogramming commented Sep 9, 2023

Precision will be limited by support from lightning. afaik half precision IS supported, but only on a GPU that supports it.

https://pytorch-lightning.readthedocs.io/en/1.8.6/guides/speed.html#mixed-precision-16-bit-training

@evanwrm
Copy link
Author

evanwrm commented Sep 9, 2023

I suppose my request was for darts to maybe display a warning instead of an error if precision is set this way (e.g. #860).

To clarify my use case:
I have a moderately sized dataset (~5m samples), currently with a batch size of 64, TFT model, + 3080 (Ampere) I get:

  • fp64: ~3it/s
  • fp32: ~10it/s

Was hoping to continue working within darts while using mixed-precision training :)

@Hsinfu
Copy link
Contributor

Hsinfu commented Nov 6, 2023

The precision bug cloud be fixed by #2046 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triage Issue waiting for triaging
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants