diff --git a/neso_solar_consumer/app.py b/neso_solar_consumer/app.py index b0a7761..cdb2ce2 100644 --- a/neso_solar_consumer/app.py +++ b/neso_solar_consumer/app.py @@ -30,7 +30,11 @@ def get_forecast(): """ resource_id = "example_resource_id" # Replace with the actual resource ID. limit = 100 # Number of records to fetch. - columns = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"] # Relevant columns to extract. + columns = [ + "DATE_GMT", + "TIME_GMT", + "EMBEDDED_SOLAR_FORECAST", + ] # Relevant columns to extract. rename_columns = { "DATE_GMT": "start_utc", "TIME_GMT": "end_utc", diff --git a/neso_solar_consumer/fetch_data.py b/neso_solar_consumer/fetch_data.py index 0267af8..c2dd8b0 100644 --- a/neso_solar_consumer/fetch_data.py +++ b/neso_solar_consumer/fetch_data.py @@ -98,4 +98,4 @@ def fetch_data_using_sql( except Exception as e: print(f"An error occurred: {e}") - return pd.DataFrame() \ No newline at end of file + return pd.DataFrame() diff --git a/neso_solar_consumer/format_forecast.py b/neso_solar_consumer/format_forecast.py index f1348a9..64c6fda 100644 --- a/neso_solar_consumer/format_forecast.py +++ b/neso_solar_consumer/format_forecast.py @@ -2,15 +2,22 @@ from datetime import datetime, timezone import pandas as pd from nowcasting_datamodel.models import ForecastSQL, ForecastValue -from nowcasting_datamodel.read.read import get_latest_input_data_last_updated, get_location +from nowcasting_datamodel.read.read import ( + get_latest_input_data_last_updated, + get_location, +) from nowcasting_datamodel.read.read_models import get_model # Configure logging (set to INFO for production; use DEBUG during debugging) -logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") +logging.basicConfig( + level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" +) logger = logging.getLogger(__name__) -def format_to_forecast_sql(data: pd.DataFrame, model_tag: str, model_version: str, session) -> list: +def format_to_forecast_sql( + data: pd.DataFrame, model_tag: str, model_version: str, session +) -> list: logger.info("Starting format_to_forecast_sql process...") # Step 1: Retrieve model metadata @@ -28,14 +35,19 @@ def format_to_forecast_sql(data: pd.DataFrame, model_tag: str, model_version: st continue try: - target_time = datetime.fromisoformat(row["start_utc"]).replace(tzinfo=timezone.utc) + target_time = datetime.fromisoformat(row["start_utc"]).replace( + tzinfo=timezone.utc + ) except ValueError: - logger.warning(f"Invalid datetime format: {row['start_utc']}. Skipping row.") + logger.warning( + f"Invalid datetime format: {row['start_utc']}. Skipping row." + ) continue forecast_value = ForecastValue( target_time=target_time, - expected_power_generation_megawatts=row["solar_forecast_kw"] / 1000, # Convert to MW + expected_power_generation_megawatts=row["solar_forecast_kw"] + / 1000, # Convert to MW ).to_orm() forecast_values.append(forecast_value) diff --git a/tests/conftest.py b/tests/conftest.py index 1ba90d1..eef84f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,4 +81,4 @@ def test_config(): "rename_columns": RENAME_COLUMNS, "model_name": MODEL_NAME, "model_version": MODEL_VERSION, - } \ No newline at end of file + } diff --git a/tests/test_fetch_data.py b/tests/test_fetch_data.py index e641a7d..2aa6edf 100644 --- a/tests/test_fetch_data.py +++ b/tests/test_fetch_data.py @@ -25,6 +25,7 @@ pytest tests/test_fetch_data.py -k "fetch_data" """ + from neso_solar_consumer.fetch_data import fetch_data, fetch_data_using_sql @@ -48,8 +49,12 @@ def test_fetch_data_sql(test_config): """ Test the fetch_data_using_sql function to ensure it fetches and processes data correctly via SQL. """ - sql_query = f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}' - df_sql = fetch_data_using_sql(sql_query, test_config["columns"], test_config["rename_columns"]) + sql_query = ( + f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}' + ) + df_sql = fetch_data_using_sql( + sql_query, test_config["columns"], test_config["rename_columns"] + ) assert not df_sql.empty, "fetch_data_using_sql returned an empty DataFrame!" assert set(df_sql.columns) == set( test_config["rename_columns"].values() @@ -60,12 +65,18 @@ def test_data_consistency(test_config): """ Validate that the data fetched by fetch_data and fetch_data_using_sql are consistent. """ - sql_query = f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}' + sql_query = ( + f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}' + ) df_api = fetch_data( test_config["resource_id"], test_config["limit"], test_config["columns"], test_config["rename_columns"], ) - df_sql = fetch_data_using_sql(sql_query, test_config["columns"], test_config["rename_columns"]) - assert df_api.equals(df_sql), "Data from fetch_data and fetch_data_using_sql are inconsistent!" \ No newline at end of file + df_sql = fetch_data_using_sql( + sql_query, test_config["columns"], test_config["rename_columns"] + ) + assert df_api.equals( + df_sql + ), "Data from fetch_data and fetch_data_using_sql are inconsistent!" diff --git a/tests/test_format_forecast.py b/tests/test_format_forecast.py index 1a55b99..64dfdff 100644 --- a/tests/test_format_forecast.py +++ b/tests/test_format_forecast.py @@ -26,11 +26,15 @@ def test_format_to_forecast_sql_real(db_session, test_config): forecast = forecasts[0] # Filter out invalid rows from the DataFrame (like your function would) - valid_data = data.drop_duplicates(subset=["start_utc", "end_utc", "solar_forecast_kw"]) + valid_data = data.drop_duplicates( + subset=["start_utc", "end_utc", "solar_forecast_kw"] + ) valid_data = valid_data[ - valid_data["start_utc"].apply(lambda x: isinstance(x, str) and len(x) > 0) # Valid datetime + valid_data["start_utc"].apply( + lambda x: isinstance(x, str) and len(x) > 0 + ) # Valid datetime ] - assert len(forecast.forecast_values) == len(valid_data), ( - f"Mismatch in ForecastValue entries! Expected {len(valid_data)} but got {len(forecast.forecast_values)}." - ) + assert len(forecast.forecast_values) == len( + valid_data + ), f"Mismatch in ForecastValue entries! Expected {len(valid_data)} but got {len(forecast.forecast_values)}."