Skip to content

Commit

Permalink
Merge pull request #356 from openclimatefix/time-divisibility
Browse files Browse the repository at this point in the history
Enforce forecast and history duration divisibility by time resolution
  • Loading branch information
AUdaltsova authored Aug 12, 2024
2 parents a6eec2b + ddae643 commit 9ff98d2
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 18 deletions.
89 changes: 71 additions & 18 deletions ocf_datapipes/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import git
import numpy as np
from pathy import Pathy
from pydantic import BaseModel, Field, RootModel, model_validator, validator
from pydantic import BaseModel, Field, RootModel, ValidationInfo, field_validator, model_validator

# nowcasting_dataset imports
from ocf_datapipes.utils.consts import (
Expand Down Expand Up @@ -93,8 +93,8 @@ class DataSourceMixin(Base):

log_level: str = Field(
"DEBUG",
description="The logging level for this data source. T"
"his is the default value and can be set in each data source",
description="The logging level for this data source. "
"This is the default value and can be set in each data source",
)

@property
Expand Down Expand Up @@ -139,16 +139,16 @@ class DropoutMixin(Base):

dropout_fraction: float = Field(0, description="Chance of dropout being applied to each sample")

@validator("dropout_timedeltas_minutes")
def dropout_timedeltas_minutes_negative(cls, v):
@field_validator("dropout_timedeltas_minutes")
def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
"""Validate 'dropout_timedeltas_minutes'"""
if v is not None:
for m in v:
assert m <= 0
return v

@validator("dropout_fraction")
def dropout_fraction_valid(cls, v):
@field_validator("dropout_fraction")
def dropout_fraction_valid(cls, v: float) -> float:
"""Validate 'dropout_fraction'"""
assert 0 <= v <= 1
return v
Expand All @@ -169,8 +169,8 @@ class SystemDropoutMixin(Base):
system_dropout_fraction_min: float = Field(0, description="Min chance of system dropout")
system_dropout_fraction_max: float = Field(0, description="Max chance of system dropout")

@validator("system_dropout_fraction_min", "system_dropout_fraction_max")
def validate_system_dropout_fractions(cls, v):
@field_validator("system_dropout_fraction_min", "system_dropout_fraction_max")
def validate_system_dropout_fractions(cls, v: float):
"""Validate dropout fraction values"""
assert 0 <= v <= 1
return v
Expand All @@ -192,8 +192,8 @@ class TimeResolutionMixin(Base):
"Note that this needs to be divisible by 5.",
)

@validator("time_resolution_minutes")
def forecast_minutes_divide_by_5(cls, v):
@field_validator("time_resolution_minutes")
def forecast_minutes_divide_by_5(cls, v: int) -> int:
"""Validate 'forecast_minutes'"""
assert v % 5 == 0, f"The time resolution ({v}) is not divisible by 5"
return v
Expand Down Expand Up @@ -257,7 +257,6 @@ class Wind(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixi
None,
description="List of the ML IDs of the Wind systems you'd like to filter to.",
)
time_resolution_minutes: int = Field(15, description="The temporal resolution (in minutes).")
wind_image_size_meters_height: int = METERS_PER_ROI
wind_image_size_meters_width: int = METERS_PER_ROI
n_wind_systems_per_example: int = Field(
Expand Down Expand Up @@ -286,6 +285,24 @@ class Wind(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixi
"Note that this needs to be divisible by 5.",
)

@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check forecast length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check history length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


class PVFiles(BaseModel):
"""Model to hold pv file and metadata file"""
Expand All @@ -305,8 +322,8 @@ class PVFiles(BaseModel):

label: Optional[str] = Field(providers[0], description="Label of where the pv data came from")

@validator("label")
def v_label0(cls, v):
@field_validator("label")
def v_label0(cls, v: str) -> str:
"""Validate 'label'"""
if v not in providers:
message = f"provider {v} not in {providers}"
Expand Down Expand Up @@ -385,6 +402,24 @@ def model_validation(cls, v):

return v

@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check forecast length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check history length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


class Sensor(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames):
"""PV configuration model"""
Expand Down Expand Up @@ -599,15 +634,33 @@ class NWP(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixin
0.1, description="The number of degrees to coarsen the NWP data to"
)

@validator("nwp_provider")
def validate_nwp_provider(cls, v):
@field_validator("nwp_provider")
def validate_nwp_provider(cls, v: str) -> str:
"""Validate 'nwp_provider'"""
if v.lower() not in NWP_PROVIDERS:
message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
logger.warning(message)
assert Exception(message)
return v

@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check forecast length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check history length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


class MultiNWP(RootModel):
"""Configuration for multiple NWPs"""
Expand Down Expand Up @@ -668,13 +721,13 @@ class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
"Note that this needs to be divisible by 5.",
)

@validator("history_minutes")
@field_validator("history_minutes")
def history_minutes_divide_by_30(cls, v):
"""Validate 'history_minutes'"""
assert v % 30 == 0 # this means it also divides by 5
return v

@validator("forecast_minutes")
@field_validator("forecast_minutes")
def forecast_minutes_divide_by_30(cls, v):
"""Validate 'forecast_minutes'"""
assert v % 30 == 0 # this means it also divides by 5
Expand Down
24 changes: 24 additions & 0 deletions tests/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,27 @@ def test_config_git(configuration_filename):
assert type(config.git.message) == str
assert type(config.git.hash) == str
assert type(config.git.committed_date) == datetime


def test_incorrect_forecast_minutes():
"""
Check a forecast length no divisible by time resolution causes error
"""

configuration = Configuration()
configuration.input_data = configuration.input_data.set_all_to_defaults()
configuration.input_data.wind.forecast_minutes = 1111
with pytest.raises(Exception):
_ = Configuration(**configuration.dict())


def test_incorrect_history_minutes():
"""
Check a forecast length no divisible by time resolution causes error
"""

configuration = Configuration()
configuration.input_data = configuration.input_data.set_all_to_defaults()
configuration.input_data.wind.history_minutes = 1111
with pytest.raises(Exception):
_ = Configuration(**configuration.dict())

0 comments on commit 9ff98d2

Please sign in to comment.