Skip to content

Commit

Permalink
feat: add functionality and testing for solar forecast data processing
Browse files Browse the repository at this point in the history
- Added `format_to_forecast_sql` in `format_forecast.py` to convert NESO solar forecast data into a `ForecastSQL` object.
  - Fetches model metadata, location, and input data last updated timestamp.
  - Processes rows to create `ForecastValue` objects and aggregates them into a single `ForecastSQL` object.
  - Logs key steps for improved traceability and debugging.

- Added `test_format_to_forecast_sql_real` in `test_format_forecast.py` to validate the functionality of `format_to_forecast_sql`.
  - Fetches real data using `fetch_data` and ensures it is correctly formatted into a `ForecastSQL` object.
  - Performs validation to ensure the number of `ForecastValue` entries matches the filtered DataFrame.

- Added `conftest.py` to provide shared test fixtures for the test suite:
  - `postgres_container` fixture spins up a PostgreSQL container for isolated testing using `testcontainers`.
  - `db_session` fixture sets up a clean database schema and a test ML model for each test function.
  - `test_config` fixture provides shared configuration constants for tests.

These changes establish a robust framework for processing and testing solar forecast data effectively.
  • Loading branch information
siddharth7113 committed Jan 2, 2025
1 parent 5a32560 commit 0ee35ad
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 0 deletions.
54 changes: 54 additions & 0 deletions neso_solar_consumer/format_forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
from datetime import datetime, timezone
import pandas as pd
from nowcasting_datamodel.models import ForecastSQL, ForecastValue
from nowcasting_datamodel.read.read import get_latest_input_data_last_updated, get_location
from nowcasting_datamodel.read.read_models import get_model

# Configure logging (set to INFO for production; use DEBUG during debugging)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)


def format_to_forecast_sql(data: pd.DataFrame, model_tag: str, model_version: str, session) -> list:
logger.info("Starting format_to_forecast_sql process...")

# Step 1: Retrieve model metadata
model = get_model(name=model_tag, version=model_version, session=session)
input_data_last_updated = get_latest_input_data_last_updated(session=session)

# Step 2: Fetch or create the location
location = get_location(session=session, gsp_id=0) # National forecast

# Step 3: Process all rows into ForecastValue objects
forecast_values = []
for _, row in data.iterrows():
if pd.isnull(row["start_utc"]) or pd.isnull(row["solar_forecast_kw"]):
logger.warning(f"Skipping row due to missing data: {row}")
continue

try:
target_time = datetime.fromisoformat(row["start_utc"]).replace(tzinfo=timezone.utc)
except ValueError:
logger.warning(f"Invalid datetime format: {row['start_utc']}. Skipping row.")
continue

forecast_value = ForecastValue(
target_time=target_time,
expected_power_generation_megawatts=row["solar_forecast_kw"] / 1000, # Convert to MW
).to_orm()
forecast_values.append(forecast_value)

# Step 4: Create a single ForecastSQL object
forecast = ForecastSQL(
model=model,
forecast_creation_time=datetime.now(tz=timezone.utc),
location=location,
input_data_last_updated=input_data_last_updated,
forecast_values=forecast_values,
historic=False,
)
logger.info(f"Created ForecastSQL object with {len(forecast_values)} values.")

# Return a single ForecastSQL object in a list
return [forecast]
84 changes: 84 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest
from typing import Generator
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from nowcasting_datamodel.models.base import Base_Forecast
from nowcasting_datamodel.models import MLModelSQL
from testcontainers.postgres import PostgresContainer

# Shared Test Configuration Constants
RESOURCE_ID = "db6c038f-98af-4570-ab60-24d71ebd0ae5"
LIMIT = 5
COLUMNS = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"]
RENAME_COLUMNS = {
"DATE_GMT": "start_utc",
"TIME_GMT": "end_utc",
"EMBEDDED_SOLAR_FORECAST": "solar_forecast_kw",
}
MODEL_NAME = "real_data_model"
MODEL_VERSION = "1.0"


@pytest.fixture(scope="session")
def postgres_container():
"""
Fixture to spin up a PostgreSQL container for the entire test session.
This fixture uses `testcontainers` to start a fresh PostgreSQL container and provides
the connection URL dynamically for use in other fixtures.
"""
with PostgresContainer("postgres:15.5") as postgres:
postgres.start()
yield postgres.get_connection_url()


@pytest.fixture(scope="function")
def db_session(postgres_container) -> Generator:
"""
Fixture to set up and tear down a PostgreSQL database session for testing.
This fixture:
- Connects to the PostgreSQL container provided by `postgres_container`.
- Creates a fresh database schema before each test.
- Adds a dummy ML model for test purposes.
- Tears down the database session and cleans up resources after each test.
Args:
postgres_container (str): The dynamic connection URL provided by PostgresContainer.
Returns:
Generator: A SQLAlchemy session object.
"""
# Use the dynamic connection URL
engine = create_engine(postgres_container)
Base_Forecast.metadata.drop_all(engine) # Drop all tables to ensure a clean slate
Base_Forecast.metadata.create_all(engine) # Recreate the tables

# Establish session
Session = sessionmaker(bind=engine)
session = Session()

# Insert a dummy model for testing
session.query(MLModelSQL).delete() # Clean up any pre-existing data
model = MLModelSQL(name=MODEL_NAME, version=MODEL_VERSION)
session.add(model)
session.commit()

yield session # Provide the session to the test

# Cleanup: close session and dispose of engine
session.close()
engine.dispose()


@pytest.fixture(scope="session")
def test_config():
"""
Fixture to provide shared test configuration constants.
Returns:
dict: A dictionary of test configuration values.
"""
return {
"resource_id": RESOURCE_ID,
"limit": LIMIT,
"columns": COLUMNS,
"rename_columns": RENAME_COLUMNS,
"model_name": MODEL_NAME,
"model_version": MODEL_VERSION,
}
36 changes: 36 additions & 0 deletions tests/test_format_forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from neso_solar_consumer.fetch_data import fetch_data
from neso_solar_consumer.format_forecast import format_to_forecast_sql
from nowcasting_datamodel.models import ForecastSQL, ForecastValue


def test_format_to_forecast_sql_real(db_session, test_config):
"""
Test `format_to_forecast_sql` with real data fetched from the NESO API.
"""
# Step 1: Fetch data from the API
data = fetch_data(
test_config["resource_id"],
test_config["limit"],
test_config["columns"],
test_config["rename_columns"],
)
assert not data.empty, "fetch_data returned an empty DataFrame!"

# Step 2: Format the data into a ForecastSQL object
forecasts = format_to_forecast_sql(
data, test_config["model_name"], test_config["model_version"], db_session
)
assert len(forecasts) == 1, f"Expected 1 ForecastSQL object, got {len(forecasts)}"

# Step 3: Validate the ForecastSQL content
forecast = forecasts[0]

# Filter out invalid rows from the DataFrame (like your function would)
valid_data = data.drop_duplicates(subset=["start_utc", "end_utc", "solar_forecast_kw"])
valid_data = valid_data[
valid_data["start_utc"].apply(lambda x: isinstance(x, str) and len(x) > 0) # Valid datetime
]

assert len(forecast.forecast_values) == len(valid_data), (
f"Mismatch in ForecastValue entries! Expected {len(valid_data)} but got {len(forecast.forecast_values)}."
)

0 comments on commit 0ee35ad

Please sign in to comment.