Skip to content

Commit

Permalink
NWP test data loading correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
confusedmatrix committed Feb 14, 2024
1 parent 0b758d4 commit a0f0df3
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 48 deletions.
33 changes: 20 additions & 13 deletions india_forecast_app/models/pvnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import torch
from ocf_datapipes.batch import stack_np_examples_into_batch
from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline as pv_base_pipeline
from ocf_datapipes.training.windnet import DictDatasetIterDataPipe
from ocf_datapipes.training.windnet import DictDatasetIterDataPipe, split_dataset_dict_dp
from ocf_datapipes.training.windnet import construct_sliced_data_pipeline as wind_base_pipeline
from ocf_datapipes.utils import Location
from ocf_datapipes.utils.utils import combine_to_single_dataset
from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset

Check failure on line 20 in india_forecast_app/models/pvnet/model.py

View workflow job for this annotation

GitHub Actions / lint_and_test / Lint the code and run the tests

Ruff (F401)

india_forecast_app/models/pvnet/model.py:20:39: F401 `ocf_datapipes.utils.utils.combine_to_single_dataset` imported but unused

Check failure on line 20 in india_forecast_app/models/pvnet/model.py

View workflow job for this annotation

GitHub Actions / lint_and_test / Lint the code and run the tests

Ruff (F401)

india_forecast_app/models/pvnet/model.py:20:66: F401 `ocf_datapipes.utils.utils.uncombine_from_single_dataset` imported but unused
from pvnet.data.utils import batch_to_tensor, copy_batch_to_device
from pvnet.models.base_model import BaseModel as PVNetBaseModel
from torch.utils.data import DataLoader
Expand All @@ -39,7 +39,7 @@
PV_MODEL_VERSION = os.getenv("PV_MODEL_VERSION",
default="d194488203375e766253f0d2961010356de52eb9")

BATCH_SIZE = 10
BATCH_SIZE = 1

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -164,10 +164,8 @@ def _create_dataloader(self):
populate_data_config_sources(data_config_filename, populated_data_config_filename)

# Location and time datapipes
# TODO not sure if this is the correct way to set these up...
location_pipe = IterableWrapper([Location(coordinate_system="lon_lat", x=72.6399, y=26.4499)])

Check failure on line 167 in india_forecast_app/models/pvnet/model.py

View workflow job for this annotation

GitHub Actions / lint_and_test / Lint the code and run the tests

Ruff (E501)

india_forecast_app/models/pvnet/model.py:167:101: E501 Line too long (102 > 100)
t0_datapipe = IterableWrapper([self.t0])
# t0_datapipe = IterableWrapper([self.t0]).repeat(len(location_pipe))

location_pipe = location_pipe.sharding_filter()
t0_datapipe = t0_datapipe.sharding_filter()
Expand All @@ -181,14 +179,24 @@ def _create_dataloader(self):
t0_datapipe=t0_datapipe
)
)
log.info(next(iter(base_datapipe_dict["nwp"]["ecmwf"])))

# TODO figure out why this is an empty dataset
log.info(next(iter(base_datapipe_dict["wind"])).to_pandas())

base_datapipe = DictDatasetIterDataPipe(
base_datapipe = (DictDatasetIterDataPipe(
{k: v for k, v in base_datapipe_dict.items() if k != "config"},
).map(combine_to_single_dataset)
)
.map(split_dataset_dict_dp))

# log.info(next(iter(next(iter(base_datapipe))["nwp"]['ecmwf'])))
# log.info(next(iter(next(iter(base_datapipe))["wind"])))

batch_datapipe = (
base_datapipe
.windnet_convert_to_numpy_batch()
.batch(BATCH_SIZE)
.map(stack_np_examples_into_batch)
.map(batch_to_tensor)
)

# log.info(next(iter(batch_datapipe)))
else:
base_datapipe = (
pv_base_pipeline(
Expand All @@ -198,8 +206,7 @@ def _create_dataloader(self):
production=True
)
)

batch_datapipe = base_datapipe.batch(BATCH_SIZE).map(stack_np_examples_into_batch)
batch_datapipe = base_datapipe.batch(BATCH_SIZE).map(stack_np_examples_into_batch)

n_workers = os.cpu_count() - 1

