Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Pydantic from_orm to model_validate (#284) #285

Merged
merged 9 commits into from
Jul 11, 2024
6 changes: 4 additions & 2 deletions nowcasting_datamodel/models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def convert_list_forecast_value_seven_days_sql_to_list_forecast(
for forecast_value_sql in forecast_values_sql:
gsp_id = forecast_value_sql.forecast.location.gsp_id

forecast_value: ForecastValue = ForecastValue.from_orm(forecast_value_sql)
forecast_value: ForecastValue = ForecastValue.model_validate(
forecast_value_sql, from_attributes=True
)

if gsp_id in forecasts_by_gsp.keys():
forecasts_by_gsp[gsp_id].forecast_values.append(forecast_value)
Expand All @@ -50,7 +52,7 @@ def convert_list_forecast_value_seven_days_sql_to_list_forecast(
forecast_values=[forecast_value_sql],
historic=forecast_value_sql.forecast.historic,
)
forecast = Forecast.from_orm(forecast)
forecast = Forecast.model_validate(forecast, from_attributes=True)
forecasts_by_gsp[gsp_id] = forecast

forecasts = [forecast for forecast in forecasts_by_gsp.values()]
Expand Down
77 changes: 66 additions & 11 deletions nowcasting_datamodel/models/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,28 @@ def to_orm(self) -> ForecastValueSQL:
@classmethod
def from_orm(cls, obj: ForecastValueSQL):
"""Make sure _adjust_mw is transfered also"""
m = super().from_orm(obj=obj)
m = super().model_validate(obj=obj, from_attributes=True)

# this is because from orm doesnt copy over '_' variables.
# But we don't want to expose this in the API
default_value = 0.0
if hasattr(obj, "adjust_mw"):
adjust_mw = obj.adjust_mw
if not adjust_mw or np.isnan(adjust_mw):
adjust_mw = default_value
m._adjust_mw = adjust_mw
else:
m._adjust_mw = default_value

if hasattr(obj, "properties"):
m._properties = obj.properties

return m

@classmethod
def model_validate(cls, obj: ForecastValueSQL, from_attributes: bool | None = None):
"""Make sure _adjust_mw is transfered also"""
m = super().model_validate(obj=obj, from_attributes=from_attributes)

# this is because from orm doesnt copy over '_' variables.
# But we don't want to expose this in the API
Expand Down Expand Up @@ -497,20 +518,36 @@ def to_orm(self) -> ForecastSQL:

@classmethod
def from_orm(cls, forecast_sql: ForecastSQL):
"""Method to make Forecast object from ForecastSQL,
"""Method to make Forecast object from ForecastSQL"""
# do normal transform
return Forecast(
forecast_creation_time=forecast_sql.forecast_creation_time,
location=Location.model_validate(forecast_sql.location, from_attributes=True),
input_data_last_updated=InputDataLastUpdated.model_validate(
forecast_sql.input_data_last_updated, from_attributes=True
),
forecast_values=[
ForecastValue.model_validate(forecast_value, from_attributes=True)
for forecast_value in forecast_sql.forecast_values
],
historic=forecast_sql.historic,
model=MLModel.model_validate(forecast_sql.model),
)

but move 'forecast_values_latest' to 'forecast_values'
This is useful as we want the API to still present a Forecast object.
"""
@classmethod
def model_validate(cls, forecast_sql: ForecastSQL, from_attributes: bool | None = None):
"""Method to make Forecast object from ForecastSQL"""
# do normal transform
return Forecast(
forecast_creation_time=forecast_sql.forecast_creation_time,
location=Location.from_orm(forecast_sql.location),
input_data_last_updated=InputDataLastUpdated.from_orm(
forecast_sql.input_data_last_updated
location=Location.model_validate(
forecast_sql.location, from_attributes=from_attributes
),
input_data_last_updated=InputDataLastUpdated.model_validate(
forecast_sql.input_data_last_updated, from_attributes=from_attributes
),
forecast_values=[
ForecastValue.from_orm(forecast_value)
ForecastValue.model_validate(forecast_value, from_attributes=from_attributes)
for forecast_value in forecast_sql.forecast_values
],
historic=forecast_sql.historic,
Expand All @@ -525,11 +562,29 @@ def from_orm_latest(cls, forecast_sql: ForecastSQL):
This is useful as we want the API to still present a Forecast object.
"""
# do normal transform
forecast = cls.from_orm(forecast_sql)
forecast = cls.model_validate(forecast_sql, from_attributes=True)

# move 'forecast_values_latest' to 'forecast_values'
forecast.forecast_values = [
ForecastValue.model_validate(forecast_value, from_attributes=True)
for forecast_value in forecast_sql.forecast_values_latest
]

return forecast

@classmethod
def model_validate_latest(cls, forecast_sql: ForecastSQL, from_attributes: bool | None = None):
"""Method to make Forecast object from ForecastSQL,

but move 'forecast_values_latest' to 'forecast_values'
This is useful as we want the API to still present a Forecast object.
"""
# do normal transform
forecast = cls.model_validate(forecast_sql, from_attributes=from_attributes)

# move 'forecast_values_latest' to 'forecast_values'
forecast.forecast_values = [
ForecastValue.from_orm(forecast_value)
ForecastValue.model_validate(forecast_value, from_attributes=from_attributes)
for forecast_value in forecast_sql.forecast_values_latest
]

Expand Down
7 changes: 5 additions & 2 deletions nowcasting_datamodel/national.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def make_national_forecast(
gsp_id = forecast.location.gsp_id

one_gsp = pd.DataFrame(
[ForecastValue.from_orm(value).dict() for value in forecast.forecast_values]
[
ForecastValue.model_validate(value, from_attributes=True).model_dump()
for value in forecast.forecast_values
]
)
adjusts_mw = [f.adjust_mw for f in forecast.forecast_values]
one_gsp["gps_id"] = gsp_id
Expand Down Expand Up @@ -98,6 +101,6 @@ def make_national_forecast(
)

# validate
_ = Forecast.from_orm(national_forecast)
_ = Forecast.model_validate(national_forecast, from_attributes=True)

return national_forecast
4 changes: 3 additions & 1 deletion nowcasting_datamodel/save/adjust.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ def reduce_metric_values_to_correct_forecast_horizon(
f"Reducing metric values to correct forecast horizon {datetime_now=} {hours_ahead=}"
)

latest_me_df = pd.DataFrame([MetricValue.from_orm(m).dict() for m in latest_me])
latest_me_df = pd.DataFrame(
[MetricValue.model_validate(m, from_attributes=True).model_dump() for m in latest_me]
)
if len(latest_me_df) == 0:
# no latest ME values, so just making an empty dataframe
latest_me_df = pd.DataFrame(columns=["forecast_horizon_minutes", "time_of_day", "value"])
Expand Down
36 changes: 21 additions & 15 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,51 +19,51 @@
def test_adjust_forecasts(forecasts):
forecasts[0].forecast_values[0].expected_power_generation_megawatts = 10.0
forecasts[0].forecast_values[0].adjust_mw = 1.23
forecasts = [Forecast.from_orm(f) for f in forecasts]
forecasts = [Forecast.model_validate(f, from_attributes=True) for f in forecasts]

assert forecasts[0].forecast_values[0]._adjust_mw == 1.23

forecasts[0].adjust(limit=1.22)
assert forecasts[0].forecast_values[0].expected_power_generation_megawatts == 8.78
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].dict()
assert "_adjust_mw" not in forecasts[0].forecast_values[0].dict()
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].model_dump()
assert "_adjust_mw" not in forecasts[0].forecast_values[0].model_dump()


def test_adjust_forecast_neg(forecasts):
forecasts[0].forecast_values[0].expected_power_generation_megawatts = 10.0
forecasts[0].forecast_values[0].adjust_mw = -1.23
forecasts = [Forecast.from_orm(f) for f in forecasts]
forecasts = [Forecast.model_validate(f, from_attributes=True) for f in forecasts]

forecasts[0].adjust(limit=1.22)
assert forecasts[0].forecast_values[0].expected_power_generation_megawatts == 11.22
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].dict()
assert "_adjust_mw" not in forecasts[0].forecast_values[0].dict()
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].model_dump()
assert "_adjust_mw" not in forecasts[0].forecast_values[0].model_dump()


def test_adjust_forecast_below_zero(forecasts):
v = forecasts[0].forecast_values[0].expected_power_generation_megawatts
forecasts[0].forecast_values[0].adjust_mw = v + 100
forecasts = [Forecast.from_orm(f) for f in forecasts]
forecasts = [Forecast.model_validate(f, from_attributes=True) for f in forecasts]
forecasts[0].forecast_values[0]._properties = {"10": v - 100}

forecasts[0].adjust(limit=v * 3)

assert forecasts[0].forecast_values[0].expected_power_generation_megawatts == 0.0
assert forecasts[0].forecast_values[0]._properties["10"] == 0.0
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].dict()
assert "_adjust_mw" not in forecasts[0].forecast_values[0].dict()
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].model_dump()
assert "_adjust_mw" not in forecasts[0].forecast_values[0].model_dump()


def test_adjust_many_forecasts(forecasts):
forecasts[0].forecast_values[0].adjust_mw = 1.23
forecasts = [Forecast.from_orm(f) for f in forecasts]
forecasts = [Forecast.model_validate(f, from_attributes=True) for f in forecasts]
m = ManyForecasts(forecasts=forecasts)
m.adjust()


def test_normalize_forecasts(forecasts):
v = forecasts[0].forecast_values[0].expected_power_generation_megawatts
forecasts_all = [Forecast.from_orm(f) for f in forecasts]
forecasts_all = [Forecast.model_validate(f, from_attributes=True) for f in forecasts]

forecasts_all[0].normalize()
assert (
Expand All @@ -76,7 +76,7 @@ def test_normalize_forecasts(forecasts):


def test_normalize_forecasts_no_installed_capacity(forecasts):
forecast = Forecast.from_orm(forecasts[0])
forecast = Forecast.model_validate(forecasts[0], from_attributes=True)
forecast.location.installed_capacity_mw = None

v = forecast.forecast_values[0].expected_power_generation_megawatts
Expand All @@ -98,7 +98,7 @@ def test_status_validation():
def test_status_orm():
status = Status(message="testing", status="warning")
ormed_status = status.to_orm()
status_orm = Status.from_orm(ormed_status)
status_orm = Status.model_validate(ormed_status, from_attributes=True)

assert status_orm.message == status.message
assert status_orm.status == status.status
Expand All @@ -123,6 +123,12 @@ def test_forecast_latest_to_pydantic(forecast_sql):
forecast = Forecast.from_orm_latest(forecast_sql=forecast_sql)
assert forecast.forecast_values[0] == ForecastValue.from_orm(f1)

forecast = Forecast.model_validate(forecast_sql, from_attributes=True)
assert forecast.forecast_values[0] != ForecastValue.model_validate(f1, from_attributes=True)

forecast = Forecast.model_validate_latest(forecast_sql=forecast_sql, from_attributes=True)
assert forecast.forecast_values[0] == ForecastValue.model_validate(f1, from_attributes=True)


def test_forecast_value_from_orm(forecast_sql):
forecast_sql = forecast_sql[0]
Expand All @@ -131,7 +137,7 @@ def test_forecast_value_from_orm(forecast_sql):
target_time=datetime(2023, 1, 1, 0, 30), expected_power_generation_megawatts=1
)

actual = ForecastValue.from_orm(f)
actual = ForecastValue.model_validate(f, from_attributes=True)
expected = ForecastValue(
target_time=datetime(2023, 1, 1, 0, 30, tzinfo=timezone.utc),
expected_power_generation_megawatts=1.0,
Expand All @@ -148,7 +154,7 @@ def test_forecast_value_from_orm_from_adjust_mw_nan(forecast_sql, null_value):
)
f.adjust_mw = null_value

actual = ForecastValue.from_orm(f)
actual = ForecastValue.model_validate(f, from_attributes=True)
expected = ForecastValue(
target_time=datetime(2023, 1, 1, 0, 30, tzinfo=timezone.utc),
expected_power_generation_megawatts=1.0,
Expand Down
14 changes: 9 additions & 5 deletions tests/read/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_get_forecast(db_session, forecasts):
assert forecast_read.forecast_values[0] == forecasts[-1].forecast_values[0]

_ = Forecast.from_orm(forecast_read)
_ = Forecast.model_validate(forecast_read, from_attributes=True)


def test_read_gsp_id(db_session, forecasts):
Expand Down Expand Up @@ -125,6 +126,7 @@ def test_get_forecast_values_gsp_id(db_session, forecasts):
)

_ = ForecastValue.from_orm(forecast_values_read[0])
_ = ForecastValue.model_validate(forecast_values_read[0], from_attributes=True)

assert len(forecast_values_read) == N_FAKE_FORECASTS

Expand Down Expand Up @@ -152,7 +154,7 @@ def test_get_forecast_values_latest_gsp_id(db_session):
forecast_values_read = get_forecast_values_latest(
session=db_session, gsp_id=f1[0].location.gsp_id
)
_ = ForecastValue.from_orm(forecast_values_read[0])
_ = ForecastValue.model_validate(forecast_values_read[0], from_attributes=True)

assert len(forecast_values_read) == 2
assert forecast_values_read[0].gsp_id == f1[0].location.gsp_id
Expand Down Expand Up @@ -220,7 +222,7 @@ def test_get_forecast_values_gsp_id_latest(db_session):
start_datetime=datetime(2024, 1, 2, tzinfo=timezone.utc),
)

_ = ForecastValue.from_orm(forecast_values_read[0])
_ = ForecastValue.model_validate(forecast_values_read[0], from_attributes=True)

assert len(forecast_values_read) == 16 # only getting forecast ahead

Expand All @@ -244,7 +246,7 @@ def test_get_forecast_values_start_and_creation(db_session):
created_utc_limit=datetime(2024, 1, 1, tzinfo=timezone.utc),
)

_ = ForecastValue.from_orm(forecast_values_read[0])
_ = ForecastValue.model_validate(forecast_values_read[0], from_attributes=True)

assert len(forecast_values_read) == 76 # only getting forecast ahead

Expand Down Expand Up @@ -402,15 +404,17 @@ def test_get_national_latest_forecast(db_session):


def test_get_pv_system(db_session_pv):
pv_system = PVSystem.from_orm(make_fake_pv_system())
pv_system = PVSystem.model_validate(make_fake_pv_system(), from_attributes=True)
save_pv_system(session=db_session_pv, pv_system=pv_system)

pv_system_get = get_pv_system(
session=db_session_pv, provider=pv_system.provider, pv_system_id=pv_system.pv_system_id
)
# this get defaulted to True when adding to the database
pv_system.correct_data = True
assert PVSystem.from_orm(pv_system) == PVSystem.from_orm(pv_system_get)
assert PVSystem.model_validate(pv_system, from_attributes=True) == PVSystem.model_validate(
pv_system_get, from_attributes=True
)


def test_get_latest_input_data_last_updated_multiple_entries(db_session):
Expand Down
5 changes: 4 additions & 1 deletion tests/read/test_read_gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,10 @@ def test_get_gsp_yield_by_location(db_session):
assert locations_with_gsp_yields[0].gsp_id == 1
assert len(locations_with_gsp_yields[0].gsp_yields) == 2

locations = [LocationWithGSPYields.from_orm(location) for location in locations_with_gsp_yields]
locations = [
LocationWithGSPYields.model_validate(location, from_attributes=True)
for location in locations_with_gsp_yields
]
assert len(locations[0].gsp_yields) == 2
assert locations_with_gsp_yields[0].gsp_yields[0].datetime_utc.tzinfo == timezone.utc

Expand Down
2 changes: 1 addition & 1 deletion tests/read/test_read_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def test_get_model(db_session):
assert model_read_1.name == model_read_2.name
assert model_read_1.version == model_read_2.version

_ = MLModel.from_orm(model_read_2)
_ = MLModel.model_validate(model_read_2, from_attributes=True)
12 changes: 6 additions & 6 deletions tests/test_fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,29 @@ def test_make_fake_intensity():

def test_make_fake_location():
location_sql: LocationSQL = make_fake_location(1)
location = Location.from_orm(location_sql)
location = Location.model_validate(location_sql, from_attributes=True)
_ = Location.to_orm(location)


def test_make_fake_input_data_last_updated():
input_sql: InputDataLastUpdatedSQL = make_fake_input_data_last_updated()
input = InputDataLastUpdated.from_orm(input_sql)
input = InputDataLastUpdated.model_validate(input_sql, from_attributes=True)
_ = InputDataLastUpdated.to_orm(input)


def test_make_fake_forecast_value():
target = datetime(2023, 1, 1, tzinfo=timezone.utc)

forecast_value_sql: ForecastValueSQL = make_fake_forecast_value(target_time=target)
forecast_value = ForecastValue.from_orm(forecast_value_sql)
forecast_value = ForecastValue.model_validate(forecast_value_sql, from_attributes=True)
_ = ForecastValue.to_orm(forecast_value)


def test_make_fake_forecast(db_session):
forecast_sql: ForecastSQL = make_fake_forecast(gsp_id=1, session=db_session)
forecast = Forecast.from_orm(forecast_sql)
forecast = Forecast.model_validate(forecast_sql, from_attributes=True)
forecast_sql = Forecast.to_orm(forecast)
_ = Forecast.from_orm(forecast_sql)
_ = Forecast.model_validate(forecast_sql, from_attributes=True)

from sqlalchemy import text

Expand All @@ -78,7 +78,7 @@ def test_make_fake_forecasts(db_session):

def test_make_national_fake_forecast(db_session):
forecast_sql: ForecastSQL = make_fake_national_forecast(session=db_session)
forecast = Forecast.from_orm(forecast_sql)
forecast = Forecast.model_validate(forecast_sql, from_attributes=True)
_ = Forecast.to_orm(forecast)


Expand Down
Loading
Loading