Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Formatting Data with Tests #25

Merged
merged 7 commits into from
Jan 3, 2025
6 changes: 5 additions & 1 deletion neso_solar_consumer/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def get_forecast():
"""
resource_id = "example_resource_id" # Replace with the actual resource ID.
limit = 100 # Number of records to fetch.
columns = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"] # Relevant columns to extract.
columns = [
"DATE_GMT",
"TIME_GMT",
"EMBEDDED_SOLAR_FORECAST",
] # Relevant columns to extract.
rename_columns = {
"DATE_GMT": "start_utc",
"TIME_GMT": "end_utc",
Expand Down
2 changes: 1 addition & 1 deletion neso_solar_consumer/fetch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@ def fetch_data_using_sql(

except Exception as e:
print(f"An error occurred: {e}")
return pd.DataFrame()
return pd.DataFrame()
66 changes: 66 additions & 0 deletions neso_solar_consumer/format_forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import logging
from datetime import datetime, timezone
import pandas as pd
from nowcasting_datamodel.models import ForecastSQL, ForecastValue
from nowcasting_datamodel.read.read import (
get_latest_input_data_last_updated,
get_location,
)
from nowcasting_datamodel.read.read_models import get_model

# Configure logging (set to INFO for production; use DEBUG during debugging)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)


def format_to_forecast_sql(
data: pd.DataFrame, model_tag: str, model_version: str, session
) -> list:
logger.info("Starting format_to_forecast_sql process...")

# Step 1: Retrieve model metadata
model = get_model(name=model_tag, version=model_version, session=session)
input_data_last_updated = get_latest_input_data_last_updated(session=session)

# Step 2: Fetch or create the location
location = get_location(session=session, gsp_id=0) # National forecast

# Step 3: Process all rows into ForecastValue objects
forecast_values = []
for _, row in data.iterrows():
if pd.isnull(row["start_utc"]) or pd.isnull(row["solar_forecast_kw"]):
logger.warning(f"Skipping row due to missing data: {row}")
continue

try:
target_time = datetime.fromisoformat(row["start_utc"]).replace(
tzinfo=timezone.utc
)
except ValueError:
logger.warning(
f"Invalid datetime format: {row['start_utc']}. Skipping row."
)
continue

forecast_value = ForecastValue(
target_time=target_time,
expected_power_generation_megawatts=row["solar_forecast_kw"]
/ 1000, # Convert to MW
).to_orm()
forecast_values.append(forecast_value)

# Step 4: Create a single ForecastSQL object
forecast = ForecastSQL(
model=model,
forecast_creation_time=datetime.now(tz=timezone.utc),
location=location,
input_data_last_updated=input_data_last_updated,
forecast_values=forecast_values,
historic=False,
)
logger.info(f"Created ForecastSQL object with {len(forecast_values)} values.")

# Return a single ForecastSQL object in a list
return [forecast]
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@ build-backend = "setuptools.build_meta"
name = "neso_solar_consumer"
version = "0.1"
dependencies = [
"pandas"
"pandas",
"sqlalchemy",
"nowcasting_datamodel==1.5.56",
"testcontainers"
]

[project.optional-dependencies]
dev = [
"pytest","black","ruff"
"pytest",
"black",
"ruff"
]
84 changes: 84 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest
from typing import Generator
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from nowcasting_datamodel.models.base import Base_Forecast
from nowcasting_datamodel.models import MLModelSQL
from testcontainers.postgres import PostgresContainer

# Shared Test Configuration Constants
RESOURCE_ID = "db6c038f-98af-4570-ab60-24d71ebd0ae5"
LIMIT = 5
COLUMNS = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"]
RENAME_COLUMNS = {
"DATE_GMT": "start_utc",
"TIME_GMT": "end_utc",
"EMBEDDED_SOLAR_FORECAST": "solar_forecast_kw",
}
MODEL_NAME = "real_data_model"
MODEL_VERSION = "1.0"


@pytest.fixture(scope="session")
def postgres_container():
"""
Fixture to spin up a PostgreSQL container for the entire test session.
This fixture uses `testcontainers` to start a fresh PostgreSQL container and provides
the connection URL dynamically for use in other fixtures.
"""
with PostgresContainer("postgres:15.5") as postgres:
postgres.start()
yield postgres.get_connection_url()


@pytest.fixture(scope="function")
def db_session(postgres_container) -> Generator:
"""
Fixture to set up and tear down a PostgreSQL database session for testing.
This fixture:
- Connects to the PostgreSQL container provided by `postgres_container`.
- Creates a fresh database schema before each test.
- Adds a dummy ML model for test purposes.
- Tears down the database session and cleans up resources after each test.
Args:
postgres_container (str): The dynamic connection URL provided by PostgresContainer.
Returns:
Generator: A SQLAlchemy session object.
"""
# Use the dynamic connection URL
engine = create_engine(postgres_container)
Base_Forecast.metadata.drop_all(engine) # Drop all tables to ensure a clean slate
Base_Forecast.metadata.create_all(engine) # Recreate the tables

# Establish session
Session = sessionmaker(bind=engine)
session = Session()

# Insert a dummy model for testing
session.query(MLModelSQL).delete() # Clean up any pre-existing data
model = MLModelSQL(name=MODEL_NAME, version=MODEL_VERSION)
session.add(model)
session.commit()

yield session # Provide the session to the test

# Cleanup: close session and dispose of engine
session.close()
engine.dispose()


@pytest.fixture(scope="session")
def test_config():
"""
Fixture to provide shared test configuration constants.
Returns:
dict: A dictionary of test configuration values.
"""
return {
"resource_id": RESOURCE_ID,
"limit": LIMIT,
"columns": COLUMNS,
"rename_columns": RENAME_COLUMNS,
"model_name": MODEL_NAME,
"model_version": MODEL_VERSION,
}
50 changes: 29 additions & 21 deletions tests/test_fetch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,49 +26,57 @@

"""

