forked from openclimatefix/neso-solar-consumer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add functionality and testing for solar forecast data processing
- Added `format_to_forecast_sql` in `format_forecast.py` to convert NESO solar forecast data into a `ForecastSQL` object. - Fetches model metadata, location, and input data last updated timestamp. - Processes rows to create `ForecastValue` objects and aggregates them into a single `ForecastSQL` object. - Logs key steps for improved traceability and debugging. - Added `test_format_to_forecast_sql_real` in `test_format_forecast.py` to validate the functionality of `format_to_forecast_sql`. - Fetches real data using `fetch_data` and ensures it is correctly formatted into a `ForecastSQL` object. - Performs validation to ensure the number of `ForecastValue` entries matches the filtered DataFrame. - Added `conftest.py` to provide shared test fixtures for the test suite: - `postgres_container` fixture spins up a PostgreSQL container for isolated testing using `testcontainers`. - `db_session` fixture sets up a clean database schema and a test ML model for each test function. - `test_config` fixture provides shared configuration constants for tests. These changes establish a robust framework for processing and testing solar forecast data effectively.
- Loading branch information
1 parent
5a32560
commit 0ee35ad
Showing
3 changed files
with
174 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import logging | ||
from datetime import datetime, timezone | ||
import pandas as pd | ||
from nowcasting_datamodel.models import ForecastSQL, ForecastValue | ||
from nowcasting_datamodel.read.read import get_latest_input_data_last_updated, get_location | ||
from nowcasting_datamodel.read.read_models import get_model | ||
|
||
# Configure logging (set to INFO for production; use DEBUG during debugging) | ||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def format_to_forecast_sql(data: pd.DataFrame, model_tag: str, model_version: str, session) -> list: | ||
logger.info("Starting format_to_forecast_sql process...") | ||
|
||
# Step 1: Retrieve model metadata | ||
model = get_model(name=model_tag, version=model_version, session=session) | ||
input_data_last_updated = get_latest_input_data_last_updated(session=session) | ||
|
||
# Step 2: Fetch or create the location | ||
location = get_location(session=session, gsp_id=0) # National forecast | ||
|
||
# Step 3: Process all rows into ForecastValue objects | ||
forecast_values = [] | ||
for _, row in data.iterrows(): | ||
if pd.isnull(row["start_utc"]) or pd.isnull(row["solar_forecast_kw"]): | ||
logger.warning(f"Skipping row due to missing data: {row}") | ||
continue | ||
|
||
try: | ||
target_time = datetime.fromisoformat(row["start_utc"]).replace(tzinfo=timezone.utc) | ||
except ValueError: | ||
logger.warning(f"Invalid datetime format: {row['start_utc']}. Skipping row.") | ||
continue | ||
|
||
forecast_value = ForecastValue( | ||
target_time=target_time, | ||
expected_power_generation_megawatts=row["solar_forecast_kw"] / 1000, # Convert to MW | ||
).to_orm() | ||
forecast_values.append(forecast_value) | ||
|
||
# Step 4: Create a single ForecastSQL object | ||
forecast = ForecastSQL( | ||
model=model, | ||
forecast_creation_time=datetime.now(tz=timezone.utc), | ||
location=location, | ||
input_data_last_updated=input_data_last_updated, | ||
forecast_values=forecast_values, | ||
historic=False, | ||
) | ||
logger.info(f"Created ForecastSQL object with {len(forecast_values)} values.") | ||
|
||
# Return a single ForecastSQL object in a list | ||
return [forecast] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import pytest | ||
from typing import Generator | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import sessionmaker | ||
from nowcasting_datamodel.models.base import Base_Forecast | ||
from nowcasting_datamodel.models import MLModelSQL | ||
from testcontainers.postgres import PostgresContainer | ||
|
||
# Shared Test Configuration Constants | ||
RESOURCE_ID = "db6c038f-98af-4570-ab60-24d71ebd0ae5" | ||
LIMIT = 5 | ||
COLUMNS = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"] | ||
RENAME_COLUMNS = { | ||
"DATE_GMT": "start_utc", | ||
"TIME_GMT": "end_utc", | ||
"EMBEDDED_SOLAR_FORECAST": "solar_forecast_kw", | ||
} | ||
MODEL_NAME = "real_data_model" | ||
MODEL_VERSION = "1.0" | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def postgres_container(): | ||
""" | ||
Fixture to spin up a PostgreSQL container for the entire test session. | ||
This fixture uses `testcontainers` to start a fresh PostgreSQL container and provides | ||
the connection URL dynamically for use in other fixtures. | ||
""" | ||
with PostgresContainer("postgres:15.5") as postgres: | ||
postgres.start() | ||
yield postgres.get_connection_url() | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def db_session(postgres_container) -> Generator: | ||
""" | ||
Fixture to set up and tear down a PostgreSQL database session for testing. | ||
This fixture: | ||
- Connects to the PostgreSQL container provided by `postgres_container`. | ||
- Creates a fresh database schema before each test. | ||
- Adds a dummy ML model for test purposes. | ||
- Tears down the database session and cleans up resources after each test. | ||
Args: | ||
postgres_container (str): The dynamic connection URL provided by PostgresContainer. | ||
Returns: | ||
Generator: A SQLAlchemy session object. | ||
""" | ||
# Use the dynamic connection URL | ||
engine = create_engine(postgres_container) | ||
Base_Forecast.metadata.drop_all(engine) # Drop all tables to ensure a clean slate | ||
Base_Forecast.metadata.create_all(engine) # Recreate the tables | ||
|
||
# Establish session | ||
Session = sessionmaker(bind=engine) | ||
session = Session() | ||
|
||
# Insert a dummy model for testing | ||
session.query(MLModelSQL).delete() # Clean up any pre-existing data | ||
model = MLModelSQL(name=MODEL_NAME, version=MODEL_VERSION) | ||
session.add(model) | ||
session.commit() | ||
|
||
yield session # Provide the session to the test | ||
|
||
# Cleanup: close session and dispose of engine | ||
session.close() | ||
engine.dispose() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def test_config(): | ||
""" | ||
Fixture to provide shared test configuration constants. | ||
Returns: | ||
dict: A dictionary of test configuration values. | ||
""" | ||
return { | ||
"resource_id": RESOURCE_ID, | ||
"limit": LIMIT, | ||
"columns": COLUMNS, | ||
"rename_columns": RENAME_COLUMNS, | ||
"model_name": MODEL_NAME, | ||
"model_version": MODEL_VERSION, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from neso_solar_consumer.fetch_data import fetch_data | ||
from neso_solar_consumer.format_forecast import format_to_forecast_sql | ||
from nowcasting_datamodel.models import ForecastSQL, ForecastValue | ||
|
||
|
||
def test_format_to_forecast_sql_real(db_session, test_config): | ||
""" | ||
Test `format_to_forecast_sql` with real data fetched from the NESO API. | ||
""" | ||
# Step 1: Fetch data from the API | ||
data = fetch_data( | ||
test_config["resource_id"], | ||
test_config["limit"], | ||
test_config["columns"], | ||
test_config["rename_columns"], | ||
) | ||
assert not data.empty, "fetch_data returned an empty DataFrame!" | ||
|
||
# Step 2: Format the data into a ForecastSQL object | ||
forecasts = format_to_forecast_sql( | ||
data, test_config["model_name"], test_config["model_version"], db_session | ||
) | ||
assert len(forecasts) == 1, f"Expected 1 ForecastSQL object, got {len(forecasts)}" | ||
|
||
# Step 3: Validate the ForecastSQL content | ||
forecast = forecasts[0] | ||
|
||
# Filter out invalid rows from the DataFrame (like your function would) | ||
valid_data = data.drop_duplicates(subset=["start_utc", "end_utc", "solar_forecast_kw"]) | ||
valid_data = valid_data[ | ||
valid_data["start_utc"].apply(lambda x: isinstance(x, str) and len(x) > 0) # Valid datetime | ||
] | ||
|
||
assert len(forecast.forecast_values) == len(valid_data), ( | ||
f"Mismatch in ForecastValue entries! Expected {len(valid_data)} but got {len(forecast.forecast_values)}." | ||
) |