diff --git a/india_forecast_app/models/all_models.yaml b/india_forecast_app/models/all_models.yaml index 96cf09d..cf6538b 100644 --- a/india_forecast_app/models/all_models.yaml +++ b/india_forecast_app/models/all_models.yaml @@ -32,6 +32,12 @@ models: version: d71104620f0b0bdd3eeb63cafecd2a49032ae0f7 client: ruvnl asset_type: pv + - name: pvnet_india_ecmwf_mo_gfs + type: pvnet + id: openclimatefix/pvnet_india + version: 7a179d1f8349d99cb2a30a48e5be17d59b6f2b16 + client: ruvnl + asset_type: pv # Ad client solar - name: pvnet_ad_sites type: pvnet diff --git a/india_forecast_app/models/pvnet/model.py b/india_forecast_app/models/pvnet/model.py index 63dcc0c..f540356 100644 --- a/india_forecast_app/models/pvnet/model.py +++ b/india_forecast_app/models/pvnet/model.py @@ -380,10 +380,23 @@ def _create_dataloader(self): ) else: + + # This is a bit of a hack, for ocf datapipes. + # The normalisation constants are different for the + # ruvnl pv 1st model and the ruvnl pv 2nd model + # and ad models + # When moving to ocf-data-sampler, we should think carefully how this is done + if self.name == 'pvnet_india_ecmwf_mo_gfs': + new_normalisation_constants = True + else: + new_normalisation_constants = False + log.debug(f"Using new normalisation constants: {new_normalisation_constants}") + base_datapipe_dict = pv_base_pipeline( config_filename=populated_data_config_filename, location_pipe=location_pipe, t0_datapipe=t0_datapipe, + new_normalisation_constants=new_normalisation_constants, ) base_datapipe = DictDatasetIterDataPipe( diff --git a/india_forecast_app/models/pvnet/utils.py b/india_forecast_app/models/pvnet/utils.py index 9dfbb89..f367edf 100644 --- a/india_forecast_app/models/pvnet/utils.py +++ b/india_forecast_app/models/pvnet/utils.py @@ -156,9 +156,11 @@ def process_and_cache_nwp(nwp_config: NWPProcessAndCacheConfig): if nwp_config.source == "mo_global": + # COMMENTED this out for the moment, as different models use different mo global variables # only select the variables we need - nwp_channels = list(nwp_config.config.nwp_channels) - ds = ds.sel(variable=nwp_channels) + # nwp_channels = list(nwp_config.config.nwp_channels) + # log.info(f"Selecting NWP channels {nwp_channels} for mo_global data") + # ds = ds.sel(variable=nwp_channels) # get directory of file regrid_coords = os.path.dirname(nwp.__file__) diff --git a/poetry.lock b/poetry.lock index 967065d..4338817 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -3798,13 +3798,13 @@ numpy = "*" [[package]] name = "ocf-datapipes" -version = "3.3.52" +version = "3.3.55" description = "Pytorch Datapipes built for use in Open Climate Fix's forecasting work" optional = false python-versions = "*" files = [ - {file = "ocf_datapipes-3.3.52-py3-none-any.whl", hash = "sha256:889d5d0a0d91868783a85c414a9860d721305942a57ce5713c7c14fb38e9cf6f"}, - {file = "ocf_datapipes-3.3.52.tar.gz", hash = "sha256:eaf149ad672d423595afc45c5d58bbac1e537d1714e388ecf702d477754de61e"}, + {file = "ocf_datapipes-3.3.55-py3-none-any.whl", hash = "sha256:d4712aef35a13974eccbbe7aca420de318fc84f0112a8bb05729f5a610965d49"}, + {file = "ocf_datapipes-3.3.55.tar.gz", hash = "sha256:1a9db14fffcdc902ea4f649d5d2ed47d33f5ee80b9803f44278abf142a530ae3"}, ] [package.dependencies] @@ -3818,18 +3818,18 @@ fsspec = "*" geopandas = "*" gitpython = "*" h5netcdf = "*" -nowcasting-datamodel = ">=1.5.30" +nowcasting_datamodel = ">=1.5.30" numpy = "*" -ocf-blosc2 = "*" +ocf_blosc2 = "*" pandas = "*" pathy = "*" pvlib = "*" pvlive-api = "*" -pyaml-env = "*" +pyaml_env = "*" pydantic = "*" pyproj = "*" pyresample = "*" -pytorch-lightning = "*" +pytorch_lightning = "*" rioxarray = "*" scipy = "*" torch = ">=2.0.0,<2.5.0" @@ -7017,4 +7017,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "ce3e5b7c5eb197d7e643996fe596b70ba281397e28acbfbe5f70df2f51f915dd" +content-hash = "bb7863a314a25e025cdf8c88a97eeae78793cf0aa71834cd2121218ce0f37b47" diff --git a/pyproject.toml b/pyproject.toml index 91dc56d..e876817 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ pvnet = "3.0.64" pytz = "^2024.1" numpy = "^1.26.4" huggingface-hub = "0.20.3" -ocf-datapipes = "3.3.52" +ocf-datapipes = "3.3.55" pyogrio = "0.8.0" # 0.9.0 seems to cause an error at the moment torch = [ {url="https://download.pytorch.org/whl/cpu/torch-2.2.1%2Bcpu-cp311-cp311-linux_x86_64.whl", markers="platform_system == \"Linux\" and platform_machine == \"x86_64\""}, diff --git a/tests/conftest.py b/tests/conftest.py index 8291877..87a35e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -166,7 +166,7 @@ def generation_db_values(db_session, sites, init_timestamp): def generation_db_values_only_wind(db_session, sites, init_timestamp): """Create some fake generations""" - n = 20*25 # 25 hours of readings + n = 20 * 25 # 25 hours of readings start_times = [init_timestamp - dt.timedelta(minutes=x * 3) for x in range(n)] # remove some of the most recent readings (to simulate missing timestamps) @@ -361,12 +361,9 @@ def nwp_mo_global_data(tmp_path_factory, time_before_present): # Load dataset which only contains coordinates, but no data ds = xr.open_zarr( - f"{os.path.dirname(os.path.abspath(__file__))}/test_data/nwp-no-data_gfs.zarr" + f"{os.path.dirname(os.path.abspath(__file__))}/test_data/nwp-no-data.zarr" ) - # rename dimension init_time_utc to init_time - ds = ds.rename({"init_time_utc": "init_time"}) - # Last t0 to at least 4 hours ago and floor to 3-hour interval t0_datetime_utc = time_before_present(dt.timedelta(hours=0)).floor("3h") t0_datetime_utc = t0_datetime_utc - dt.timedelta(hours=4) @@ -389,10 +386,21 @@ def nwp_mo_global_data(tmp_path_factory, time_before_present): ds[v].encoding.clear() # change variables values to for MO global - ds.variable.values[0:3] = ["temperature_sl", "wind_u_component_10m", "wind_v_component_10m"] + ds.variable.values[0:10] = [ + "temperature_sl", + "wind_u_component_10m", + "wind_v_component_10m", + "downward_shortwave_radiation_flux_gl", + "cloud_cover_high", + "cloud_cover_low", + "cloud_cover_medium", + "relative_humidity_sl", + "snow_depth_gl", + "visibility_sl", + ] # interpolate 3 hourly step to 1 hour steps - steps = pd.TimedeltaIndex(np.arange(49) * 3600 * 1e9, freq='infer') + steps = pd.TimedeltaIndex(np.arange(49) * 3600 * 1e9, freq="infer") ds = ds.interp(step=steps, method="linear") ds["mo_global"] = xr.DataArray( diff --git a/tests/models/test_pydantic_models.py b/tests/models/test_pydantic_models.py index 3f23330..818a646 100644 --- a/tests/models/test_pydantic_models.py +++ b/tests/models/test_pydantic_models.py @@ -5,7 +5,7 @@ def test_get_all_models(): """Test for getting all models""" models = get_all_models() - assert len(models.models) == 10 + assert len(models.models) == 11 def test_get_all_models_client(): diff --git a/tests/test_app.py b/tests/test_app.py index 9623732..8347a07 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -161,9 +161,9 @@ def test_app( assert result.exit_code == 0 if write_to_db: - assert db_session.query(ForecastSQL).count() == init_n_forecasts + 5 * 2 - assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (5 * 2 * 192) - assert db_session.query(MLModelSQL).count() == 5 * 2 + assert db_session.query(ForecastSQL).count() == init_n_forecasts + 6 * 2 + assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (6 * 2 * 192) + assert db_session.query(MLModelSQL).count() == 6 * 2 else: assert db_session.query(ForecastSQL).count() == init_n_forecasts assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values @@ -183,8 +183,8 @@ def test_app_no_pv_data( result = run_click_script(app, args) assert result.exit_code == 0 - assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2 * 5 - assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 5 * 192) + assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2 * 6 + assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 6 * 192) @pytest.mark.requires_hf_token