diff --git a/PyStemmusScope/bmi/implementation.py b/PyStemmusScope/bmi/implementation.py index cff64e29..c64ca575 100644 --- a/PyStemmusScope/bmi/implementation.py +++ b/PyStemmusScope/bmi/implementation.py @@ -1,6 +1,5 @@ """BMI wrapper for the STEMMUS_SCOPE model.""" import os -from dataclasses import dataclass from pathlib import Path from typing import Literal from typing import Protocol @@ -9,48 +8,11 @@ import numpy as np from bmipy.bmi import Bmi from PyStemmusScope.bmi.utils import InapplicableBmiMethods +from PyStemmusScope.bmi.utils import nested_set +from PyStemmusScope.bmi.variable_reference import VARIABLES from PyStemmusScope.config_io import read_config -@dataclass -class BmiVariable: - """Holds all info to inform the BMI implementation.""" - - name: str - dtype: str - input: bool - output: bool - units: str - grid: int - - -VARIABLES: tuple[BmiVariable, ...] = ( - # name dtype input ouput units grid - # atmospheric vars: - BmiVariable("respiration", "float64", False, True, "cm s-1", 0), - BmiVariable("evaporation_total", "float64", False, True, "cm s-1", 0), - - # soil vars: - BmiVariable("soil_temperature", "float64", True, True, "degC", 1), - BmiVariable("soil_moisture", "float64", True, True, "m3 m-3", 1), - BmiVariable("soil_root_water_uptake", "float64", False, True, "cm s-1", 0), - - # surface runoff - BmiVariable("surface_runoff_total", "float64", False, True, "cm s-1", 0), - BmiVariable("surface_runoff_hortonian", "float64", False, True, "cm s-1", 0), - BmiVariable("surface_runoff_dunnian", "float64", False, True, "cm s-1", 0), - - # groundwater vars (STEMMUS_SCOPE) - BmiVariable("groundwater_root_water_uptake", "float64", False, True, "cm s-1", 0), - BmiVariable("groundwater_recharge", "float64", False, True, "cm s-1", 0), - - # groundwater (coupling) vars - BmiVariable("groundwater_coupling_enabled", "bool", True, False, "-", 0), - BmiVariable("groundwater_head_bottom_layer", "float64", True, False, "cm", 0), - BmiVariable("groundwater_temperature", "float64", True, False, "degC", 0), - BmiVariable("groundwater_elevation_top_aquifer", "float64", True, False, "cm", 0), -) - MODEL_INPUT_VARNAMES: tuple[str, ...] = tuple( var.name for var in VARIABLES if var.input ) @@ -67,6 +29,8 @@ class BmiVariable: VARNAME_GRID: dict[str, int] = {var.name: var.grid for var in VARIABLES} +VARNAME_LOC: dict[str, list[str]] = {var.name: var.loc for var in VARIABLES} + NO_STATE_MSG = ( "The model state is not available. Please run `.update()` before requesting " "\nthis model info. If you did run .update() before, something seems to have " @@ -92,57 +56,34 @@ def load_state(config: dict) -> h5py.File: return h5py.File(matfile, mode="a") -def get_variable(state: h5py.File, varname: str) -> np.ndarray: # noqa: PLR0911 PLR0912 C901 +def get_variable( + state: h5py.File, varname: str +) -> np.ndarray: # noqa: PLR0911 PLR0912 C901 """Get a variable from the model state. Args: state: STEMMUS_SCOPE model state varname: Variable name """ - # atmospheric vars - if varname == "respiration": - return state["fluxes"]["Resp"][0] - elif varname == "evaporation_total": - return state["EVAP"][0] + if varname not in MODEL_VARNAMES: + msg = "Unknown variable name" + raise ValueError(msg) - # soil vars + # deviating implemetation: elif varname == "soil_temperature": return state["TT"][0, :-1] - elif varname == "soil_moisture": - return state["SoilVariables"]["Theta_U"][0] - elif varname == "soil_root_water_uptake": - return state["RWUs"][0] - - # surface runoff - elif varname == "surface_runoff_total": - return state["RS"][0] - elif varname == "surface_runoff_dunnian": - return state["ForcingData"]["R_Dunn"][0] - elif varname == "surface_runoff_hortonian": - return state["ForcingData"]["R_Hort"][0] - - # groundwater vars - elif varname == "groundwater_root_water_uptake": - return state["RWUg"][0] - elif varname == "groundwater_recharge": - return state["gwfluxes"]["recharge"][0] - - # groundwater coupling variables: - elif varname == "groundwater_coupling_enabled": - return state["GroundwaterSettings"]["GroundwaterCoupling"][0].astype(bool) - elif varname == "groundwater_head_bottom_layer": - return state["GroundwaterSettings"]["headBotmLayer"][0] - elif varname == "groundwater_temperature": - return state["GroundwaterSettings"]["tempBotm"][0] - elif varname == "groundwater_elevation_top_aquifer": - return state["GroundwaterSettings"]["toplevel"][0] - else: - if varname in MODEL_VARNAMES: - msg = "Varname is missing in get_variable! Contact devs." - else: - msg = "Unknown variable name" - raise ValueError(msg) + # default implementation: + _s = state + for _loc in VARNAME_LOC[varname]: + _s = _s.get(_loc) + + if VARNAME_GRID[varname] == 0: + return _s[0].astype(VARNAME_DTYPE[varname]) + + # something's gone wrong: + msg = "Varname is missing in get_variable! Contact devs." + raise ValueError(msg) def set_variable( @@ -168,29 +109,21 @@ def set_variable( else: vals = value + if varname in MODEL_OUTPUT_VARNAMES and varname not in MODEL_INPUT_VARNAMES: + msg = "This variable is a model output variable only. You cannot set it." + raise ValueError(msg) + elif varname not in MODEL_INPUT_VARNAMES: + msg = "Uknown variable name" + raise ValueError(msg) + + # deviating implementations: if varname == "soil_temperature": state["TT"][0, :-1] = vals - elif varname == "soil_moisture": - state["SoilVariables"]["Theta_U"][0] = vals - - # groundwater coupling variables: elif varname == "groundwater_coupling_enabled": state["GroundwaterSettings"]["GroundwaterCoupling"][0] = vals.astype("float") - elif varname == "groundwater_head_bottom_layer": - state["GroundwaterSettings"]["headBotmLayer"][0] = vals - elif varname == "groundwater_temperature": - state["GroundwaterSettings"]["tempBotm"][0] = vals - elif varname == "groundwater_elevation_top_aquifer": - state["GroundwaterSettings"]["toplevel"][0] = vals - + # default: else: - if varname in MODEL_OUTPUT_VARNAMES and varname not in MODEL_INPUT_VARNAMES: - msg = "This variable is a model output variable only. You cannot set it." - elif varname in MODEL_VARNAMES: - msg = "Varname is missing in set_variable! Contact devs." - else: - msg = "Uknown variable name" - raise ValueError(msg) + nested_set(state, VARNAME_LOC[varname] + [0], vals) return state diff --git a/PyStemmusScope/bmi/utils.py b/PyStemmusScope/bmi/utils.py index 6586708a..7fde7867 100644 --- a/PyStemmusScope/bmi/utils.py +++ b/PyStemmusScope/bmi/utils.py @@ -1,4 +1,6 @@ """Utilities for the STEMMUS_SCOPE Basic Model Interface.""" +from typing import Any +from typing import Union import numpy as np @@ -64,3 +66,18 @@ def get_grid_nodes_per_face( ) -> np.ndarray: """Get the number of nodes for each face.""" raise NotImplementedError(INAPPLICABLE_GRID_METHOD_MSG) + + +def nested_set(dic: dict, keys: Union[list, tuple], value: Any) -> None: + """Set a value in a nested dictionary programatically. + + E.g.: dict[keys[0]][keys[1]] = value + + Args: + dic: Dictionary to be modified. + keys: Iterable of keys that are used to find the right value. + value: The new value. + """ + for key in keys[:-1]: + dic = dic.setdefault(key, {}) + dic[keys[-1]] = value diff --git a/PyStemmusScope/bmi/variable_reference.py b/PyStemmusScope/bmi/variable_reference.py new file mode 100644 index 00000000..123050c2 --- /dev/null +++ b/PyStemmusScope/bmi/variable_reference.py @@ -0,0 +1,150 @@ +"""Variable reference to inform the BMI implementation.""" +from dataclasses import dataclass + + +@dataclass +class BmiVariable: + """Holds all info to inform the BMI implementation.""" + + name: str + dtype: str + input: bool + output: bool + units: str + grid: int + loc: list[str] + + +VARIABLES: tuple[BmiVariable, ...] = ( + # atmospheric vars: + BmiVariable( + name="respiration", + dtype="float64", + input=False, + output=True, + units="cm s-1", + grid=0, + loc=["fluxes", "Resp"], + ), + BmiVariable( + name="evaporation_total", + dtype="float64", + input=False, + output=True, + units="cm s-1", + grid=0, + loc=["EVAP"], + ), + # soil vars: + BmiVariable( + name="soil_temperature", + dtype="float64", + input=True, + output=True, + units="degC", + grid=1, + loc=["TT"], + ), + BmiVariable( + name="soil_moisture", + dtype="float64", + input=True, + output=True, + units="m3 m-3", + grid=1, + loc=["SoilVariables", "Theta_U"], + ), + BmiVariable( + name="soil_root_water_uptake", + dtype="float64", + input=False, + output=True, + units="cm s-1", + grid=0, + loc=["RWUs"], + ), + # surface runoff + BmiVariable( + name="surface_runoff_total", + dtype="float64", + input=False, + output=True, + units="cm s-1", + grid=0, + loc=["RS"], + ), + BmiVariable( + name="surface_runoff_hortonian", + dtype="float64", + input=False, + output=True, + units="cm s-1", + grid=0, + loc=["ForcingData", "R_Dunn"], + ), + BmiVariable( + name="surface_runoff_dunnian", + dtype="float64", + input=False, + output=True, + units="cm s-1", + grid=0, + loc=["ForcingData", "R_Hort"], + ), + # groundwater vars (STEMMUS_SCOPE) + BmiVariable( + name="groundwater_root_water_uptake", + dtype="float64", + input=False, + output=True, + units="cm s-1", + grid=0, + loc=["RWUg"], + ), + BmiVariable( + name="groundwater_recharge", + dtype="float64", + input=False, + output=True, + units="cm s-1", + grid=0, + loc=["gwfluxes", "recharge"], + ), + # groundwater (coupling) vars + BmiVariable( + name="groundwater_coupling_enabled", + dtype="bool", + input=True, + output=False, + units="-", + grid=0, + loc=["GroundwaterSettings", "GroundwaterCoupling"], + ), + BmiVariable( + name="groundwater_head_bottom_layer", + dtype="float64", + input=True, + output=False, + units="cm", + grid=0, + loc=["GroundwaterSettings", "headBotmLayer"], + ), + BmiVariable( + name="groundwater_temperature", + dtype="float64", + input=True, + output=False, + units="degC", + grid=0, + loc=["GroundwaterSettings", "tempBotm"], + ), + BmiVariable( + name="groundwater_elevation_top_aquifer", + dtype="float64", + input=True, + output=False, + units="cm", + grid=0, + loc=["GroundwaterSettings", "toplevel"], + ), +)