Skip to content

Commit

Permalink
refactor(PackageContainer): compose not inherit, deprecate methods (#…
Browse files Browse the repository at this point in the history
…2324)

First steps toward shrinking the inheritance hierarchy. Components (simulation/model/package) now have a PackageContainer instead of being one. APIs which are clearly public remain the same. Visibility is tweaked for a few methods which were previously "private" (leading underscore) but seem user-facing. A number of methods which seem internal-only and/or are redundant with other APIs are deprecated.
  • Loading branch information
deltamarnix authored Oct 14, 2024
1 parent 82ec160 commit f378f84
Show file tree
Hide file tree
Showing 9 changed files with 476 additions and 182 deletions.
45 changes: 22 additions & 23 deletions autotest/regression/test_mf6.py
Original file line number Diff line number Diff line change
Expand Up @@ -4497,22 +4497,21 @@ def test006_2models_mvr(function_tmpdir, example_data_path):
exg_pkg.exchangedata.set_data(exg_data)

# test getting packages
pkg_dict = parent_model.package_dict
assert len(pkg_dict) == 6
pkg_names = parent_model.package_names
assert len(pkg_names) == 6
pkg_list = parent_model.get_package()
assert len(pkg_list) == 6
# confirm that this is a copy of the original dictionary with references
# to the packages
del pkg_dict[pkg_names[0]]
assert len(pkg_dict) == 5
pkg_dict = parent_model.package_dict
assert len(pkg_dict) == 6

old_val = pkg_dict["dis"].nlay.get_data()
pkg_dict["dis"].nlay = 22
pkg_dict = parent_model.package_dict
assert pkg_dict["dis"].nlay.get_data() == 22
pkg_dict["dis"].nlay = old_val
del pkg_list[0]
assert len(pkg_list) == 5
pkg_list = parent_model.get_package()
assert len(pkg_list) == 6

dis_pkg = parent_model.get_package("dis")
old_val = dis_pkg.nlay.get_data()
dis_pkg.nlay = 22
pkg_list = parent_model.get_package()
assert dis_pkg.nlay.get_data() == 22
dis_pkg.nlay = old_val

# write simulation again
save_folder = function_tmpdir / "save"
Expand Down Expand Up @@ -4560,8 +4559,8 @@ def test006_2models_mvr(function_tmpdir, example_data_path):
model = sim.get_model(model_name)
for package in model_package_check:
assert (
package in model.package_type_dict
or package in sim.package_type_dict
model.get_package(package, type_only=True) is not None
or sim.get_package(package, type_only=True) is not None
) == (package in load_only or f"{package}6" in load_only)
assert (len(sim._exchange_files) > 0) == (
"gwf6-gwf6" in load_only or "gwf-gwf" in load_only
Expand All @@ -4577,10 +4576,10 @@ def test006_2models_mvr(function_tmpdir, example_data_path):
)
model_parent = sim.get_model("parent")
model_child = sim.get_model("child")
assert "oc" not in model_parent.package_type_dict
assert "oc" in model_child.package_type_dict
assert "npf" in model_parent.package_type_dict
assert "npf" not in model_child.package_type_dict
assert model_parent.get_package("oc") is None
assert model_child.get_package("oc") is not None
assert model_parent.get_package("npf") is not None
assert model_child.get_package("npf") is None