import pytest
from neso_solar_consumer.fetch_data import fetch_data, fetch_data_using_sql


resource_id = "db6c038f-98af-4570-ab60-24d71ebd0ae5"
limit = 5
columns = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"]
rename_columns = {
"DATE_GMT": "start_utc",
"TIME_GMT": "end_utc",
"EMBEDDED_SOLAR_FORECAST": "solar_forecast_kw",
}
sql_query = f'SELECT * from "{resource_id}" LIMIT {limit}'


def test_fetch_data_api():
def test_fetch_data_api(test_config):
"""
Test the fetch_data function to ensure it fetches and processes data correctly via API.
"""
df_api = fetch_data(resource_id, limit, columns, rename_columns)
df_api = fetch_data(
test_config["resource_id"],
test_config["limit"],
test_config["columns"],
test_config["rename_columns"],
)
assert not df_api.empty, "fetch_data returned an empty DataFrame!"
assert set(df_api.columns) == set(
rename_columns.values()
test_config["rename_columns"].values()
), "Column names do not match after renaming!"


def test_fetch_data_sql():
def test_fetch_data_sql(test_config):
"""
Test the fetch_data_using_sql function to ensure it fetches and processes data correctly via SQL.
"""
df_sql = fetch_data_using_sql(sql_query, columns, rename_columns)
sql_query = (
f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}'
)
df_sql = fetch_data_using_sql(
sql_query, test_config["columns"], test_config["rename_columns"]
)
assert not df_sql.empty, "fetch_data_using_sql returned an empty DataFrame!"
assert set(df_sql.columns) == set(
rename_columns.values()
test_config["rename_columns"].values()
), "Column names do not match after renaming!"


def test_data_consistency():
def test_data_consistency(test_config):
"""
Validate that the data fetched by fetch_data and fetch_data_using_sql are consistent.
"""
df_api = fetch_data(resource_id, limit, columns, rename_columns)
df_sql = fetch_data_using_sql(sql_query, columns, rename_columns)
sql_query = (
f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}'
)
df_api = fetch_data(
test_config["resource_id"],
test_config["limit"],
test_config["columns"],
test_config["rename_columns"],
)
df_sql = fetch_data_using_sql(
sql_query, test_config["columns"], test_config["rename_columns"]
)
assert df_api.equals(
df_sql
), "Data from fetch_data and fetch_data_using_sql are inconsistent!"
40 changes: 40 additions & 0 deletions tests/test_format_forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from neso_solar_consumer.fetch_data import fetch_data
from neso_solar_consumer.format_forecast import format_to_forecast_sql
from nowcasting_datamodel.models import ForecastSQL, ForecastValue


def test_format_to_forecast_sql_real(db_session, test_config):
"""
Test `format_to_forecast_sql` with real data fetched from the NESO API.
"""
# Step 1: Fetch data from the API
data = fetch_data(
test_config["resource_id"],
test_config["limit"],
test_config["columns"],
test_config["rename_columns"],
)
assert not data.empty, "fetch_data returned an empty DataFrame!"

# Step 2: Format the data into a ForecastSQL object
forecasts = format_to_forecast_sql(
data, test_config["model_name"], test_config["model_version"], db_session
)
assert len(forecasts) == 1, f"Expected 1 ForecastSQL object, got {len(forecasts)}"

# Step 3: Validate the ForecastSQL content
forecast = forecasts[0]

# Filter out invalid rows from the DataFrame (like your function would)
valid_data = data.drop_duplicates(
subset=["start_utc", "end_utc", "solar_forecast_kw"]
)
valid_data = valid_data[
valid_data["start_utc"].apply(
lambda x: isinstance(x, str) and len(x) > 0
) # Valid datetime
]

assert len(forecast.forecast_values) == len(
valid_data
), f"Mismatch in ForecastValue entries! Expected {len(valid_data)} but got {len(forecast.forecast_values)}."
Loading