Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Formatting and Saving Forecast Data with Tests #24

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions neso_solar_consumer/fetch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
import pandas as pd


def fetch_data(
resource_id: str, limit: int, columns: list, rename_columns: dict
) -> pd.DataFrame:
def fetch_data(resource_id: str, limit: int, columns: list, rename_columns: dict) -> pd.DataFrame:
"""
Fetch data from the NESO API and process it into a Pandas DataFrame.

Expand Down Expand Up @@ -55,9 +53,7 @@ def fetch_data(
return pd.DataFrame()


def fetch_data_using_sql(
sql_query: str, columns: list, rename_columns: dict
) -> pd.DataFrame:
def fetch_data_using_sql(sql_query: str, columns: list, rename_columns: dict) -> pd.DataFrame:
"""
Fetch data from the NESO API using an SQL query, process it, and return specific columns with renamed headers.

Expand Down Expand Up @@ -98,4 +94,4 @@ def fetch_data_using_sql(

except Exception as e:
print(f"An error occurred: {e}")
return pd.DataFrame()
return pd.DataFrame()
71 changes: 71 additions & 0 deletions neso_solar_consumer/format_forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import logging
from datetime import datetime, timezone
import pandas as pd
from nowcasting_datamodel.models import ForecastSQL, ForecastValue
from nowcasting_datamodel.read.read import get_latest_input_data_last_updated, get_location
from nowcasting_datamodel.read.read_models import get_model

# Configure logging (set to INFO for production; use DEBUG during debugging)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)


def format_to_forecast_sql(data: pd.DataFrame, model_tag: str, model_version: str, session) -> list:
"""
Convert NESO solar forecast data into a single ForecastSQL object.

Args:
data (pd.DataFrame): Input DataFrame with forecast data.
model_tag (str): The model name/tag.
model_version (str): The model version.
session: SQLAlchemy session for database access.

Returns:
list: A list containing one ForecastSQL object.
"""
logger.info("Starting format_to_forecast_sql process...")

# Step 1: Retrieve model metadata
model = get_model(name=model_tag, version=model_version, session=session)
logger.debug(f"Model Retrieved: {model}")

# Step 2: Fetch input data last updated timestamp
input_data_last_updated = get_latest_input_data_last_updated(session=session)
logger.debug(f"Input Data Last Updated: {input_data_last_updated}")

# Step 3: Fetch or create the location using get_location
location = get_location(session=session, gsp_id=0)
logger.debug(f"Location Retrieved or Created: {location}")

# Step 4: Process rows into ForecastValue objects
forecast_values = []
for _, row in data.iterrows():
if pd.isnull(row["start_utc"]) or pd.isnull(row["solar_forecast_kw"]):
logger.debug("Skipping row due to missing data")
continue

target_time = datetime.fromisoformat(row["start_utc"]).replace(tzinfo=timezone.utc)
forecast_value = ForecastValue(
target_time=target_time,
expected_power_generation_megawatts=row["solar_forecast_kw"],
).to_orm()
forecast_values.append(forecast_value)
logger.debug(f"Forecast Value Created: {forecast_value}")

if not forecast_values:
logger.warning("No valid forecast values found in the data. Exiting.")
return []

# Step 5: Create a single ForecastSQL object
forecast = ForecastSQL(
model=model,
forecast_creation_time=datetime.now(tz=timezone.utc),
location=location, # Directly using the location from get_location
input_data_last_updated=input_data_last_updated,
forecast_values=forecast_values,
historic=False,
)
logger.debug(f"ForecastSQL Object Created: {forecast}")

logger.info("ForecastSQL object successfully added to session and flushed.")
siddharth7113 marked this conversation as resolved.
Show resolved Hide resolved
return [forecast]
18 changes: 18 additions & 0 deletions neso_solar_consumer/save_forecasts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# save_forecasts.py

from nowcasting_datamodel.save.save import save


def save_forecasts_to_db(forecasts: list, session):
"""
Save a list of ForecastSQL objects to the database using the nowcasting_datamodel `save` function.

Parameters:
forecasts (list): The list of ForecastSQL objects to save.
session (Session): SQLAlchemy session for database access.
"""
save(
forecasts=forecasts,
session=session,
save_to_last_seven_days=True, # Save forecasts to the last seven days table
)
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ build-backend = "setuptools.build_meta"
name = "neso_solar_consumer"
version = "0.1"
dependencies = [
"pandas"
"pandas","sqlalchemy","nowcasting_datamodel"
]

[project.optional-dependencies]
dev = [
"pytest","black","ruff"
]
"pytest", "black", "ruff"
]
72 changes: 72 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from typing import Generator
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from nowcasting_datamodel.models.base import Base_Forecast
from nowcasting_datamodel.models import MLModelSQL

