diff --git a/scripts/backtest_sites.py b/scripts/backtest_sites.py index ca6c5313..87048b9c 100644 --- a/scripts/backtest_sites.py +++ b/scripts/backtest_sites.py @@ -70,9 +70,9 @@ # checkpoint on the val set model_chckpoint_dir = "PLACEHOLDER" -revision = None -token = None -model_id = None +hf_revision = None +hf_token = None +hf_model_id = None # Forecasts will be made for all available init times between these start_datetime = "2022-05-08 00:00" @@ -477,9 +477,9 @@ def main(config: DictConfig): # Create a dataloader for the concurrent batches and use multiprocessing dataloader = DataLoader(batch_pipe, **dataloader_kwargs) # Load the PVNet model - if model_chckpoint_dir is not None: + if model_chckpoint_dir: model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True) - elif model_id is not None: + elif model_id: model = load_model_from_hf(model_id, revision, token) else: raise ValueError("Provide a model checkpoint or a HuggingFace model")