Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

site torch dataset #74

Merged
merged 26 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2a4acfc
first try at site workflow
peterdudfield Oct 25, 2024
941cf03
remove side_ids, rename sites --> site, system_id --> site_id
peterdudfield Oct 25, 2024
acdad11
add site dropout
peterdudfield Oct 25, 2024
658f0fd
refactor minutes transfer
peterdudfield Oct 25, 2024
9ad0975
Add comment
peterdudfield Oct 25, 2024
5ba4cd8
Merge branch 'main' into issue/site-pipeline
peterdudfield Oct 28, 2024
0964a48
add Legacy support
peterdudfield Oct 28, 2024
e4882d5
Update ocf_data_sampler/config/model.py
peterdudfield Oct 31, 2024
021cf41
reanme Sites->Site
peterdudfield Oct 31, 2024
df7b13a
Merge commit 'e4882d5df4fb082bff39918501d4dcb21b88b51f' into issue/si…
peterdudfield Oct 31, 2024
3a85983
remove legacy format to script
peterdudfield Oct 31, 2024
ae80011
Update scripts/refactor_site.py
peterdudfield Nov 4, 2024
972bb0f
Update ocf_data_sampler/torch_datasets/site.py
peterdudfield Nov 4, 2024
f930ff5
Update ocf_data_sampler/torch_datasets/site.py
peterdudfield Nov 4, 2024
7d00ab5
Update ocf_data_sampler/select/space_slice_for_dataset.py
peterdudfield Nov 4, 2024
7dc2432
Update ocf_data_sampler/torch_datasets/process_and_combine.py
peterdudfield Nov 4, 2024
e75f7ad
Update ocf_data_sampler/load/load_dataset.py
peterdudfield Nov 4, 2024
e7a7a22
Update scripts/refactor_site.py
peterdudfield Nov 4, 2024
3a666c4
Some of PR comments
peterdudfield Nov 4, 2024
b59aee2
refactor
peterdudfield Nov 4, 2024
1d9317c
add sanity checks
peterdudfield Nov 4, 2024
f9d49d3
tidy imports
peterdudfield Nov 4, 2024
8660732
add tests
peterdudfield Nov 4, 2024
61f697b
add get_sample back in + test
peterdudfield Nov 4, 2024
7c03a81
Update ocf_data_sampler/torch_datasets/site.py
peterdudfield Nov 5, 2024
158fee1
Update ocf_data_sampler/torch_datasets/site.py
peterdudfield Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,39 @@ class TimeResolutionMixin(Base):
)


class Site(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
"""Site configuration model"""

file_path: str = Field(
...,
description="The NetCDF files holding the power timeseries.",
)
metadata_file_path: str = Field(
...,
description="The CSV files describing power system",
)

@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check forecast length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check history length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

# TODO validate the netcdf for sites
# TODO validate the csv for metadata
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved

class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
"""Satellite configuration model"""

Expand Down Expand Up @@ -240,6 +273,7 @@ class InputData(Base):
satellite: Optional[Satellite] = None
nwp: Optional[MultiNWP] = None
gsp: Optional[GSP] = None
site: Optional[Site] = None


class Configuration(Base):
Expand Down
55 changes: 55 additions & 0 deletions ocf_data_sampler/load/load_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
""" Loads all data sources """
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
import xarray as xr
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved

from ocf_data_sampler.config import Configuration
from ocf_data_sampler.load.gsp import open_gsp
from ocf_data_sampler.load.nwp import open_nwp
from ocf_data_sampler.load.satellite import open_sat_data
from ocf_data_sampler.load.site import open_site


def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]:
"""Construct dictionary of all of the input data sources

Args:
config: Configuration file
"""

in_config = config.input_data

datasets_dict = {}
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved

# Load GSP data unless the path is None
if in_config.gsp and in_config.gsp.gsp_zarr_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think checking for in_config.gsp is enough. zarr_path is a required field so config model should be checking it exists for you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, ok, need to look into this. There's defiantely a test that makes sure it works with no gsp data. Ill have a further look at this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, there is a funny edge case, where in live we have no live GSP data, as it's not used for the forecast. But we still need gsp.time_resolution_minutes, gsp_config.forecast_minutes and gsp_config.history_minutes for the sun coords position. There's a slight case for making a Sun configuration but, perhaps its better in a different PR

da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path).compute()

AUdaltsova marked this conversation as resolved.
Show resolved Hide resolved
# Remove national GSP
datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))