Expand Down
4 changes: 2 additions & 2 deletions india_forecast_app/models/pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def populate_data_config_sources(input_path, output_path):
production_paths = {
"wind": {
"filename": wind_netcdf_path,
"wind_metadata_filename": wind_metadata_path
"metadata_filename": wind_metadata_path
},
"nwp": {
"ecmwf": nwp_path
Expand All @@ -56,7 +56,7 @@ def populate_data_config_sources(input_path, output_path):
assert "wind" in production_paths, "Missing production path: wind"
wind_config["wind_files_groups"][0]["wind_filename"] = production_paths["wind"]['filename']
wind_config["wind_files_groups"][0]["wind_metadata_filename"] = (
production_paths)["wind"]['wind_metadata_filename']
production_paths)["wind"]['metadata_filename']

log.info(config)
with open(output_path, 'w') as outfile:
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def nwp_data(tmp_path_factory, time_before_present):
)

# Last t0 to at least 2 hours ago and floor to 3-hour interval
t0_datetime_utc = (time_before_present(dt.timedelta(hours=2))
t0_datetime_utc = (time_before_present(dt.timedelta(hours=0))
.floor(dt.timedelta(hours=3)))
ds.init_time.values[:] = pd.date_range(
t0_datetime_utc - dt.timedelta(hours=3 * (len(ds.init_time) - 1)),
Expand Down Expand Up @@ -182,7 +182,7 @@ def wind_data(tmp_path_factory, time_before_present):
ds = xr.open_dataset(netcdf_source_path)

# Set t0 to at least 2 hours ago and floor to 15-min interval
t0_datetime_utc = (time_before_present(dt.timedelta(hours=2))
t0_datetime_utc = (time_before_present(dt.timedelta(hours=0))
.floor(dt.timedelta(minutes=15)))
ds.time_utc.values[:] = pd.date_range(
t0_datetime_utc - dt.timedelta(minutes=15 * (len(ds.time_utc) - 1)),
Expand Down
3 changes: 1 addition & 2 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ def test_get_model(asset_type, nwp_data, wind_data):
def test_run_model(db_session, asset_type, nwp_data, wind_data, caplog):
"""Test for running PV and wind models"""

caplog.set_level('INFO')
caplog.set_level('DEBUG')

model = PVNetModel if asset_type == "wind" else DummyModel

timestamp = dt.datetime(year=2024, month=2, day=14, hour=11, minute=0, second=0, microsecond=0)
forecast = run_model(
model=model(asset_type, timestamp=dt.datetime.now(tz=None)),
# model=model(asset_type, timestamp=timestamp),
Expand Down
89 changes: 61 additions & 28 deletions tests/test_data/inspect_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,28 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 28,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-02-14T11:43:09.199372Z",
"start_time": "2024-02-14T11:43:08.541738Z"
"end_time": "2024-02-14T16:18:21.547026Z",
"start_time": "2024-02-14T16:18:19.852866Z"
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: numcodecs in /Users/chris/Library/Caches/pypoetry/virtualenvs/india-forecast-app-_WmL7Pf1-py3.11/lib/python3.11/site-packages (0.12.1)\r\n",
"Requirement already satisfied: numpy>=1.7 in /Users/chris/Library/Caches/pypoetry/virtualenvs/india-forecast-app-_WmL7Pf1-py3.11/lib/python3.11/site-packages (from numcodecs) (1.26.4)\r\n",
"\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m23.3.1\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m24.0\u001B[0m\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n"
]
}
],
"source": [
"import xarray as xr"
]
Expand All @@ -23,30 +35,47 @@
"name": "stdout",
"output_type": "stream",
"text": [
"<xarray.DataArray 'time_utc' (time_utc: 163680)>\n",
"array(['2019-06-15T09:45:00.000000000', '2019-06-15T10:00:00.000000000',\n",
" '2019-06-15T10:15:00.000000000', ..., '2024-02-14T09:00:00.000000000',\n",
" '2024-02-14T09:15:00.000000000', '2024-02-14T09:30:00.000000000'],\n",
" dtype='datetime64[ns]')\n",
"Coordinates:\n",
" * time_utc (time_utc) datetime64[ns] 2019-06-15T09:45:00 ... 2024-02-14T09...\n"
"ItemsView(<xarray.Dataset>\n",
"Dimensions: (pv__time_utc: 197, pv__pv_system_id: 1,\n",
" nwp-ecmwf__latitude: 168,\n",
" nwp-ecmwf__longitude: 168,\n",
" nwp-ecmwf__channel: 12,\n",
" nwp-ecmwf__target_time_utc: 50)\n",
"Coordinates: (12/13)\n",
" * pv__time_utc (pv__time_utc) datetime64[ns] 2023-02-27T10:0...\n",
" * pv__pv_system_id (pv__pv_system_id) int64 0\n",
" pv__observed_capacity_wp (pv__pv_system_id) float64 ...\n",
" pv__nominal_capacity_wp (pv__pv_system_id) float64 ...\n",
" pv__ml_id (pv__pv_system_id) float64 ...\n",
" pv__longitude (pv__pv_system_id) float64 ...\n",
" ... ...\n",
" * nwp-ecmwf__latitude (nwp-ecmwf__latitude) float64 30.65 ... 22.3\n",
" * nwp-ecmwf__longitude (nwp-ecmwf__longitude) float64 70.45 ... 78.8\n",
" * nwp-ecmwf__channel (nwp-ecmwf__channel) <U5 'hcc' 'lcc' ... 'dswrf'\n",
" * nwp-ecmwf__target_time_utc (nwp-ecmwf__target_time_utc) datetime64[ns] 2...\n",
" nwp-ecmwf__init_time_utc datetime64[ns] ...\n",
" nwp-ecmwf__step (nwp-ecmwf__target_time_utc) timedelta64[ns] ...\n",
"Data variables:\n",
" pv (pv__time_utc, pv__pv_system_id) float32 ...\n",
" nwp-ecmwf (nwp-ecmwf__target_time_utc, nwp-ecmwf__channel, nwp-ecmwf__latitude, nwp-ecmwf__longitude) float32 ...)\n"
]
}
],
"source": [
"ds = xr.open_dataset(\"../../data/wind/wind_data.nc\")\n",
"ds = xr.open_dataset(\"~/Downloads/000000.nc\")\n",
"# ds = xr.open_dataset(\"../../data/wind/wind_data.nc\")\n",
"# ds = xr.open_dataset(\"wind/wind_data.nc\") # Test data\n",
"print(ds[\"time_utc\"])"
"print(ds.items())"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-02-14T11:43:10.640355Z",
"start_time": "2024-02-14T11:43:09.958557Z"
"end_time": "2024-02-14T16:39:41.780845Z",
"start_time": "2024-02-14T16:39:41.691233Z"
}
},
"id": "8ef26c400e24207b",
"execution_count": 3
"execution_count": 36
},
{
"cell_type": "code",
Expand All @@ -55,30 +84,34 @@
"name": "stdout",
"output_type": "stream",
"text": [
"<xarray.DataArray 'init_time' (init_time: 1)>\n",
"array(['2024-02-14T09:00:00.000000000'], dtype='datetime64[ns]')\n",
"<xarray.Dataset>\n",
"Dimensions: (init_time: 1, latitude: 221, longitude: 221, step: 85,\n",
" variable: 17)\n",
"Coordinates:\n",
" * init_time (init_time) datetime64[ns] 2024-02-14T09:00:00\n",
"Attributes:\n",
" long_name: initial time of forecast\n",
" standard_name: forecast_reference_time\n"
" * init_time (init_time) datetime64[ns] 2023-05-21\n",
" * latitude (latitude) float64 31.0 30.95 30.9 30.85 ... 20.1 20.05 20.0\n",
" * longitude (longitude) float64 68.0 68.05 68.1 68.15 ... 78.9 78.95 79.0\n",
" * step (step) timedelta64[ns] 00:00:00 01:00:00 ... 3 days 12:00:00\n",
" * variable (variable) object 'dlwrf' 'dswrf' 'duvrs' ... 'v10' 'v100' 'v200'\n",
"Data variables:\n",
" *empty*\n"
]
}
],
"source": [
"ds = xr.open_zarr(\"../../data/nwp.zarr\")\n",
"# ds = xr.open_zarr(\"nwp.zarr\") # test data\n",
"print(ds['init_time'])"
"# ds = xr.open_zarr(\"../../data/nwp.zarr\")\n",
"ds = xr.open_zarr(\"nwp.zarr.bak\") # test data\n",
"print(ds)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-02-14T11:43:12.064264Z",
"start_time": "2024-02-14T11:43:11.971095Z"
"end_time": "2024-02-14T16:41:53.106778Z",
"start_time": "2024-02-14T16:41:53.095116Z"
}
},
"id": "3f6edf56ab0fb0f7",
"execution_count": 4
"execution_count": 38
},
{
"cell_type": "code",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/nwp.zarr/.zmetadata
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,4 @@
}
},
"zarr_consolidated_format": 1
}
}
Binary file modified tests/test_data/nwp.zarr/init_time/0
Binary file not shown.

0 comments on commit a0f0df3

Please sign in to comment.