# test running a runnable load_only case
sim = MFSimulation.load(
Expand Down Expand Up @@ -4652,9 +4651,9 @@ def test001e_uzf_3lay(function_tmpdir, example_data_path):
sim.set_sim_path(function_tmpdir)
model = sim.get_model()
for package in model_package_check:
assert (package in model.package_type_dict) == (
package in load_only or f"{package}6" in load_only
)
assert (
model.get_package(package, type_only=True) is not None
) == (package in load_only or f"{package}6" in load_only)
# test running a runnable load_only case
sim = MFSimulation.load(
model_name, "mf6", "mf6", pth, load_only=load_only_lists[0]
Expand Down
12 changes: 2 additions & 10 deletions autotest/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ def model_is_copy(m1, m2):
if k in [
"_packagelist",
"_package_paths",
"package_key_dict",
"package_type_dict",
"package_name_dict",
"package_filename_dict",
"_ftype_num_dict",
]:
continue
Expand Down Expand Up @@ -97,17 +93,13 @@ def package_is_copy(pk1, pk2):
if k in [
"_child_package_groups",
"_data_list",
"_packagelist",
"_simulation_data",
"simulation_data",
"blocks",
"dimensions",
"package_key_dict",
"package_name_dict",
"package_filename_dict",
"package_type_dict",
"post_block_comments",
"simulation_data",
"structure",
"_package_container",
]:
continue
elif isinstance(v, MFPackage):
Expand Down
18 changes: 9 additions & 9 deletions autotest/test_model_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,19 +467,19 @@ def test_empty_packages(function_tmpdir):
m0 = new_sim.get_model(f"{base_name}_0")
m1 = new_sim.get_model(f"{base_name}_1")

if "chd_0" in m0.package_dict:
raise AssertionError(f"Empty CHD file written to {base_name}_0 model")

if "wel_0" in m1.package_dict:
raise AssertionError(f"Empty WEL file written to {base_name}_1 model")
assert not m0.get_package(
name="chd_0"
), f"Empty CHD file written to {base_name}_0 model"
assert not m1.get_package(
name="wel_0"
), f"Empty WEL file written to {base_name}_1 model"

mvr_status0 = m0.sfr.mover.array
mvr_status1 = m0.sfr.mover.array

if not mvr_status0 or not mvr_status1:
raise AssertionError(
"Mover status being overwritten in options splitting"
)
assert (
mvr_status0 and mvr_status1
), "Mover status being overwritten in options splitting"


@requires_exe("mf6")
Expand Down
43 changes: 13 additions & 30 deletions flopy/mf6/mfbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pathlib import Path
from shutil import copyfile
from typing import Union
from warnings import warn


# internal handled exceptions
Expand Down Expand Up @@ -454,24 +453,13 @@ class PackageContainer:
modflow_models = []
models_by_type = {}

def __init__(self, simulation_data, name):
self.type = "PackageContainer"
self.simulation_data = simulation_data
self.name = name
self._packagelist = []
def __init__(self, simulation_data):
self._simulation_data = simulation_data
self.packagelist = []
self.package_type_dict = {}
self.package_name_dict = {}
self.package_filename_dict = {}

@property
def package_key_dict(self):
warnings.warn(
"package_key_dict has been deprecated, use "
"package_type_dict instead",
category=DeprecationWarning,
)
return self.package_type_dict

@staticmethod
def package_list():
"""Static method that returns the list of available packages.
Expand Down Expand Up @@ -554,9 +542,9 @@ def package_names(self):
"""Returns a list of package names."""
return list(self.package_name_dict.keys())

def _add_package(self, package, path):
def add_package(self, package):
# put in packages list and update lookup dictionaries
self._packagelist.append(package)
self.packagelist.append(package)
if package.package_name is not None:
self.package_name_dict[package.package_name.lower()] = package
if package.filename is not None:
Expand All @@ -565,9 +553,9 @@ def _add_package(self, package, path):
self.package_type_dict[package.package_type.lower()] = []
self.package_type_dict[package.package_type.lower()].append(package)

def _remove_package(self, package):
if package in self._packagelist:
self._packagelist.remove(package)
def remove_package(self, package):
if package in self.packagelist:
self.packagelist.remove(package)
if (
package.package_name is not None
and package.package_name.lower() in self.package_name_dict
Expand All @@ -587,7 +575,7 @@ def _remove_package(self, package):

# collect keys of items to be removed from main dictionary
items_to_remove = []
for key in self.simulation_data.mfdata:
for key in self._simulation_data.mfdata:
is_subkey = True
for pitem, ditem in zip(package.path, key):
if pitem != ditem:
Expand All @@ -598,7 +586,7 @@ def _remove_package(self, package):

# remove items from main dictionary
for key in items_to_remove:
del self.simulation_data.mfdata[key]
del self._simulation_data.mfdata[key]

def _rename_package(self, package, new_name):
# fix package_name_dict key
Expand All @@ -609,7 +597,7 @@ def _rename_package(self, package, new_name):
del self.package_name_dict[package.package_name.lower()]
self.package_name_dict[new_name.lower()] = package
# get keys to fix in main dictionary
main_dict = self.simulation_data.mfdata
main_dict = self._simulation_data.mfdata
items_to_fix = []
for key in main_dict:
is_subkey = True
Expand Down Expand Up @@ -648,7 +636,7 @@ def get_package(self, name=None, type_only=False, name_only=False):
"""
if name is None:
return self._packagelist[:]
return self.packagelist[:]

# search for full package name
if name.lower() in self.package_name_dict and not type_only:
Expand All @@ -669,7 +657,7 @@ def get_package(self, name=None, type_only=False, name_only=False):

# search for partial and case-insensitive package name
if not type_only:
for pp in self._packagelist:
for pp in self.packagelist:
if pp.package_name is not None:
# get first package of the type requested
package_name = pp.package_name.lower()
Expand All @@ -680,11 +668,6 @@ def get_package(self, name=None, type_only=False, name_only=False):

return None

def register_package(self, package):
"""Base method for registering a package. Should be overridden."""
path = (package.package_name,)
return (path, None)

@staticmethod
def _load_only_dict(load_only):
if load_only is None:
Expand Down
Loading

0 comments on commit f378f84

Please sign in to comment.