diff --git a/neso_solar_consumer/app.py b/neso_solar_consumer/app.py index b0a7761..cdb2ce2 100644 --- a/neso_solar_consumer/app.py +++ b/neso_solar_consumer/app.py @@ -30,7 +30,11 @@ def get_forecast(): """ resource_id = "example_resource_id" # Replace with the actual resource ID. limit = 100 # Number of records to fetch. - columns = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"] # Relevant columns to extract. + columns = [ + "DATE_GMT", + "TIME_GMT", + "EMBEDDED_SOLAR_FORECAST", + ] # Relevant columns to extract. rename_columns = { "DATE_GMT": "start_utc", "TIME_GMT": "end_utc", diff --git a/neso_solar_consumer/fetch_data.py b/neso_solar_consumer/fetch_data.py index 0267af8..c2dd8b0 100644 --- a/neso_solar_consumer/fetch_data.py +++ b/neso_solar_consumer/fetch_data.py @@ -98,4 +98,4 @@ def fetch_data_using_sql( except Exception as e: print(f"An error occurred: {e}") - return pd.DataFrame() \ No newline at end of file + return pd.DataFrame() diff --git a/neso_solar_consumer/format_forecast.py b/neso_solar_consumer/format_forecast.py new file mode 100644 index 0000000..64c6fda --- /dev/null +++ b/neso_solar_consumer/format_forecast.py @@ -0,0 +1,66 @@ +import logging +from datetime import datetime, timezone +import pandas as pd +from nowcasting_datamodel.models import ForecastSQL, ForecastValue +from nowcasting_datamodel.read.read import ( + get_latest_input_data_last_updated, + get_location, +) +from nowcasting_datamodel.read.read_models import get_model + +# Configure logging (set to INFO for production; use DEBUG during debugging) +logging.basicConfig( + level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" +) +logger = logging.getLogger(__name__) + + +def format_to_forecast_sql( + data: pd.DataFrame, model_tag: str, model_version: str, session +) -> list: + logger.info("Starting format_to_forecast_sql process...") + + # Step 1: Retrieve model metadata + model = get_model(name=model_tag, version=model_version, session=session) + input_data_last_updated = get_latest_input_data_last_updated(session=session) + + # Step 2: Fetch or create the location + location = get_location(session=session, gsp_id=0) # National forecast + + # Step 3: Process all rows into ForecastValue objects + forecast_values = [] + for _, row in data.iterrows(): + if pd.isnull(row["start_utc"]) or pd.isnull(row["solar_forecast_kw"]): + logger.warning(f"Skipping row due to missing data: {row}") + continue + + try: + target_time = datetime.fromisoformat(row["start_utc"]).replace( + tzinfo=timezone.utc + ) + except ValueError: + logger.warning( + f"Invalid datetime format: {row['start_utc']}. Skipping row." + ) + continue + + forecast_value = ForecastValue( + target_time=target_time, + expected_power_generation_megawatts=row["solar_forecast_kw"] + / 1000, # Convert to MW + ).to_orm() + forecast_values.append(forecast_value) + + # Step 4: Create a single ForecastSQL object + forecast = ForecastSQL( + model=model, + forecast_creation_time=datetime.now(tz=timezone.utc), + location=location, + input_data_last_updated=input_data_last_updated, + forecast_values=forecast_values, + historic=False, + ) + logger.info(f"Created ForecastSQL object with {len(forecast_values)} values.") + + # Return a single ForecastSQL object in a list + return [forecast] diff --git a/pyproject.toml b/pyproject.toml index 81bb507..8965d0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,10 +6,15 @@ build-backend = "setuptools.build_meta" name = "neso_solar_consumer" version = "0.1" dependencies = [ - "pandas" + "pandas", + "sqlalchemy", + "nowcasting_datamodel==1.5.56", + "testcontainers" ] [project.optional-dependencies] dev = [ - "pytest","black","ruff" + "pytest", + "black", + "ruff" ] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..eef84f1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,84 @@ +import pytest +from typing import Generator +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from nowcasting_datamodel.models.base import Base_Forecast +from nowcasting_datamodel.models import MLModelSQL +from testcontainers.postgres import PostgresContainer + +# Shared Test Configuration Constants +RESOURCE_ID = "db6c038f-98af-4570-ab60-24d71ebd0ae5" +LIMIT = 5 +COLUMNS = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"] +RENAME_COLUMNS = { + "DATE_GMT": "start_utc", + "TIME_GMT": "end_utc", + "EMBEDDED_SOLAR_FORECAST": "solar_forecast_kw", +} +MODEL_NAME = "real_data_model" +MODEL_VERSION = "1.0" + + +@pytest.fixture(scope="session") +def postgres_container(): + """ + Fixture to spin up a PostgreSQL container for the entire test session. + This fixture uses `testcontainers` to start a fresh PostgreSQL container and provides + the connection URL dynamically for use in other fixtures. + """ + with PostgresContainer("postgres:15.5") as postgres: + postgres.start() + yield postgres.get_connection_url() + + +@pytest.fixture(scope="function") +def db_session(postgres_container) -> Generator: + """ + Fixture to set up and tear down a PostgreSQL database session for testing. + This fixture: + - Connects to the PostgreSQL container provided by `postgres_container`. + - Creates a fresh database schema before each test. + - Adds a dummy ML model for test purposes. + - Tears down the database session and cleans up resources after each test. + Args: + postgres_container (str): The dynamic connection URL provided by PostgresContainer. + Returns: + Generator: A SQLAlchemy session object. + """ + # Use the dynamic connection URL + engine = create_engine(postgres_container) + Base_Forecast.metadata.drop_all(engine) # Drop all tables to ensure a clean slate + Base_Forecast.metadata.create_all(engine) # Recreate the tables + + # Establish session + Session = sessionmaker(bind=engine) + session = Session() + + # Insert a dummy model for testing + session.query(MLModelSQL).delete() # Clean up any pre-existing data + model = MLModelSQL(name=MODEL_NAME, version=MODEL_VERSION) + session.add(model) + session.commit() + + yield session # Provide the session to the test + + # Cleanup: close session and dispose of engine + session.close() + engine.dispose() + + +@pytest.fixture(scope="session") +def test_config(): + """ + Fixture to provide shared test configuration constants. + Returns: + dict: A dictionary of test configuration values. + """ + return { + "resource_id": RESOURCE_ID, + "limit": LIMIT, + "columns": COLUMNS, + "rename_columns": RENAME_COLUMNS, + "model_name": MODEL_NAME, + "model_version": MODEL_VERSION, + } diff --git a/tests/test_fetch_data.py b/tests/test_fetch_data.py index dc7ae79..2aa6edf 100644 --- a/tests/test_fetch_data.py +++ b/tests/test_fetch_data.py @@ -26,49 +26,57 @@ """ -import pytest from neso_solar_consumer.fetch_data import fetch_data, fetch_data_using_sql -resource_id = "db6c038f-98af-4570-ab60-24d71ebd0ae5" -limit = 5 -columns = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"] -rename_columns = { - "DATE_GMT": "start_utc", - "TIME_GMT": "end_utc", - "EMBEDDED_SOLAR_FORECAST": "solar_forecast_kw", -} -sql_query = f'SELECT * from "{resource_id}" LIMIT {limit}' - - -def test_fetch_data_api(): +def test_fetch_data_api(test_config): """ Test the fetch_data function to ensure it fetches and processes data correctly via API. """ - df_api = fetch_data(resource_id, limit, columns, rename_columns) + df_api = fetch_data( + test_config["resource_id"], + test_config["limit"], + test_config["columns"], + test_config["rename_columns"], + ) assert not df_api.empty, "fetch_data returned an empty DataFrame!" assert set(df_api.columns) == set( - rename_columns.values() + test_config["rename_columns"].values() ), "Column names do not match after renaming!" -def test_fetch_data_sql(): +def test_fetch_data_sql(test_config): """ Test the fetch_data_using_sql function to ensure it fetches and processes data correctly via SQL. """ - df_sql = fetch_data_using_sql(sql_query, columns, rename_columns) + sql_query = ( + f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}' + ) + df_sql = fetch_data_using_sql( + sql_query, test_config["columns"], test_config["rename_columns"] + ) assert not df_sql.empty, "fetch_data_using_sql returned an empty DataFrame!" assert set(df_sql.columns) == set( - rename_columns.values() + test_config["rename_columns"].values() ), "Column names do not match after renaming!" -def test_data_consistency(): +def test_data_consistency(test_config): """ Validate that the data fetched by fetch_data and fetch_data_using_sql are consistent. """ - df_api = fetch_data(resource_id, limit, columns, rename_columns) - df_sql = fetch_data_using_sql(sql_query, columns, rename_columns) + sql_query = ( + f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}' + ) + df_api = fetch_data( + test_config["resource_id"], + test_config["limit"], + test_config["columns"], + test_config["rename_columns"], + ) + df_sql = fetch_data_using_sql( + sql_query, test_config["columns"], test_config["rename_columns"] + ) assert df_api.equals( df_sql ), "Data from fetch_data and fetch_data_using_sql are inconsistent!" diff --git a/tests/test_format_forecast.py b/tests/test_format_forecast.py new file mode 100644 index 0000000..64dfdff --- /dev/null +++ b/tests/test_format_forecast.py @@ -0,0 +1,40 @@ +from neso_solar_consumer.fetch_data import fetch_data +from neso_solar_consumer.format_forecast import format_to_forecast_sql +from nowcasting_datamodel.models import ForecastSQL, ForecastValue + + +def test_format_to_forecast_sql_real(db_session, test_config): + """ + Test `format_to_forecast_sql` with real data fetched from the NESO API. + """ + # Step 1: Fetch data from the API + data = fetch_data( + test_config["resource_id"], + test_config["limit"], + test_config["columns"], + test_config["rename_columns"], + ) + assert not data.empty, "fetch_data returned an empty DataFrame!" + + # Step 2: Format the data into a ForecastSQL object + forecasts = format_to_forecast_sql( + data, test_config["model_name"], test_config["model_version"], db_session + ) + assert len(forecasts) == 1, f"Expected 1 ForecastSQL object, got {len(forecasts)}" + + # Step 3: Validate the ForecastSQL content + forecast = forecasts[0] + + # Filter out invalid rows from the DataFrame (like your function would) + valid_data = data.drop_duplicates( + subset=["start_utc", "end_utc", "solar_forecast_kw"] + ) + valid_data = valid_data[ + valid_data["start_utc"].apply( + lambda x: isinstance(x, str) and len(x) > 0 + ) # Valid datetime + ] + + assert len(forecast.forecast_values) == len( + valid_data + ), f"Mismatch in ForecastValue entries! Expected {len(valid_data)} but got {len(forecast.forecast_values)}."