You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I had a few questions around precision for TorchForecastingModels:
It seems we can only choose between float64 & float32, I'm curious if we could use half-precision (float16 or bfloat16), it seems this is discouraged when setting through pl_trainer_kwargs.
When loading the model, the precision seems to be reset to float64 (train_sample).
Context:
importtorchimportnumpyasnpfromdartsimportmodelsfromdarts.utils.likelihood_modelsimportQuantileRegressionfromdarts.datasetsimportAirPassengersDataset# create fake datasetseries_air=AirPassengersDataset().load().astype(np.float32) # can we use float16?tft_accelerator= (
{"accelerator": "gpu", "devices": -1}
iftorch.cuda.is_available()
else {"accelerator": "cpu"}
)
defcreate_model(model_name: str):
returnmodels.TFTModel(
input_chunk_length=4,
output_chunk_length=4,
categorical_embedding_sizes={}, # <-- error if not set + loading weightslikelihood=QuantileRegression(), # <-- error if not set + loading weightsn_epochs=10,
model_name=model_name,
work_dir="models",
force_reset=True,
save_checkpoints=True,
pl_trainer_kwargs={**tft_accelerator},
)
model=create_model("TFTModel")
model.fit(
[series_air, series_air],
past_covariates=[series_air, series_air],
future_covariates=[series_air, series_air],
)
model=create_model("TFTModel_finetune")
model.load_weights_from_checkpoint( # <-- train_sample now float64 regardless of series dtypemodel_name="TFTModel",
work_dir="models",
best=False,
)
model.predict(
n=4,
series=series_air[:-4],
past_covariates=series_air,
future_covariates=series_air,
)
System:
Python version: 3.11.4
darts version 0.25.0
The text was updated successfully, but these errors were encountered:
I had a few questions around precision for TorchForecastingModels:
Context:
System:
Python version:
3.11.4
darts version
0.25.0
The text was updated successfully, but these errors were encountered: