diff --git a/neso_solar_consumer/format_forecast.py b/neso_solar_consumer/format_forecast.py new file mode 100644 index 0000000..f1348a9 --- /dev/null +++ b/neso_solar_consumer/format_forecast.py @@ -0,0 +1,54 @@ +import logging +from datetime import datetime, timezone +import pandas as pd +from nowcasting_datamodel.models import ForecastSQL, ForecastValue +from nowcasting_datamodel.read.read import get_latest_input_data_last_updated, get_location +from nowcasting_datamodel.read.read_models import get_model + +# Configure logging (set to INFO for production; use DEBUG during debugging) +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") +logger = logging.getLogger(__name__) + + +def format_to_forecast_sql(data: pd.DataFrame, model_tag: str, model_version: str, session) -> list: + logger.info("Starting format_to_forecast_sql process...") + + # Step 1: Retrieve model metadata + model = get_model(name=model_tag, version=model_version, session=session) + input_data_last_updated = get_latest_input_data_last_updated(session=session) + + # Step 2: Fetch or create the location + location = get_location(session=session, gsp_id=0) # National forecast + + # Step 3: Process all rows into ForecastValue objects + forecast_values = [] + for _, row in data.iterrows(): + if pd.isnull(row["start_utc"]) or pd.isnull(row["solar_forecast_kw"]): + logger.warning(f"Skipping row due to missing data: {row}") + continue + + try: + target_time = datetime.fromisoformat(row["start_utc"]).replace(tzinfo=timezone.utc) + except ValueError: + logger.warning(f"Invalid datetime format: {row['start_utc']}. Skipping row.") + continue + + forecast_value = ForecastValue( + target_time=target_time, + expected_power_generation_megawatts=row["solar_forecast_kw"] / 1000, # Convert to MW + ).to_orm() + forecast_values.append(forecast_value) + + # Step 4: Create a single ForecastSQL object + forecast = ForecastSQL( + model=model, + forecast_creation_time=datetime.now(tz=timezone.utc), + location=location, + input_data_last_updated=input_data_last_updated, + forecast_values=forecast_values, + historic=False, + ) + logger.info(f"Created ForecastSQL object with {len(forecast_values)} values.") + + # Return a single ForecastSQL object in a list + return [forecast] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1ba90d1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,84 @@ +import pytest +from typing import Generator +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from nowcasting_datamodel.models.base import Base_Forecast +from nowcasting_datamodel.models import MLModelSQL +from testcontainers.postgres import PostgresContainer + +# Shared Test Configuration Constants +RESOURCE_ID = "db6c038f-98af-4570-ab60-24d71ebd0ae5" +LIMIT = 5 +COLUMNS = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"] +RENAME_COLUMNS = { + "DATE_GMT": "start_utc", + "TIME_GMT": "end_utc", + "EMBEDDED_SOLAR_FORECAST": "solar_forecast_kw", +} +MODEL_NAME = "real_data_model" +MODEL_VERSION = "1.0" + + +@pytest.fixture(scope="session") +def postgres_container(): + """ + Fixture to spin up a PostgreSQL container for the entire test session. + This fixture uses `testcontainers` to start a fresh PostgreSQL container and provides + the connection URL dynamically for use in other fixtures. + """ + with PostgresContainer("postgres:15.5") as postgres: + postgres.start() + yield postgres.get_connection_url() + + +@pytest.fixture(scope="function") +def db_session(postgres_container) -> Generator: + """ + Fixture to set up and tear down a PostgreSQL database session for testing. + This fixture: + - Connects to the PostgreSQL container provided by `postgres_container`. + - Creates a fresh database schema before each test. + - Adds a dummy ML model for test purposes. + - Tears down the database session and cleans up resources after each test. + Args: + postgres_container (str): The dynamic connection URL provided by PostgresContainer. + Returns: + Generator: A SQLAlchemy session object. + """ + # Use the dynamic connection URL + engine = create_engine(postgres_container) + Base_Forecast.metadata.drop_all(engine) # Drop all tables to ensure a clean slate + Base_Forecast.metadata.create_all(engine) # Recreate the tables + + # Establish session + Session = sessionmaker(bind=engine) + session = Session() + + # Insert a dummy model for testing + session.query(MLModelSQL).delete() # Clean up any pre-existing data + model = MLModelSQL(name=MODEL_NAME, version=MODEL_VERSION) + session.add(model) + session.commit() + + yield session # Provide the session to the test + + # Cleanup: close session and dispose of engine + session.close() + engine.dispose() + + +@pytest.fixture(scope="session") +def test_config(): + """ + Fixture to provide shared test configuration constants. + Returns: + dict: A dictionary of test configuration values. + """ + return { + "resource_id": RESOURCE_ID, + "limit": LIMIT, + "columns": COLUMNS, + "rename_columns": RENAME_COLUMNS, + "model_name": MODEL_NAME, + "model_version": MODEL_VERSION, + } \ No newline at end of file diff --git a/tests/test_format_forecast.py b/tests/test_format_forecast.py new file mode 100644 index 0000000..1a55b99 --- /dev/null +++ b/tests/test_format_forecast.py @@ -0,0 +1,36 @@ +from neso_solar_consumer.fetch_data import fetch_data +from neso_solar_consumer.format_forecast import format_to_forecast_sql +from nowcasting_datamodel.models import ForecastSQL, ForecastValue + + +def test_format_to_forecast_sql_real(db_session, test_config): + """ + Test `format_to_forecast_sql` with real data fetched from the NESO API. + """ + # Step 1: Fetch data from the API + data = fetch_data( + test_config["resource_id"], + test_config["limit"], + test_config["columns"], + test_config["rename_columns"], + ) + assert not data.empty, "fetch_data returned an empty DataFrame!" + + # Step 2: Format the data into a ForecastSQL object + forecasts = format_to_forecast_sql( + data, test_config["model_name"], test_config["model_version"], db_session + ) + assert len(forecasts) == 1, f"Expected 1 ForecastSQL object, got {len(forecasts)}" + + # Step 3: Validate the ForecastSQL content + forecast = forecasts[0] + + # Filter out invalid rows from the DataFrame (like your function would) + valid_data = data.drop_duplicates(subset=["start_utc", "end_utc", "solar_forecast_kw"]) + valid_data = valid_data[ + valid_data["start_utc"].apply(lambda x: isinstance(x, str) and len(x) > 0) # Valid datetime + ] + + assert len(forecast.forecast_values) == len(valid_data), ( + f"Mismatch in ForecastValue entries! Expected {len(valid_data)} but got {len(forecast.forecast_values)}." + )