From 11f573b0341b461a584faae4fd82ea4b2bbffc69 Mon Sep 17 00:00:00 2001 From: spaulins-usgs Date: Wed, 20 Mar 2024 14:08:32 -0700 Subject: [PATCH] fix(get_package and model_time): #2117, #2118 (#2123) * fix(get_package and model_time): #2117, #2118 get_package now allows you to get package only by name or type, instead of always searching for a package both by name and type model_time displays the correct steady state array and no longer gets confused if packages are named similar to the package type it is searching for * fix(resolve merge conflict) --------- Co-authored-by: scottrp <45947939+scottrp@users.noreply.github.com> --- autotest/regression/test_mf6.py | 11 +++++++++++ flopy/mf6/mfbase.py | 31 ++++++++++++++++++------------- flopy/mf6/mfmodel.py | 18 ++++++++++++------ flopy/mf6/mfsimbase.py | 2 +- 4 files changed, 42 insertions(+), 20 deletions(-) diff --git a/autotest/regression/test_mf6.py b/autotest/regression/test_mf6.py index 9c98504631..01fe275b4b 100644 --- a/autotest/regression/test_mf6.py +++ b/autotest/regression/test_mf6.py @@ -1946,6 +1946,17 @@ def test005_create_tests_advgw_tidal(function_tmpdir, example_data_path): assert value[0][0] == "ghb- 2-6-10" assert found_flows and found_obs + # check model.time steady state and transient + sto_package = model.get_package("sto") + sto_package.steady_state.set_data({0: True, 1: False, 2: False, 3: False}) + sto_package.transient.set_data({0: False, 1: True, 2: True, 3: True}) + flopy.mf6.ModflowGwfdrn(model, pname="storm") + ss = model.modeltime.steady_state + assert ss[0] + assert not ss[1] + assert not ss[2] + assert not ss[3] + # clean up sim.delete_output_files() diff --git a/flopy/mf6/mfbase.py b/flopy/mf6/mfbase.py index a1e6b3caf6..437509b641 100644 --- a/flopy/mf6/mfbase.py +++ b/flopy/mf6/mfbase.py @@ -624,7 +624,7 @@ def _rename_package(self, package, new_name): ) main_dict[new_key] = main_dict.pop(key) - def get_package(self, name=None): + def get_package(self, name=None, type_only=False, name_only=False): """ Finds a package by package name, package key, package type, or partial package name. returns either a single package, a list of packages, @@ -633,7 +633,11 @@ def get_package(self, name=None): Parameters ---------- name : str - Name of the package, 'RIV', 'LPF', etc. + Name or type of the package, 'my-riv-1, 'RIV', 'LPF', etc. + type_only : bool + Search for package by type only + name_only : bool + Search for package by name only Returns ------- @@ -644,11 +648,11 @@ def get_package(self, name=None): return self._packagelist[:] # search for full package name - if name.lower() in self.package_name_dict: + if name.lower() in self.package_name_dict and not type_only: return self.package_name_dict[name.lower()] # search for package type - if name.lower() in self.package_type_dict: + if name.lower() in self.package_type_dict and not name_only: if len(self.package_type_dict[name.lower()]) == 0: return None elif len(self.package_type_dict[name.lower()]) == 1: @@ -657,18 +661,19 @@ def get_package(self, name=None): return self.package_type_dict[name.lower()] # search for file name - if name.lower() in self.package_filename_dict: + if name.lower() in self.package_filename_dict and not type_only: return self.package_filename_dict[name.lower()] # search for partial and case-insensitive package name - for pp in self._packagelist: - if pp.package_name is not None: - # get first package of the type requested - package_name = pp.package_name.lower() - if len(package_name) > len(name): - package_name = package_name[0 : len(name)] - if package_name.lower() == name.lower(): - return pp + if not type_only: + for pp in self._packagelist: + if pp.package_name is not None: + # get first package of the type requested + package_name = pp.package_name.lower() + if len(package_name) > len(name): + package_name = package_name[0 : len(name)] + if package_name.lower() == name.lower(): + return pp return None diff --git a/flopy/mf6/mfmodel.py b/flopy/mf6/mfmodel.py index df9124e766..0d765cd804 100644 --- a/flopy/mf6/mfmodel.py +++ b/flopy/mf6/mfmodel.py @@ -264,25 +264,31 @@ def modeltime(self): simulation. """ - tdis = self.simulation.get_package("tdis") + tdis = self.simulation.get_package("tdis", type_only=True) period_data = tdis.perioddata.get_data() # build steady state data - sto = self.get_package("sto") + sto = self.get_package("sto", type_only=True) if sto is None: steady = np.full((len(period_data["perlen"])), True, dtype=bool) else: steady = np.full((len(period_data["perlen"])), False, dtype=bool) ss_periods = sto.steady_state.get_active_key_dict() + for period, val in ss_periods.items(): + if val: + ss_periods[period] = sto.steady_state.get_data(period) tr_periods = sto.transient.get_active_key_dict() + for period, val in tr_periods.items(): + if val: + tr_periods[period] = sto.transient.get_data(period) if ss_periods: last_ss_value = False # loop through steady state array for index, value in enumerate(steady): # resolve if current index is steady state or transient - if index in ss_periods: + if index in ss_periods and ss_periods[index]: last_ss_value = True - elif index in tr_periods: + elif index in tr_periods and tr_periods[index]: last_ss_value = False if last_ss_value is True: steady[index] = True @@ -830,7 +836,7 @@ def load_base( sim_data = simulation.simulation_data if ftype == "dis" and not sim_data.max_columns_user_set: # set column wrap to ncol - dis = instance.get_package("dis") + dis = instance.get_package("dis", type_only=True) if dis is not None and hasattr(dis, "ncol"): sim_data.max_columns_of_data = dis.ncol.get_data() sim_data.max_columns_user_set = False @@ -1242,7 +1248,7 @@ def get_steadystate_list(self): ss_list.append(True) index += 1 - storage = self.get_package("sto") + storage = self.get_package("sto", type_only=True) if storage is not None: tr_keys = storage.transient.get_keys(True) ss_keys = storage.steady_state.get_keys(True) diff --git a/flopy/mf6/mfsimbase.py b/flopy/mf6/mfsimbase.py index f817f4f532..e84c4dfc51 100644 --- a/flopy/mf6/mfsimbase.py +++ b/flopy/mf6/mfsimbase.py @@ -1535,7 +1535,7 @@ def write_simulation( if not sim_data.max_columns_user_set: # search for dis packages for model in self._models.values(): - dis = model.get_package("dis") + dis = model.get_package("dis", type_only=True) if dis is not None and hasattr(dis, "ncol"): sim_data.max_columns_of_data = dis.ncol.get_data() sim_data.max_columns_user_set = False