# Load NWP data if in config
if in_config.nwp:

datasets_dict["nwp"] = {}
for nwp_source, nwp_config in in_config.nwp.items():

da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)

da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels))

datasets_dict["nwp"][nwp_source] = da_nwp

# Load satellite data if in config
if in_config.satellite:
sat_config = config.input_data.satellite

da_sat = open_sat_data(sat_config.satellite_zarr_path)

da_sat = da_sat.sel(channel=list(sat_config.satellite_channels))

datasets_dict["sat"] = da_sat

if in_config.site:
da_sites = open_site(in_config.site)
datasets_dict["site"] = da_sites

return datasets_dict
7 changes: 5 additions & 2 deletions ocf_data_sampler/load/nwp/providers/ecmwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
)



def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
"""
Opens the ECMWF IFS NWP data
Expand All @@ -27,10 +26,14 @@ def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
ds = ds.rename(
{
"init_time": "init_time_utc",
"variable": "channel",
}
)

# LEGACY SUPPORT
# rename variable to channel if it exists
if "variable" in ds:
ds = ds.rename({"variable": "channel"})

peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
# Check the timestamps are unique and increasing
check_time_unique_increasing(ds.init_time_utc)

Expand Down
30 changes: 30 additions & 0 deletions ocf_data_sampler/load/site.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pandas as pd
import xarray as xr
import numpy as np

from ocf_data_sampler.config.model import Site


def open_site(sites_config: Site) -> xr.DataArray:
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved

# Load site generation xr.Dataset
data_ds = xr.open_dataset(sites_config.file_path)

# Load site generation data
metadata_df = pd.read_csv(sites_config.metadata_file_path, index_col="site_id")

# Add coordinates
ds = data_ds.assign_coords(
latitude=(metadata_df.latitude.to_xarray()),
longitude=(metadata_df.longitude.to_xarray()),
capacity_kwp=data_ds.capacity_kwp,
)

# Sanity checks
assert np.isfinite(data_ds.capacity_kwp.values).all()
assert (data_ds.capacity_kwp.values > 0).all()
assert metadata_df.index.is_unique

return ds.generation_kw
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved


1 change: 1 addition & 0 deletions ocf_data_sampler/numpy_batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .nwp import convert_nwp_to_numpy_batch, NWPBatchKey
from .satellite import convert_satellite_to_numpy_batch, SatelliteBatchKey
from .sun_position import make_sun_position_numpy_batch
from .site import convert_site_to_numpy_batch

29 changes: 29 additions & 0 deletions ocf_data_sampler/numpy_batch/site.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Convert site to Numpy Batch"""

import xarray as xr


class SiteBatchKey:

generation = "site"
site_capacity_kwp = "site_capacity_kwp"
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
site_time_utc = "site_time_utc"
site_t0_idx = "site_t0_idx"
site_solar_azimuth = "site_solar_azimuth"
site_solar_elevation = "site_solar_elevation"
site_id = "site_id"


def convert_site_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
"""Convert from Xarray to NumpyBatch"""

example = {
SiteBatchKey.generation: da.values,
SiteBatchKey.site_capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values,
SiteBatchKey.site_time_utc: da["time_utc"].values.astype(float),
}

peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
if t0_idx is not None:
example[SiteBatchKey.site_t0_idx] = t0_idx

return example
9 changes: 8 additions & 1 deletion ocf_data_sampler/select/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@

from .fill_time_periods import fill_time_periods
from .find_contiguous_time_periods import (
find_contiguous_t0_periods,
intersection_of_multiple_dataframes_of_periods,
)
from .location import Location
from .spatial_slice_for_dataset import slice_datasets_by_space
from .time_slice_for_dataset import slice_datasets_by_time
3 changes: 2 additions & 1 deletion ocf_data_sampler/select/dropout.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
""" Functions for simulating dropout in time series data """
import numpy as np
import pandas as pd
import xarray as xr


def draw_dropout_time(
t0: pd.Timestamp,
dropout_timedeltas: list[pd.Timedelta] | None,
dropout_timedeltas: list[pd.Timedelta] | pd.Timedelta | None,
dropout_frac: float = 0,
):

