Skip to content

Commit

Permalink
fix(get_package and model_time): #2117, #2118 (#2123)
Browse files Browse the repository at this point in the history
* fix(get_package and model_time): #2117, #2118

get_package now allows you to get package only by name or type, instead of always searching for a package both by name and type

model_time displays the correct steady state array and no longer gets confused if packages are named similar to the package type it is searching for

* fix(resolve merge conflict)

---------

Co-authored-by: scottrp <[email protected]>
  • Loading branch information
spaulins-usgs and scottrp authored Mar 20, 2024
1 parent 00b3d1c commit 11f573b
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 20 deletions.
11 changes: 11 additions & 0 deletions autotest/regression/test_mf6.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,17 @@ def test005_create_tests_advgw_tidal(function_tmpdir, example_data_path):
assert value[0][0] == "ghb- 2-6-10"
assert found_flows and found_obs

# check model.time steady state and transient
sto_package = model.get_package("sto")
sto_package.steady_state.set_data({0: True, 1: False, 2: False, 3: False})
sto_package.transient.set_data({0: False, 1: True, 2: True, 3: True})
flopy.mf6.ModflowGwfdrn(model, pname="storm")
ss = model.modeltime.steady_state
assert ss[0]
assert not ss[1]
assert not ss[2]
assert not ss[3]

# clean up
sim.delete_output_files()

Expand Down
31 changes: 18 additions & 13 deletions flopy/mf6/mfbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def _rename_package(self, package, new_name):
)
main_dict[new_key] = main_dict.pop(key)

def get_package(self, name=None):
def get_package(self, name=None, type_only=False, name_only=False):
"""
Finds a package by package name, package key, package type, or partial
package name. returns either a single package, a list of packages,
Expand All @@ -633,7 +633,11 @@ def get_package(self, name=None):
Parameters
----------
name : str
Name of the package, 'RIV', 'LPF', etc.
Name or type of the package, 'my-riv-1, 'RIV', 'LPF', etc.
type_only : bool
Search for package by type only
name_only : bool
Search for package by name only
Returns
-------
Expand All @@ -644,11 +648,11 @@ def get_package(self, name=None):
return self._packagelist[:]

# search for full package name
if name.lower() in self.package_name_dict:
if name.lower() in self.package_name_dict and not type_only:
return self.package_name_dict[name.lower()]

# search for package type
if name.lower() in self.package_type_dict:
if name.lower() in self.package_type_dict and not name_only:
if len(self.package_type_dict[name.lower()]) == 0:
return None
elif len(self.package_type_dict[name.lower()]) == 1:
Expand All @@ -657,18 +661,19 @@ def get_package(self, name=None):
return self.package_type_dict[name.lower()]

# search for file name
if name.lower() in self.package_filename_dict:
if name.lower() in self.package_filename_dict and not type_only:
return self.package_filename_dict[name.lower()]

# search for partial and case-insensitive package name
for pp in self._packagelist:
if pp.package_name is not None:
# get first package of the type requested
package_name = pp.package_name.lower()
if len(package_name) > len(name):
package_name = package_name[0 : len(name)]
if package_name.lower() == name.lower():
return pp
if not type_only:
for pp in self._packagelist:
if pp.package_name is not None:
# get first package of the type requested
package_name = pp.package_name.lower()
if len(package_name) > len(name):
package_name = package_name[0 : len(name)]
if package_name.lower() == name.lower():
return pp

return None

Expand Down
18 changes: 12 additions & 6 deletions flopy/mf6/mfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,25 +264,31 @@ def modeltime(self):
simulation.
"""
tdis = self.simulation.get_package("tdis")
tdis = self.simulation.get_package("tdis", type_only=True)
period_data = tdis.perioddata.get_data()

# build steady state data
sto = self.get_package("sto")
sto = self.get_package("sto", type_only=True)
if sto is None:
steady = np.full((len(period_data["perlen"])), True, dtype=bool)
else:
steady = np.full((len(period_data["perlen"])), False, dtype=bool)
ss_periods = sto.steady_state.get_active_key_dict()
for period, val in ss_periods.items():
if val:
ss_periods[period] = sto.steady_state.get_data(period)
tr_periods = sto.transient.get_active_key_dict()
for period, val in tr_periods.items():
if val:
tr_periods[period] = sto.transient.get_data(period)
if ss_periods:
last_ss_value = False
# loop through steady state array
for index, value in enumerate(steady):
# resolve if current index is steady state or transient
if index in ss_periods:
if index in ss_periods and ss_periods[index]:
last_ss_value = True
elif index in tr_periods:
elif index in tr_periods and tr_periods[index]:
last_ss_value = False
if last_ss_value is True:
steady[index] = True
Expand Down Expand Up @@ -830,7 +836,7 @@ def load_base(
sim_data = simulation.simulation_data
if ftype == "dis" and not sim_data.max_columns_user_set:
# set column wrap to ncol
dis = instance.get_package("dis")
dis = instance.get_package("dis", type_only=True)
if dis is not None and hasattr(dis, "ncol"):
sim_data.max_columns_of_data = dis.ncol.get_data()
sim_data.max_columns_user_set = False
Expand Down Expand Up @@ -1242,7 +1248,7 @@ def get_steadystate_list(self):
ss_list.append(True)
index += 1

storage = self.get_package("sto")
storage = self.get_package("sto", type_only=True)
if storage is not None:
tr_keys = storage.transient.get_keys(True)
ss_keys = storage.steady_state.get_keys(True)
Expand Down
2 changes: 1 addition & 1 deletion flopy/mf6/mfsimbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,7 +1535,7 @@ def write_simulation(
if not sim_data.max_columns_user_set:
# search for dis packages
for model in self._models.values():
dis = model.get_package("dis")
dis = model.get_package("dis", type_only=True)
if dis is not None and hasattr(dis, "ncol"):
sim_data.max_columns_of_data = dis.ncol.get_data()
sim_data.max_columns_user_set = False
Expand Down

0 comments on commit 11f573b

Please sign in to comment.