Skip to content

Commit

Permalink
style: apply Black formatting to ensure consistent code style
Browse files Browse the repository at this point in the history
Ran Black on the codebase to standardize formatting for this PR. No functional changes were made.
  • Loading branch information
siddharth7113 committed Jan 2, 2025
1 parent c9bab47 commit f4ef298
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 19 deletions.
6 changes: 5 additions & 1 deletion neso_solar_consumer/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def get_forecast():
"""
resource_id = "example_resource_id" # Replace with the actual resource ID.
limit = 100 # Number of records to fetch.
columns = ["DATE_GMT", "TIME_GMT", "EMBEDDED_SOLAR_FORECAST"] # Relevant columns to extract.
columns = [
"DATE_GMT",
"TIME_GMT",
"EMBEDDED_SOLAR_FORECAST",
] # Relevant columns to extract.
rename_columns = {
"DATE_GMT": "start_utc",
"TIME_GMT": "end_utc",
Expand Down
2 changes: 1 addition & 1 deletion neso_solar_consumer/fetch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@ def fetch_data_using_sql(

except Exception as e:
print(f"An error occurred: {e}")
return pd.DataFrame()
return pd.DataFrame()
24 changes: 18 additions & 6 deletions neso_solar_consumer/format_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@
from datetime import datetime, timezone
import pandas as pd
from nowcasting_datamodel.models import ForecastSQL, ForecastValue
from nowcasting_datamodel.read.read import get_latest_input_data_last_updated, get_location
from nowcasting_datamodel.read.read import (
get_latest_input_data_last_updated,
get_location,
)
from nowcasting_datamodel.read.read_models import get_model

# Configure logging (set to INFO for production; use DEBUG during debugging)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)


def format_to_forecast_sql(data: pd.DataFrame, model_tag: str, model_version: str, session) -> list:
def format_to_forecast_sql(
data: pd.DataFrame, model_tag: str, model_version: str, session
) -> list:
logger.info("Starting format_to_forecast_sql process...")

# Step 1: Retrieve model metadata
Expand All @@ -28,14 +35,19 @@ def format_to_forecast_sql(data: pd.DataFrame, model_tag: str, model_version: st
continue

try:
target_time = datetime.fromisoformat(row["start_utc"]).replace(tzinfo=timezone.utc)
target_time = datetime.fromisoformat(row["start_utc"]).replace(
tzinfo=timezone.utc
)
except ValueError:
logger.warning(f"Invalid datetime format: {row['start_utc']}. Skipping row.")
logger.warning(
f"Invalid datetime format: {row['start_utc']}. Skipping row."
)
continue

forecast_value = ForecastValue(
target_time=target_time,
expected_power_generation_megawatts=row["solar_forecast_kw"] / 1000, # Convert to MW
expected_power_generation_megawatts=row["solar_forecast_kw"]
/ 1000, # Convert to MW
).to_orm()
forecast_values.append(forecast_value)

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ def test_config():
"rename_columns": RENAME_COLUMNS,
"model_name": MODEL_NAME,
"model_version": MODEL_VERSION,
}
}
21 changes: 16 additions & 5 deletions tests/test_fetch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
pytest tests/test_fetch_data.py -k "fetch_data"
"""

from neso_solar_consumer.fetch_data import fetch_data, fetch_data_using_sql


Expand All @@ -48,8 +49,12 @@ def test_fetch_data_sql(test_config):
"""
Test the fetch_data_using_sql function to ensure it fetches and processes data correctly via SQL.
"""
sql_query = f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}'
df_sql = fetch_data_using_sql(sql_query, test_config["columns"], test_config["rename_columns"])
sql_query = (
f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}'
)
df_sql = fetch_data_using_sql(
sql_query, test_config["columns"], test_config["rename_columns"]
)
assert not df_sql.empty, "fetch_data_using_sql returned an empty DataFrame!"
assert set(df_sql.columns) == set(
test_config["rename_columns"].values()
Expand All @@ -60,12 +65,18 @@ def test_data_consistency(test_config):
"""
Validate that the data fetched by fetch_data and fetch_data_using_sql are consistent.
"""
sql_query = f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}'
sql_query = (
f'SELECT * FROM "{test_config["resource_id"]}" LIMIT {test_config["limit"]}'
)
df_api = fetch_data(
test_config["resource_id"],
test_config["limit"],
test_config["columns"],
test_config["rename_columns"],
)
df_sql = fetch_data_using_sql(sql_query, test_config["columns"], test_config["rename_columns"])
assert df_api.equals(df_sql), "Data from fetch_data and fetch_data_using_sql are inconsistent!"
df_sql = fetch_data_using_sql(
sql_query, test_config["columns"], test_config["rename_columns"]
)
assert df_api.equals(
df_sql
), "Data from fetch_data and fetch_data_using_sql are inconsistent!"
14 changes: 9 additions & 5 deletions tests/test_format_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ def test_format_to_forecast_sql_real(db_session, test_config):
forecast = forecasts[0]

# Filter out invalid rows from the DataFrame (like your function would)
valid_data = data.drop_duplicates(subset=["start_utc", "end_utc", "solar_forecast_kw"])
valid_data = data.drop_duplicates(
subset=["start_utc", "end_utc", "solar_forecast_kw"]
)
valid_data = valid_data[
valid_data["start_utc"].apply(lambda x: isinstance(x, str) and len(x) > 0) # Valid datetime
valid_data["start_utc"].apply(
lambda x: isinstance(x, str) and len(x) > 0
) # Valid datetime
]

assert len(forecast.forecast_values) == len(valid_data), (
f"Mismatch in ForecastValue entries! Expected {len(valid_data)} but got {len(forecast.forecast_values)}."
)
assert len(forecast.forecast_values) == len(
valid_data
), f"Mismatch in ForecastValue entries! Expected {len(valid_data)} but got {len(forecast.forecast_values)}."

0 comments on commit f4ef298

Please sign in to comment.