# Shared Test Configuration Constants
TEST_DB_URL = "postgresql://postgres:12345@localhost/testdb"
RESOURCE_ID = "db6c038f-98af-4570-ab60-24d71ebd0ae5"
LIMIT = 5
COLUMNS = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"]
RENAME_COLUMNS = {
"DATE_GMT": "start_utc",
"TIME_GMT": "end_utc",
"EMBEDDED_SOLAR_FORECAST": "solar_forecast_kw",
}
MODEL_NAME = "real_data_model"
MODEL_VERSION = "1.0"


@pytest.fixture(scope="function")
def db_session() -> Generator:
"""
Fixture to set up and tear down a PostgreSQL database session for testing.

This fixture:
- Creates a fresh database schema before each test.
- Adds a dummy ML model for test purposes.
- Tears down the database session and cleans up resources after each test.

Returns:
Generator: A SQLAlchemy session object.
"""
# Create database engine and tables
engine = create_engine(TEST_DB_URL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should still do with PostgresContainer("postgres:15.5") as postgres: see other examples here https://github.com/openclimatefix/pv-site-datamodel/blob/main/tests/conftest.py#L26

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @peterdudfield,I’m still looking into this and will make the necessary changes. I’m sorry if it’s taking too long—would it be okay to give me a few more days to finalize it? I’ll update you as soon as it’s done.

Base_Forecast.metadata.drop_all(engine) # Drop all tables to ensure a clean slate
Base_Forecast.metadata.create_all(engine) # Recreate the tables

# Establish session
Session = sessionmaker(bind=engine)
session = Session()

# Insert a dummy model for testing
session.query(MLModelSQL).delete() # Clean up any pre-existing data
model = MLModelSQL(name=MODEL_NAME, version=MODEL_VERSION)
session.add(model)
session.commit()

yield session # Provide the session to the test

# Cleanup: close session and dispose of engine
session.close()
engine.dispose()


@pytest.fixture(scope="session")
def test_config():
"""
Fixture to provide shared test configuration constants.

Returns:
dict: A dictionary of test configuration values.
"""
return {
"resource_id": RESOURCE_ID,
"limit": LIMIT,
"columns": COLUMNS,
"rename_columns": RENAME_COLUMNS,
"model_name": MODEL_NAME,
"model_version": MODEL_VERSION,
}
47 changes: 22 additions & 25 deletions tests/test_fetch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,50 +25,47 @@
pytest tests/test_fetch_data.py -k "fetch_data"

"""

import pytest
from neso_solar_consumer.fetch_data import fetch_data, fetch_data_using_sql


resource_id = "db6c038f-98af-4570-ab60-24d71ebd0ae5"
limit = 5
columns = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"]
rename_columns = {
"DATE_GMT": "start_utc",
"TIME_GMT": "end_utc",
"EMBEDDED_SOLAR_FORECAST": "solar_forecast_kw",
}
sql_query = f'SELECT * from "{resource_id}" LIMIT {limit}'


def test_fetch_data_api():
def test_fetch_data_api(test_config):
"""
Test the fetch_data function to ensure it fetches and processes data correctly via API.
"""
df_api = fetch_data(resource_id, limit, columns, rename_columns)
df_api = fetch_data(
test_config["resource_id"],
test_config["limit"],
test_config["columns"],
test_config["rename_columns"],
)
assert not df_api.empty, "fetch_data returned an empty DataFrame!"
assert set(df_api.columns) == set(
rename_columns.values()
test_config["rename_columns"].values()
), "Column names do not match after renaming!"


def test_fetch_data_sql():
def test_fetch_data_sql(test_config):
"""
Test the fetch_data_using_sql function to ensure it fetches and processes data correctly via SQL.
"""
df_sql = fetch_data_using_sql(sql_query, columns, rename_columns)
sql_query = f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}'
df_sql = fetch_data_using_sql(sql_query, test_config["columns"], test_config["rename_columns"])
assert not df_sql.empty, "fetch_data_using_sql returned an empty DataFrame!"
assert set(df_sql.columns) == set(
rename_columns.values()
test_config["rename_columns"].values()
), "Column names do not match after renaming!"


def test_data_consistency():
def test_data_consistency(test_config):
"""
Validate that the data fetched by fetch_data and fetch_data_using_sql are consistent.
"""
df_api = fetch_data(resource_id, limit, columns, rename_columns)
df_sql = fetch_data_using_sql(sql_query, columns, rename_columns)
assert df_api.equals(
df_sql
), "Data from fetch_data and fetch_data_using_sql are inconsistent!"
sql_query = f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}'
df_api = fetch_data(
test_config["resource_id"],
test_config["limit"],
test_config["columns"],
test_config["rename_columns"],
)
df_sql = fetch_data_using_sql(sql_query, test_config["columns"], test_config["rename_columns"])
assert df_api.equals(df_sql), "Data from fetch_data and fetch_data_using_sql are inconsistent!"
27 changes: 27 additions & 0 deletions tests/test_format_forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from neso_solar_consumer.fetch_data import fetch_data
from neso_solar_consumer.format_forecast import format_to_forecast_sql
from nowcasting_datamodel.models import ForecastSQL, ForecastValue


def test_format_to_forecast_sql_real(db_session, test_config):
"""
Test `format_to_forecast_sql` with real data fetched from the NESO API.
"""
# Step 1: Fetch mock data from the API
data = fetch_data(
test_config["resource_id"],
test_config["limit"],
test_config["columns"],
test_config["rename_columns"],
)
assert not data.empty, "fetch_data returned an empty DataFrame!"

# Step 2: Format the data into a ForecastSQL object
forecasts = format_to_forecast_sql(
data, test_config["model_name"], test_config["model_version"], db_session
)
assert len(forecasts) == 1, "More than one ForecastSQL object was created!"

# Step 3: Validate ForecastSQL content
forecast = forecasts[0]
assert len(forecast.forecast_values) == len(data), "Mismatch in ForecastValue entries!"
51 changes: 51 additions & 0 deletions tests/test_save_forecasts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Integration Test for Fetching, Formatting, and Saving Forecast Data

This script validates the integration of fetching real data, formatting it into ForecastSQL objects,
and saving it to the database.
"""

import pytest
from nowcasting_datamodel.models import ForecastSQL
from neso_solar_consumer.fetch_data import fetch_data
from neso_solar_consumer.format_forecast import format_to_forecast_sql
from neso_solar_consumer.save_forecasts import save_forecasts_to_db


def test_save_real_forecasts(db_session, test_config):
"""
Integration test: Fetch real data, format it into forecasts, and save to the database.

Steps:
1. Fetch real data from the NESO API.
2. Format the data into ForecastSQL objects.
3. Save the forecasts to the database.
4. Verify that the forecasts are correctly saved.
"""
# Step 1: Fetch real data
df = fetch_data(
test_config["resource_id"],
test_config["limit"],
test_config["columns"],
test_config["rename_columns"],
)
assert not df.empty, "fetch_data returned an empty DataFrame!"

# Step 2: Format data into ForecastSQL objects
forecasts = format_to_forecast_sql(
df, test_config["model_name"], test_config["model_version"], db_session
)
assert forecasts, "No forecasts were generated from the fetched data!"

# Step 3: Save forecasts to the database
save_forecasts_to_db(forecasts, db_session)

# Step 4: Verify forecasts are saved in the database
saved_forecast = db_session.query(ForecastSQL).first()
assert saved_forecast is not None, "No forecast was saved to the database!"
assert saved_forecast.model.name == test_config["model_name"], "Model name does not match!"
assert len(saved_forecast.forecast_values) > 0, "No forecast values were saved!"

# Debugging Output (Optional)
print("Forecast successfully saved to the database.")
print(f"Number of forecast values: {len(saved_forecast.forecast_values)}")
Loading