Skip to content

Commit

Permalink
Add set/get value at indices methods
Browse files Browse the repository at this point in the history
  • Loading branch information
BSchilperoort committed Nov 29, 2023
1 parent 8576160 commit 87a2b0d
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 20 deletions.
30 changes: 22 additions & 8 deletions PyStemmusScope/bmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,33 @@ def get_variable(state: h5py.File, varname: str) -> np.ndarray:
raise ValueError(msg)


def set_variable(state: h5py.File, varname: str, value: np.ndarray) -> dict:
def set_variable(
state: h5py.File, varname: str, value: np.ndarray, inds: Union[np.ndarray, None] = None
) -> dict:
"""Set a variable in the model state.
Args:
state: Model state.
varname: Variable name.
value: New value for the variable.
inds: (Optional) at which indices you want to set the variable values.
Returns:
Updated model state.
"""
if varname == "respiration":
state["fluxes"]["Resp"][0] = value
elif varname == "soil_temperature":
state["TT"][0, :-1] = value
if inds is not None:
vals = get_variable(state, varname)
vals[inds] = value
else:
vals = value

if varname == "soil_temperature":
state["TT"][0, :-1] = vals
else:
if varname in MODEL_OUTPUT_VARNAMES and varname not in MODEL_INPUT_VARNAMES:
msg = "This variable is a model output variable only. You cannot set it."
if varname in MODEL_VARNAMES:
msg = "Varname is missing in get_variable! Contact devs."
msg = "Varname is missing in set_variable! Contact devs."
else:
msg = "Uknown variable name"
raise ValueError(msg)
Expand Down Expand Up @@ -333,7 +342,10 @@ def get_value_at_indices(
Returns:
Value of the model variable at the given location.
"""
raise NotImplementedError()
if self.state is None:
raise ValueError(NO_STATE_MSG)
dest[:] = get_variable(self.state, name)[inds]
return dest

def set_value(self, name: str, src: np.ndarray) -> None:
"""Specify a new value for a model variable.
Expand All @@ -360,7 +372,9 @@ def set_value_at_indices(
src : array_like
The new value for the specified variable.
"""
raise NotImplementedError()
if self.state is None:
raise ValueError(NO_STATE_MSG)
self.state = set_variable(self.state, name, src, inds)

### GRID INFO ###
def get_grid_rank(self, grid: int) -> int:
Expand Down
Loading

0 comments on commit 87a2b0d

Please sign in to comment.