diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 7c814ea3..db56946f 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -62,8 +62,8 @@ def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"): if "nwp" in config["input_data"]: for source in config["input_data"]["nwp"]: - if config["input_data"]["nwp"][source][f"nwp_zarr_path"] != "": - config["input_data"]["nwp"][source][f"nwp_zarr_path"] = f"{placeholder}.zarr" + if config["input_data"]["nwp"][source]["nwp_zarr_path"] != "": + config["input_data"]["nwp"][source]["nwp_zarr_path"] = f"{placeholder}.zarr" if "pv" in config["input_data"]: for d in config["input_data"]["pv"]["pv_files_groups"]: diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index 86254c8b..f4145649 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -343,7 +343,11 @@ def forward(self, x): if self.include_sun: sun = torch.cat( - (x[BatchKey.gsp_solar_azimuth], x[BatchKey.gsp_solar_elevation]), dim=1 + ( + x[BatchKey[f"{self.target_key_name}_solar_azimuth"]], + x[BatchKey[f"{self.target_key_name}_solar_elevation"]], + ), + dim=1, ).float() sun = self.sun_fc1(sun) modes["sun"] = sun