diff --git a/india_forecast_app/models/pvnet/model.py b/india_forecast_app/models/pvnet/model.py index 8f5f538..9f2d2b3 100644 --- a/india_forecast_app/models/pvnet/model.py +++ b/india_forecast_app/models/pvnet/model.py @@ -95,8 +95,9 @@ def predict(self, site_id: str, timestamp: dt.datetime): # Run batch through model device_batch = copy_batch_to_device(batch_to_tensor(batch), DEVICE) preds = self.model(device_batch).detach().cpu().numpy() + # filter out night time - if self.asset_type == SiteAssetType.pv: + if self.asset_type == SiteAssetType.pv.name: preds = set_night_time_zeros(batch, preds) # Store predictions @@ -173,17 +174,12 @@ def predict(self, site_id: str, timestamp: dt.datetime): 0.0, ] log.debug(f"Previous values are {values_df['forecast_power_kw']}") - zero_values = values_df["forecast_power_kw"] == 0 for idx in range(8): values_df["forecast_power_kw"][idx] -= ( values_df["forecast_power_kw"][idx] - final_gen_points ) * smooth_values[final_gen_index + idx] log.debug(f"New values are {values_df['forecast_power_kw']}") - if self.asset_type == "solar": - # make sure previous zero values are still zero - values_df["forecast_power_kw"][zero_values] = 0 - if self.asset_type == "wind": # Smooth with a 1 hour rolling window # Only smooth the wind else we introduce too much of a lag in the solar @@ -244,9 +240,7 @@ def _prepare_data_sources(self): # if generation_da is still empty make nans if len(generation_da) == 0: cols = [str(col) for col in self.generation_data["data"].columns] - generation_df = pd.DataFrame( - index=forecast_timesteps, columns=cols, data=0.0001 - ) + generation_df = pd.DataFrame(index=forecast_timesteps, columns=cols, data=0.0001) generation_da = generation_df.to_xarray() generation_da.to_netcdf(wind_netcdf_path, engine="h5netcdf") diff --git a/india_forecast_app/models/pvnet/utils.py b/india_forecast_app/models/pvnet/utils.py index e4c9448..bce9158 100644 --- a/india_forecast_app/models/pvnet/utils.py +++ b/india_forecast_app/models/pvnet/utils.py @@ -51,7 +51,7 @@ def populate_data_config_sources(input_path, output_path): "wind": {"filename": wind_netcdf_path, "metadata_filename": wind_metadata_path}, "pv": {"filename": pv_netcdf_path, "metadata_filename": pv_metadata_path}, "nwp": {"ecmwf": nwp_ecmwf_path, "gfs": nwp_gfs_path}, - "satellite": {"filepath": satellite_path} + "satellite": {"filepath": satellite_path}, } if "nwp" in config["input_data"]: @@ -89,7 +89,7 @@ def populate_data_config_sources(input_path, output_path): def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str): """Reads zarr file, renames t variable to t2m and saves zarr to new destination""" - log.info(f'Processing and caching NWP data for {source_nwp_path}') + log.info(f"Processing and caching NWP data for {source_nwp_path}") # Load dataset from source ds = xr.open_zarr(source_nwp_path) @@ -105,7 +105,7 @@ def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str): is_gfs = "gfs" in source_nwp_path.lower() - if not is_gfs: # this is for ECMWF NWP + if not is_gfs: # this is for ECMWF NWP # Rename t variable to t2m variables = list(ds.variable.values) new_variables = [] @@ -119,12 +119,12 @@ def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str): else: new_variables.append(var) ds.__setitem__("variable", new_variables) - + # Hack to resolve some NWP data format differences between providers elif is_gfs: data_var = ds[list(ds.data_vars.keys())[0]] # # Use .to_dataset() to split the data variable based on 'variable' dim - ds = data_var.to_dataset(dim='variable') + ds = data_var.to_dataset(dim="variable") ds = ds.rename({"t2m": "t"}) # Save destination path ds.to_zarr(dest_nwp_path, mode="a") @@ -136,18 +136,23 @@ def download_satellite_data(satellite_source_file_path: str) -> None: # download satellite data fs = fsspec.open(satellite_source_file_path).fs if fs.exists(satellite_source_file_path): - log.info(f"Downloading satellite data from {satellite_source_file_path} " - f"to sat_15_min.zarr.zip") + log.info( + f"Downloading satellite data from {satellite_source_file_path} " + f"to sat_15_min.zarr.zip" + ) fs.get(satellite_source_file_path, "sat_15_min.zarr.zip") log.info(f"Unzipping sat_15_min.zarr.zip to {satellite_path}") os.system(f"unzip -qq sat_15_min.zarr.zip -d {satellite_path}") else: log.error(f"Could not find satellite data at {satellite_source_file_path}") + def set_night_time_zeros(batch, preds, sun_elevation_limit=0.0): """ Set all predictions to zero for night time values """ + + log.debug("Setting night time values to zero") # get sun elevation values and if less 0, set to 0 if BatchKey.wind_solar_elevation in batch.keys(): key = BatchKey.wind_solar_elevation @@ -167,7 +172,8 @@ def set_night_time_zeros(batch, preds, sun_elevation_limit=0.0): sun_elevation = sun_elevation.detach().cpu().numpy() # expand dimension from (1,197) to (1,197,7), 7 is due to the number plevels - sun_elevation = np.repeat(sun_elevation[:, :, np.newaxis], 7, axis=2) + n_plevels = preds.shape[2] + sun_elevation = np.repeat(sun_elevation[:, :, np.newaxis], n_plevels, axis=2) # only take future time steps sun_elevation = sun_elevation[:, batch[t0_key] + 1 :, :] preds[sun_elevation < sun_elevation_limit] = 0