Expand Down
44 changes: 43 additions & 1 deletion ocf_data_sampler/select/geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,23 @@ def lon_lat_to_osgb(
return _lon_lat_to_osgb(xx=x, yy=y)


def lon_lat_to_geostationary_area_coords(
longitude: Union[Number, np.ndarray],
latitude: Union[Number, np.ndarray],
xr_data: xr.DataArray,
) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
"""Loads geostationary area and transformation from lat-lon to geostationary coords
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved

Args:
longitude: longitude
latitude: latitude
xr_data: xarray object with geostationary area

Returns:
Geostationary coords: x, y
"""
return coordinates_to_geostationary_area_coords(longitude, latitude, xr_data, WGS84)

def osgb_to_geostationary_area_coords(
x: Union[Number, np.ndarray],
y: Union[Number, np.ndarray],
Expand All @@ -70,6 +87,31 @@ def osgb_to_geostationary_area_coords(
Returns:
Geostationary coords: x, y
"""

return coordinates_to_geostationary_area_coords(x, y, xr_data, OSGB36)



def coordinates_to_geostationary_area_coords(
x: Union[Number, np.ndarray],
y: Union[Number, np.ndarray],
xr_data: xr.DataArray,
crs_from: int
) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
"""Loads geostationary area and transformation from respective coordiates to geostationary coords

Args:
x: osgb east-west, or latitude
y: osgb north-south, or longitude
xr_data: xarray object with geostationary area
crs_from: the cordiates system of x,y

Returns:
Geostationary coords: x, y
"""

assert crs_from in [OSGB36, WGS84], f"Unrecognized coordinate system: {crs_from}"

# Only load these if using geostationary projection
import pyresample

Expand All @@ -80,7 +122,7 @@ def osgb_to_geostationary_area_coords(
)
geostationary_crs = geostationary_area_definition.crs
osgb_to_geostationary = pyproj.Transformer.from_crs(
crs_from=OSGB36, crs_to=geostationary_crs, always_xy=True
crs_from=crs_from, crs_to=geostationary_crs, always_xy=True
).transform
return osgb_to_geostationary(xx=x, yy=y)

Expand Down
10 changes: 8 additions & 2 deletions ocf_data_sampler/select/select_spatial_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.select.geospatial import (
lon_lat_to_osgb,
lon_lat_to_geostationary_area_coords,
osgb_to_geostationary_area_coords,
osgb_to_lon_lat,
spatial_coord_type,
Expand Down Expand Up @@ -101,7 +102,7 @@ def _get_idx_of_pixel_closest_to_poi(

def _get_idx_of_pixel_closest_to_poi_geostationary(
da: xr.DataArray,
center_osgb: Location,
center: Location,
) -> Location:
"""
Return x and y index location of pixel at center of region of interest.
Expand All @@ -116,7 +117,12 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(

_, x_dim, y_dim = spatial_coord_type(da)

x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=da)
if center.coordinate_system == 'osgb':
x, y = osgb_to_geostationary_area_coords(x=center.x, y=center.y, xr_data=da)
elif center.coordinate_system == 'lon_lat':
x, y = lon_lat_to_geostationary_area_coords(longitude=center.x, latitude=center.y, xr_data=da)
else:
x,y = center.x, center.y
center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")

# Check that the requested point lies within the data
Expand Down
53 changes: 53 additions & 0 deletions ocf_data_sampler/select/spatial_slice_for_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
""" Functions for selecting data around a given location """
from ocf_data_sampler.config import Configuration
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels


def slice_datasets_by_space(
datasets_dict: dict,
location: Location,
config: Configuration,
) -> dict:
"""Slice the dictionary of input data sources around a given location

Args:
datasets_dict: Dictionary of the input data sources
location: The location to sample around
config: Configuration object.
"""

assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp", "site"})

sliced_datasets_dict = {}

if "nwp" in datasets_dict:

sliced_datasets_dict["nwp"] = {}

for nwp_key, nwp_config in config.input_data.nwp.items():

sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
datasets_dict["nwp"][nwp_key],
location,
height_pixels=nwp_config.nwp_image_size_pixels_height,
width_pixels=nwp_config.nwp_image_size_pixels_width,
)

if "sat" in datasets_dict:
sat_config = config.input_data.satellite

sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
datasets_dict["sat"],
location,
height_pixels=sat_config.satellite_image_size_pixels_height,
width_pixels=sat_config.satellite_image_size_pixels_width,
)

if "gsp" in datasets_dict:
sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)

if "site" in datasets_dict:
sliced_datasets_dict["site"] = datasets_dict["site"].sel(site_id=location.id)

return sliced_datasets_dict
Loading
Loading