diff --git a/rules-engine/src/rules_engine/parser.py b/rules-engine/src/rules_engine/parser.py index cb75a647..006b8a27 100644 --- a/rules-engine/src/rules_engine/parser.py +++ b/rules-engine/src/rules_engine/parser.py @@ -123,7 +123,7 @@ def _parse_gas_bill_eversource(data: str) -> NaturalGasBillingInput: records = [] for row in reader: parsed_row = _GasBillRowEversource(row) - period_end_date = datetime.strptime(parsed_row.read_date, "%m/%d/%Y").date() + period_end_date = datetime.strptime(parsed_row.read_date, "%m/%d/%Y") # Calculate period_start_date using the end date and number of days in the bill # Care should be taken here to avoid off-by-one errors period_start_date = period_end_date - timedelta( @@ -166,8 +166,8 @@ def _parse_gas_bill_national_grid(data: str) -> NaturalGasBillingInput: for row in reader: parsed_row = _GasBillRowNationalGrid(row) - period_start_date = datetime.strptime(parsed_row.start_date, "%m/%d/%Y").date() - period_end_date = datetime.strptime(parsed_row.end_date, "%m/%d/%Y").date() + period_start_date = datetime.strptime(parsed_row.start_date, "%m/%d/%Y") + period_end_date = datetime.strptime(parsed_row.end_date, "%m/%d/%Y") record = NaturalGasBillingRecordInput( period_start_date=period_start_date, diff --git a/rules-engine/src/rules_engine/pydantic_models.py b/rules-engine/src/rules_engine/pydantic_models.py index 1677b3cf..6c98921f 100644 --- a/rules-engine/src/rules_engine/pydantic_models.py +++ b/rules-engine/src/rules_engine/pydantic_models.py @@ -3,7 +3,7 @@ """ from dataclasses import dataclass -from datetime import date +from datetime import date, datetime from enum import Enum from functools import cached_property from typing import Annotated, Any, Literal, Optional, Sequence @@ -82,7 +82,7 @@ class DhwInput(BaseModel): class OilPropaneBillingRecordInput(BaseModel): """From Oil-Propane tab""" - period_end_date: date = Field(description="Oil-Propane!B") + period_end_date: datetime = Field(description="Oil-Propane!B") gallons: float = Field(description="Oil-Propane!C") inclusion_override: Optional[bool] = Field(description="Oil-Propane!F") @@ -91,14 +91,14 @@ class OilPropaneBillingInput(BaseModel): """From Oil-Propane tab. Container for holding all rows of the billing input table.""" records: Sequence[OilPropaneBillingRecordInput] - preceding_delivery_date: date = Field(description="Oil-Propane!B6") + preceding_delivery_date: datetime = Field(description="Oil-Propane!B6") class NaturalGasBillingRecordInput(BaseModel): """From Natural Gas tab. A single row of the Billing input table.""" - period_start_date: date = Field(description="Natural Gas!A") - period_end_date: date = Field(description="Natural Gas!B") + period_start_date: datetime = Field(description="Natural Gas!A") + period_end_date: datetime = Field(description="Natural Gas!B") usage_therms: float = Field(description="Natural Gas!D") inclusion_override: Optional[bool] = Field(description="Natural Gas!E") @@ -111,7 +111,7 @@ class NaturalGasBillingInput(BaseModel): # Suppress mypy error when computed_field is used with cached_property; see https://github.com/python/mypy/issues/1362 @computed_field # type: ignore[misc] @cached_property - def overall_start_date(self) -> date: + def overall_start_date(self) -> datetime: if len(self.records) == 0: raise ValueError( "Natural gas billing records cannot be empty." @@ -119,7 +119,7 @@ def overall_start_date(self) -> date: + "Try again with non-empty natural gas billing records." ) - min_date = date.max + min_date = datetime.max for record in self.records: min_date = min(min_date, record.period_start_date) return min_date @@ -127,7 +127,7 @@ def overall_start_date(self) -> date: # Suppress mypy error when computed_field is used with cached_property; see https://github.com/python/mypy/issues/1362 @computed_field # type: ignore[misc] @cached_property - def overall_end_date(self) -> date: + def overall_end_date(self) -> datetime: if len(self.records) == 0: raise ValueError( "Natural gas billing records cannot be empty." @@ -135,7 +135,7 @@ def overall_end_date(self) -> date: + "Try again with non-empty natural gas billing records." ) - max_date = date.min + max_date = datetime.min for record in self.records: max_date = max(max_date, record.period_end_date) return max_date @@ -150,8 +150,8 @@ class NormalizedBillingPeriodRecordBase(BaseModel): model_config = ConfigDict(validate_assignment=True) - period_start_date: date = Field(frozen=True) - period_end_date: date = Field(frozen=True) + period_start_date: datetime = Field(frozen=True) + period_end_date: datetime = Field(frozen=True) usage: float = Field(frozen=True) inclusion_override: bool = Field(frozen=True) @@ -172,7 +172,7 @@ class NormalizedBillingPeriodRecord(NormalizedBillingPeriodRecordBase): class TemperatureInput(BaseModel): - dates: list[date] + dates: list[datetime] temperatures: list[float] diff --git a/rules-engine/tests/test_rules_engine/test_engine.py b/rules-engine/tests/test_rules_engine/test_engine.py index 1f29fbbe..17ae1e0f 100644 --- a/rules-engine/tests/test_rules_engine/test_engine.py +++ b/rules-engine/tests/test_rules_engine/test_engine.py @@ -1,4 +1,4 @@ -from datetime import date +from datetime import date, datetime from typing import Any import pytest @@ -18,8 +18,8 @@ ) dummy_billing_period_record = NormalizedBillingPeriodRecordBase( - period_start_date=date(2024, 1, 1), - period_end_date=date(2024, 2, 1), + period_start_date=datetime(2024, 1, 1), + period_end_date=datetime(2024, 2, 1), usage=1.0, inclusion_override=False, ) @@ -183,7 +183,10 @@ def sample_temp_inputs() -> TemperatureInput: ], } - return TemperatureInput(**temperature_dict) + return TemperatureInput( + temperatures=temperature_dict["temperatures"], + dates=[datetime.fromisoformat(x) for x in temperature_dict["dates"]], + ) @pytest.fixture() @@ -233,8 +236,18 @@ def sample_normalized_billing_periods() -> list[NormalizedBillingPeriodRecordBas }, ] + # billing_periods = [ + # NormalizedBillingPeriodRecordBase(**x) for x in billing_periods_dict + # ] + billing_periods = [ - NormalizedBillingPeriodRecordBase(**x) for x in billing_periods_dict + NormalizedBillingPeriodRecordBase( + period_start_date=datetime.fromisoformat(x["period_start_date"]), + period_end_date=datetime.fromisoformat(x["period_end_date"]), + usage=x["usage"], + inclusion_override=x["inclusion_override"], + ) + for x in billing_periods_dict ] return billing_periods diff --git a/rules-engine/tests/test_rules_engine/test_parser.py b/rules-engine/tests/test_rules_engine/test_parser.py index 6cbdedfa..6b9ee2df 100644 --- a/rules-engine/tests/test_rules_engine/test_parser.py +++ b/rules-engine/tests/test_rules_engine/test_parser.py @@ -1,5 +1,5 @@ import pathlib -from datetime import date +from datetime import date, datetime import pytest @@ -33,8 +33,8 @@ def _validate_eversource(result): # from excel: 11/19/2021,12/17/2021,29,124,,1,4.28,3.82 second_row = result.records[1] - assert second_row.period_start_date == date(2021, 11, 19) - assert second_row.period_end_date == date(2021, 12, 17) + assert second_row.period_start_date == datetime(2021, 11, 19) + assert second_row.period_end_date == datetime(2021, 12, 17) assert isinstance(second_row.usage_therms, float) assert second_row.usage_therms == 124 assert second_row.inclusion_override == None @@ -50,8 +50,8 @@ def _validate_national_grid(result): # from excel: 11/6/2020,12/3/2020,28,36,,1,1.29,0.99 second_row = result.records[1] - assert second_row.period_start_date == date(2020, 11, 5) - assert second_row.period_end_date == date(2020, 12, 3) + assert second_row.period_start_date == datetime(2020, 11, 5) + assert second_row.period_end_date == datetime(2020, 12, 3) assert isinstance(second_row.usage_therms, float) assert second_row.usage_therms == 36 assert second_row.inclusion_override == None diff --git a/rules-engine/tests/test_rules_engine/test_pydantic_models.py b/rules-engine/tests/test_rules_engine/test_pydantic_models.py index 1ca88ce9..9e2d8351 100644 --- a/rules-engine/tests/test_rules_engine/test_pydantic_models.py +++ b/rules-engine/tests/test_rules_engine/test_pydantic_models.py @@ -1,4 +1,4 @@ -from datetime import date +from datetime import date, datetime import pytest @@ -10,14 +10,14 @@ _EXAMPLE_VALID_RECORDS = NaturalGasBillingInput( records=[ NaturalGasBillingRecordInput( - period_start_date=date(2020, 1, 1), - period_end_date=date(2020, 1, 31), + period_start_date=datetime(2020, 1, 1), + period_end_date=datetime(2020, 1, 31), usage_therms=10, inclusion_override=None, ), NaturalGasBillingRecordInput( - period_start_date=date(2020, 2, 1), - period_end_date=date(2020, 2, 28), + period_start_date=datetime(2020, 2, 1), + period_end_date=datetime(2020, 2, 28), usage_therms=10, inclusion_override=None, ), @@ -30,14 +30,14 @@ def test_natural_gas_billing_input_overall_start_date(): - expected_overall_start_date = date(2020, 1, 1) + expected_overall_start_date = datetime(2020, 1, 1) actual_overall_start_date = _EXAMPLE_VALID_RECORDS.overall_start_date assert expected_overall_start_date == actual_overall_start_date def test_natural_gas_billing_input_overall_end_date(): - expected_overall_end_date = date(2020, 2, 28) + expected_overall_end_date = datetime(2020, 2, 28) actual_overall_end_date = _EXAMPLE_VALID_RECORDS.overall_end_date assert expected_overall_end_date == actual_overall_end_date diff --git a/rules-engine/tests/test_rules_engine/test_utils.py b/rules-engine/tests/test_rules_engine/test_utils.py index be42d2c0..57c622f4 100644 --- a/rules-engine/tests/test_rules_engine/test_utils.py +++ b/rules-engine/tests/test_rules_engine/test_utils.py @@ -160,8 +160,8 @@ def load_fuel_billing_example_input( raise ValueError("Unsupported fuel type.") -def _parse_date(value: str) -> date: - return datetime.strptime(value.split(maxsplit=1)[0], "%Y-%m-%d").date() +def _parse_date(value: str) -> datetime: + return datetime.strptime(value.split(maxsplit=1)[0], "%Y-%m-%d") def load_temperature_data(path: Path, weather_station: str) -> TemperatureInput: @@ -172,7 +172,7 @@ def load_temperature_data(path: Path, weather_station: str) -> TemperatureInput: row: Any for row in reader: - dates.append(datetime.strptime(row["Date"], "%Y-%m-%d").date()) + dates.append(datetime.strptime(row["Date"], "%Y-%m-%d")) temperatures.append(row[weather_station]) return TemperatureInput(dates=dates, temperatures=temperatures)