From e7fd3b4e4a583448a823908e3de1964271be003e Mon Sep 17 00:00:00 2001 From: Oliver Elbert Date: Tue, 30 Jan 2024 13:05:12 -0500 Subject: [PATCH] Move Active Physics Schemes to Config (#44) * initial commit, need to adapt and run tests * revising scheme name * tests pass * update history * linting * changing typehints for physics schemes to enum instead of str * driver now works with physics config enum, tests pass * fixed tests * missed one --- docs/physics/state.rst | 2 +- .../baroclinic_c12_explicit_physics.yaml | 98 +++++++++++++++++++ driver/pace/driver/driver.py | 5 +- driver/pace/driver/initialization.py | 16 ++- driver/pace/driver/state.py | 9 +- driver/tests/mpi/test_restart.py | 2 + .../fv3core/initialization/analytic_init.py | 8 +- physics/pace/physics/__init__.py | 2 +- physics/pace/physics/_config.py | 20 +++- physics/pace/physics/physics_state.py | 24 +++-- physics/pace/physics/stencils/physics.py | 23 +++-- .../savepoint/translate/translate_driver.py | 5 +- .../translate/translate_gfs_physics_driver.py | 6 +- .../translate/translate_microphysics.py | 3 +- tests/main/driver/test_example_configs.py | 1 + tests/main/driver/test_restart_fortran.py | 2 + tests/main/driver/test_restart_serial.py | 2 + tests/main/physics/test_integration.py | 3 +- util/HISTORY.md | 1 + util/pace/util/__init__.py | 1 + util/pace/util/utils.py | 6 ++ 21 files changed, 192 insertions(+), 47 deletions(-) create mode 100644 driver/examples/configs/baroclinic_c12_explicit_physics.yaml diff --git a/docs/physics/state.rst b/docs/physics/state.rst index d658171e..1b603887 100644 --- a/docs/physics/state.rst +++ b/docs/physics/state.rst @@ -38,6 +38,6 @@ You can initialize a zero-filled PhysicsState and MicrophysicsState from other P >>> quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend="numpy") >>> physics_state = PhysicsState.init_zeros( - ... quantity_factory=quantity_factory, active_packages=["microphysics"] + ... quantity_factory=quantity_factory, schemes=["GFS_microphysics"] ... ) >>> microphysics_state = physics_state.microphysics diff --git a/driver/examples/configs/baroclinic_c12_explicit_physics.yaml b/driver/examples/configs/baroclinic_c12_explicit_physics.yaml new file mode 100644 index 00000000..ceed306b --- /dev/null +++ b/driver/examples/configs/baroclinic_c12_explicit_physics.yaml @@ -0,0 +1,98 @@ +stencil_config: + compilation_config: + backend: numpy + rebuild: false + validate_args: true + format_source: false + device_sync: false +initialization: + type: analytic + config: + case: baroclinic +performance_config: + collect_performance: true + experiment_name: c12_baroclinic +nx_tile: 12 +nz: 79 +dt_atmos: 225 +minutes: 15 +layout: + - 1 + - 1 +diagnostics_config: + path: output + output_format: netcdf + names: + - u + - v + - ua + - va + - pt + - delp + - qvapor + - qliquid + - qice + - qrain + - qsnow + - qgraupel + z_select: + - level: 65 + names: + - pt +dycore_config: + a_imp: 1.0 + beta: 0. + consv_te: 0. + d2_bg: 0. + d2_bg_k1: 0.2 + d2_bg_k2: 0.1 + d4_bg: 0.15 + d_con: 1.0 + d_ext: 0.0 + dddmp: 0.5 + delt_max: 0.002 + do_sat_adj: true + do_vort_damp: true + fill: true + hord_dp: 6 + hord_mt: 6 + hord_tm: 6 + hord_tr: 8 + hord_vt: 6 + hydrostatic: false + k_split: 1 + ke_bg: 0. + kord_mt: 9 + kord_tm: -9 + kord_tr: 9 + kord_wz: 9 + n_split: 1 + nord: 3 + nwat: 6 + p_fac: 0.05 + rf_cutoff: 3000. + rf_fast: true + tau: 10. + vtdm4: 0.06 + z_tracer: true + do_qa: true + tau_i2s: 1000. + tau_g2v: 1200. + ql_gen: 0.001 + ql_mlt: 0.002 + qs_mlt: 0.000001 + qi_lim: 1.0 + dw_ocean: 0.1 + dw_land: 0.15 + icloud_f: 0 + tau_l2v: 300. + tau_v2l: 90. + fv_sg_adj: 0 + n_sponge: 48 + +physics_config: + hydrostatic: false + nwat: 6 + do_qa: true + schemes: + - GFS_microphysics diff --git a/driver/pace/driver/driver.py b/driver/pace/driver/driver.py index c8d1490f..24197621 100644 --- a/driver/pace/driver/driver.py +++ b/driver/pace/driver/driver.py @@ -231,6 +231,7 @@ def get_driver_state( damping_coefficients=damping_coefficients, driver_grid_data=driver_grid_data, grid_data=grid_data, + schemes=self.physics_config.schemes, ) @classmethod @@ -327,6 +328,9 @@ def write_for_restart( config_dict["initialization"]["type"] = "restart" config_dict["initialization"]["config"]["start_time"] = time config_dict["initialization"]["config"]["path"] = restart_path + # convert physics package enum to str + schemes = [scheme.value for scheme in config_dict["physics_config"]["schemes"]] + config_dict["physics_config"]["schemes"] = schemes # restart config doesn't have 'case' if "case" in config_dict["initialization"]["config"].keys(): del config_dict["initialization"]["config"]["case"] @@ -508,7 +512,6 @@ def exit_instead_of_build(self): quantity_factory=self.quantity_factory, grid_data=self.state.grid_data, namelist=self.config.physics_config, - active_packages=["microphysics"], ) else: # Make sure those are set to None to raise any issues diff --git a/driver/pace/driver/initialization.py b/driver/pace/driver/initialization.py index 3cf52376..04b08d3d 100644 --- a/driver/pace/driver/initialization.py +++ b/driver/pace/driver/initialization.py @@ -3,7 +3,7 @@ import os import pathlib from datetime import datetime -from typing import Callable, ClassVar, Type, TypeVar +from typing import Callable, ClassVar, List, Type, TypeVar import f90nml @@ -40,6 +40,7 @@ def get_driver_state( damping_coefficients: pace.util.grid.DampingCoefficients, driver_grid_data: pace.util.grid.DriverGridData, grid_data: pace.util.grid.GridData, + schemes: List[pace.physics.PHYSICS_PACKAGES], ) -> DriverState: ... @@ -77,6 +78,7 @@ def get_driver_state( damping_coefficients: pace.util.grid.DampingCoefficients, driver_grid_data: pace.util.grid.DriverGridData, grid_data: pace.util.grid.GridData, + schemes: List[pace.physics.PHYSICS_PACKAGES], ) -> DriverState: return self.config.get_driver_state( quantity_factory=quantity_factory, @@ -84,6 +86,7 @@ def get_driver_state( damping_coefficients=damping_coefficients, driver_grid_data=driver_grid_data, grid_data=grid_data, + schemes=schemes, ) @classmethod @@ -109,6 +112,7 @@ def get_driver_state( damping_coefficients: pace.util.grid.DampingCoefficients, driver_grid_data: pace.util.grid.DriverGridData, grid_data: pace.util.grid.GridData, + schemes: List[pace.physics.PHYSICS_PACKAGES], ) -> DriverState: dycore_state = analytic_init.init_analytic_state( analytic_init_case=self.case, @@ -120,7 +124,7 @@ def get_driver_state( comm=communicator, ) physics_state = pace.physics.PhysicsState.init_zeros( - quantity_factory=quantity_factory, active_packages=["microphysics"] + quantity_factory=quantity_factory, schemes=schemes ) tendency_state = TendencyState.init_zeros( quantity_factory=quantity_factory, @@ -152,6 +156,7 @@ def get_driver_state( damping_coefficients: pace.util.grid.DampingCoefficients, driver_grid_data: pace.util.grid.DriverGridData, grid_data: pace.util.grid.GridData, + schemes: List[pace.physics.PHYSICS_PACKAGES], ) -> DriverState: state = _restart_driver_state( self.path, @@ -161,6 +166,7 @@ def get_driver_state( damping_coefficients, driver_grid_data, grid_data, + schemes, ) return state @@ -201,6 +207,7 @@ def get_driver_state( damping_coefficients: pace.util.grid.DampingCoefficients, driver_grid_data: pace.util.grid.DriverGridData, grid_data: pace.util.grid.GridData, + schemes: List[pace.physics.PHYSICS_PACKAGES], ) -> DriverState: state = _restart_driver_state( self.path, @@ -210,6 +217,7 @@ def get_driver_state( damping_coefficients, driver_grid_data, grid_data, + schemes, ) _update_fortran_restart_pe_peln(state) @@ -272,6 +280,7 @@ def get_driver_state( damping_coefficients: pace.util.grid.DampingCoefficients, driver_grid_data: pace.util.grid.DriverGridData, grid_data: pace.util.grid.GridData, + schemes: List[pace.physics.PHYSICS_PACKAGES], ) -> DriverState: backend = quantity_factory.zeros( dims=[pace.util.X_DIM, pace.util.Y_DIM], units="unknown" @@ -280,7 +289,7 @@ def get_driver_state( dycore_state = self._initialize_dycore_state(communicator, backend) physics_state = pace.physics.PhysicsState.init_zeros( quantity_factory=quantity_factory, - active_packages=["microphysics"], + schemes=schemes, ) tendency_state = TendencyState.init_zeros(quantity_factory=quantity_factory) @@ -349,6 +358,7 @@ def get_driver_state( damping_coefficients: pace.util.grid.DampingCoefficients, driver_grid_data: pace.util.grid.DriverGridData, grid_data: pace.util.grid.GridData, + schemes: List[pace.physics.PHYSICS_PACKAGES], ) -> DriverState: return DriverState( dycore_state=self.dycore_state, diff --git a/driver/pace/driver/state.py b/driver/pace/driver/state.py index 54241b1d..93c99b55 100644 --- a/driver/pace/driver/state.py +++ b/driver/pace/driver/state.py @@ -1,5 +1,6 @@ import dataclasses from dataclasses import fields +from typing import List import xarray as xr @@ -75,6 +76,7 @@ def load_state_from_restart( damping_coefficients: pace.util.grid.DampingCoefficients, driver_grid_data: pace.util.grid.DriverGridData, grid_data: pace.util.grid.GridData, + schemes: List[pace.physics.PHYSICS_PACKAGES], ) -> "DriverState": comm = driver_config.comm_config.get_comm() communicator = pace.util.Communicator.from_layout( @@ -102,6 +104,7 @@ def load_state_from_restart( damping_coefficients=damping_coefficients, driver_grid_data=driver_grid_data, grid_data=grid_data, + schemes=schemes, ) return state @@ -176,6 +179,7 @@ def _restart_driver_state( damping_coefficients: pace.util.grid.DampingCoefficients, driver_grid_data: pace.util.grid.DriverGridData, grid_data: pace.util.grid.GridData, + schemes: List[pace.physics.PHYSICS_PACKAGES], ): fs = pace.util.get_fs(path) @@ -197,12 +201,11 @@ def _restart_driver_state( "restart_dycore_state", ) - active_packages = ["microphysics"] physics_state = pace.physics.PhysicsState.init_zeros( - quantity_factory=quantity_factory, active_packages=active_packages + quantity_factory=quantity_factory, schemes=schemes ) - physics_state.__post_init__(quantity_factory, active_packages) + physics_state.__post_init__(quantity_factory, schemes) tendency_state = TendencyState.init_zeros( quantity_factory=quantity_factory, ) diff --git a/driver/tests/mpi/test_restart.py b/driver/tests/mpi/test_restart.py index 5c2ccde9..92c6f72b 100644 --- a/driver/tests/mpi/test_restart.py +++ b/driver/tests/mpi/test_restart.py @@ -11,6 +11,7 @@ import pace.util from pace.driver import DriverConfig from pace.driver.state import DriverState +from pace.physics import PHYSICS_PACKAGES # The packages we import will import MPI, causing an MPI init, but we don't actually @@ -65,6 +66,7 @@ def test_restart(): damping_coefficients=damping_coefficients, driver_grid_data=driver_grid_data, grid_data=grid_data, + schemes=[PHYSICS_PACKAGES["GFS_microphysics"]], ) assert isinstance(driver_state, DriverState) diff --git a/fv3core/pace/fv3core/initialization/analytic_init.py b/fv3core/pace/fv3core/initialization/analytic_init.py index e8f6b07e..544c6283 100644 --- a/fv3core/pace/fv3core/initialization/analytic_init.py +++ b/fv3core/pace/fv3core/initialization/analytic_init.py @@ -1,15 +1,11 @@ -from enum import Enum, EnumMeta +from enum import Enum import pace.util as fv3util from pace.fv3core.dycore_state import DycoreState +from pace.util import MetaEnumStr from pace.util.grid import GridData -class MetaEnumStr(EnumMeta): - def __contains__(cls, item): - return item in cls.__members__.keys() - - class Cases(Enum, metaclass=MetaEnumStr): baroclinic = "baroclinic" tropicalcyclone = "tropicalcyclone" diff --git a/physics/pace/physics/__init__.py b/physics/pace/physics/__init__.py index 6fdaf68d..8fa30674 100644 --- a/physics/pace/physics/__init__.py +++ b/physics/pace/physics/__init__.py @@ -1,4 +1,4 @@ -from ._config import PhysicsConfig +from ._config import PHYSICS_PACKAGES, PhysicsConfig from .physics_state import PhysicsState from .stencils.microphysics import Microphysics from .stencils.physics import Physics diff --git a/physics/pace/physics/_config.py b/physics/pace/physics/_config.py index 58d4d274..4ce3715a 100644 --- a/physics/pace/physics/_config.py +++ b/physics/pace/physics/_config.py @@ -1,13 +1,20 @@ import dataclasses -from typing import Optional, Tuple +from enum import Enum, unique +from typing import List, Optional, Tuple import f90nml -from pace.util import Namelist, NamelistDefaults +from pace.util import MetaEnumStr, Namelist, NamelistDefaults DEFAULT_INT = 0 DEFAULT_BOOL = False +DEFAULT_SCHEMES = ["GFS_microphysics"] + + +@unique +class PHYSICS_PACKAGES(Enum, metaclass=MetaEnumStr): + GFS_microphysics = "GFS_microphysics" @dataclasses.dataclass @@ -18,6 +25,7 @@ class PhysicsConfig: npy: int = DEFAULT_INT npz: int = DEFAULT_INT nwat: int = DEFAULT_INT + schemes: List = None do_qa: bool = DEFAULT_BOOL c_cracw: float = NamelistDefaults.c_cracw c_paut: float = NamelistDefaults.c_paut @@ -100,6 +108,14 @@ class PhysicsConfig: namelist_override: Optional[str] = None def __post_init__(self): + if self.schemes is None: + self.schemes = DEFAULT_SCHEMES + package_schemes = [] + for scheme in self.schemes: + if scheme not in PHYSICS_PACKAGES: + raise NotImplementedError(f"{scheme} physics scheme not implemented") + package_schemes.append(PHYSICS_PACKAGES[scheme]) + self.schemes = package_schemes if self.namelist_override is not None: try: f90_nml = f90nml.read(self.namelist_override) diff --git a/physics/pace/physics/physics_state.py b/physics/pace/physics/physics_state.py index d764ac04..83463998 100644 --- a/physics/pace/physics/physics_state.py +++ b/physics/pace/physics/physics_state.py @@ -8,6 +8,8 @@ from pace.dsl.typing import Float from pace.physics.stencils.microphysics import MicrophysicsState +from ._config import PHYSICS_PACKAGES + @dataclass() class PhysicsState: @@ -281,13 +283,15 @@ class PhysicsState: } ) quantity_factory: InitVar[pace.util.QuantityFactory] - active_packages: InitVar[List[str]] + schemes: InitVar[List[PHYSICS_PACKAGES]] def __post_init__( - self, quantity_factory: pace.util.QuantityFactory, active_packages: List[str] + self, + quantity_factory: pace.util.QuantityFactory, + schemes: List[PHYSICS_PACKAGES], ): # storage for tendency variables not in PhysicsState - if "microphysics" in active_packages: + if "GFS_microphysics" in [scheme.value for scheme in schemes]: tendency = quantity_factory.zeros( [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "unknown", @@ -317,7 +321,9 @@ def __post_init__( self.microphysics = None @classmethod - def init_zeros(cls, quantity_factory, active_packages: List[str]) -> "PhysicsState": + def init_zeros( + cls, quantity_factory, schemes: List[PHYSICS_PACKAGES] + ) -> "PhysicsState": initial_arrays = {} for _field in fields(cls): if "dims" in _field.metadata.keys(): @@ -329,7 +335,7 @@ def init_zeros(cls, quantity_factory, active_packages: List[str]) -> "PhysicsSta return cls( **initial_arrays, quantity_factory=quantity_factory, - active_packages=active_packages, + schemes=schemes, ) @classmethod @@ -338,7 +344,7 @@ def init_from_storages( storages: Mapping[str, Any], sizer: pace.util.GridSizer, quantity_factory: pace.util.QuantityFactory, - active_packages: List[str], + schemes: List[PHYSICS_PACKAGES], ) -> "PhysicsState": inputs: Dict[str, pace.util.Quantity] = {} for _field in fields(cls): @@ -352,15 +358,13 @@ def init_from_storages( extent=sizer.get_extent(dims), ) inputs[_field.name] = quantity - return cls( - **inputs, quantity_factory=quantity_factory, active_packages=active_packages - ) + return cls(**inputs, quantity_factory=quantity_factory, schemes=schemes) @property def xr_dataset(self): data_vars = {} for name, field_info in self.__dataclass_fields__.items(): - if name not in ["quantity_factory", "active_packages"]: + if name not in ["quantity_factory", "schemes"]: if issubclass(field_info.type, pace.util.Quantity): dims = [ f"{dim_name}_{name}" for dim_name in field_info.metadata["dims"] diff --git a/physics/pace/physics/stencils/physics.py b/physics/pace/physics/stencils/physics.py index a4dfdb2f..2de7d0a8 100644 --- a/physics/pace/physics/stencils/physics.py +++ b/physics/pace/physics/stencils/physics.py @@ -1,5 +1,3 @@ -from typing import List - import gt4py.cartesian.gtscript as gtscript from gt4py.cartesian.gtscript import ( BACKWARD, @@ -10,7 +8,6 @@ interval, log, ) -from typing_extensions import Literal import pace.util import pace.util.constants as constants @@ -24,10 +21,7 @@ from pace.util import X_DIM, Y_DIM, Z_DIM from pace.util.grid import GridData -from .._config import PhysicsConfig - - -PHYSICS_PACKAGES = Literal["microphysics"] +from .._config import PHYSICS_PACKAGES, PhysicsConfig def atmos_phys_driver_statein( @@ -208,8 +202,13 @@ def __init__( quantity_factory: pace.util.QuantityFactory, grid_data: GridData, namelist: PhysicsConfig, - active_packages: List[Literal[PHYSICS_PACKAGES]], ): + schemes = [scheme.value for scheme in namelist.schemes] + for scheme in schemes: + if scheme not in PHYSICS_PACKAGES: + raise NotImplementedError( + f"{scheme} is not an implemented physics parameterization" + ) orchestrate( obj=self, config=stencil_factory.config.dace_config, @@ -249,8 +248,8 @@ def make_quantity(): "pktop": self._pktop, }, ) - if "microphysics" in active_packages: - self._do_microphysics = True + if "GFS_microphysics" in schemes: + self._gfs_microphysics = True self._prepare_microphysics = stencil_factory.from_origin_domain( func=prepare_microphysics, origin=grid_indexing.origin_compute(), @@ -267,7 +266,7 @@ def make_quantity(): stencil_factory, quantity_factory, grid_data, namelist=namelist ) else: - self._do_microphysics = False + self._gfs_microphysics = False def _setup_statein(self): self._NQ = 8 # state.nq_tot - spec.namelist.dnats @@ -311,7 +310,7 @@ def __call__(self, physics_state: PhysicsState, timestep: float): physics_state.phii, physics_state.phil, ) - if self._do_microphysics: + if self._gfs_microphysics: self._prepare_microphysics( physics_state.dz, physics_state.phii, diff --git a/physics/tests/savepoint/translate/translate_driver.py b/physics/tests/savepoint/translate/translate_driver.py index 90da4883..6910df18 100644 --- a/physics/tests/savepoint/translate/translate_driver.py +++ b/physics/tests/savepoint/translate/translate_driver.py @@ -8,7 +8,7 @@ # but also, driver tests should not be in physics from pace.fv3core.testing.translate_fvdynamics import TranslateFVDynamics from pace.fv3core.testing.validation import enable_selective_validation -from pace.physics import PhysicsConfig, PhysicsState +from pace.physics import PHYSICS_PACKAGES, PhysicsConfig, PhysicsState from pace.util.namelist import Namelist @@ -43,7 +43,8 @@ def compute_parallel(self, inputs, communicator): sizer, backend=self.stencil_config.compilation_config.backend ) physics_state = PhysicsState.init_zeros( - quantity_factory=quantity_factory, active_packages=["microphysics"] + quantity_factory=quantity_factory, + schemes=[PHYSICS_PACKAGES["GFS_microphysics"]], ) tendency_state = TendencyState.init_zeros( quantity_factory=quantity_factory, diff --git a/physics/tests/savepoint/translate/translate_gfs_physics_driver.py b/physics/tests/savepoint/translate/translate_gfs_physics_driver.py index 2feb3df9..ab8f870b 100644 --- a/physics/tests/savepoint/translate/translate_gfs_physics_driver.py +++ b/physics/tests/savepoint/translate/translate_gfs_physics_driver.py @@ -2,6 +2,7 @@ import pace.dsl.gt4py_utils as utils import pace.util as util +from pace.physics import PHYSICS_PACKAGES from pace.physics.stencils.physics import Physics, PhysicsState from pace.stencils import update_atmos_state from pace.stencils.testing.translate_physics import TranslatePhysicsFortranData2Py @@ -129,18 +130,17 @@ def compute(self, inputs): quantity_factory = util.QuantityFactory.from_backend( sizer, self.stencil_factory.backend ) - active_packages = ["microphysics"] + schemes = [PHYSICS_PACKAGES["GFS_microphysics"]] physics_state = PhysicsState( **inputs, quantity_factory=quantity_factory, - active_packages=active_packages, + schemes=schemes, ) physics = Physics( self.stencil_factory, self.grid.quantity_factory, self.grid.grid_data, self.namelist, - active_packages=active_packages, ) # TODO, self.namelist doesn't have fv_sg_adj because it is PhysicsConfig # either move where GFSPhysicsDriver starts, or pass the full namelist or diff --git a/physics/tests/savepoint/translate/translate_microphysics.py b/physics/tests/savepoint/translate/translate_microphysics.py index d3259c1a..49ebc4dc 100644 --- a/physics/tests/savepoint/translate/translate_microphysics.py +++ b/physics/tests/savepoint/translate/translate_microphysics.py @@ -5,6 +5,7 @@ import pace.dsl.gt4py_utils as utils import pace.util from pace.dsl.typing import Float +from pace.physics import PHYSICS_PACKAGES from pace.physics.stencils.microphysics import Microphysics from pace.physics.stencils.physics import PhysicsState from pace.stencils.testing.translate_physics import TranslatePhysicsFortranData2Py @@ -88,7 +89,7 @@ def compute(self, inputs): inputs, sizer=sizer, quantity_factory=quantity_factory, - active_packages=["microphysics"], + schemes=[PHYSICS_PACKAGES["GFS_microphysics"]], ) microphysics = Microphysics( self.stencil_factory, quantity_factory, self.grid.grid_data, self.namelist diff --git a/tests/main/driver/test_example_configs.py b/tests/main/driver/test_example_configs.py index 1fc5dec1..5a9dcbcd 100644 --- a/tests/main/driver/test_example_configs.py +++ b/tests/main/driver/test_example_configs.py @@ -14,6 +14,7 @@ TESTED_CONFIGS: List[str] = [ "baroclinic_c12.yaml", "baroclinic_c12_dp.yaml", + "baroclinic_c12_explicit_physics.yaml", "baroclinic_c12_comm_read.yaml", "baroclinic_c12_comm_write.yaml", "baroclinic_c12_null_comm.yaml", diff --git a/tests/main/driver/test_restart_fortran.py b/tests/main/driver/test_restart_fortran.py index d1accdaa..518ceed1 100644 --- a/tests/main/driver/test_restart_fortran.py +++ b/tests/main/driver/test_restart_fortran.py @@ -6,6 +6,7 @@ import pace.driver import pace.util from pace.driver.initialization import FortranRestartInit +from pace.physics import PHYSICS_PACKAGES from pace.util import ( CubedSphereCommunicator, CubedSpherePartitioner, @@ -60,6 +61,7 @@ def test_state_from_fortran_restart(): damping_coefficients=damping_coefficients, driver_grid_data=driver_grid_data, grid_data=grid_data, + schemes=[PHYSICS_PACKAGES["GFS_microphysics"]], ) ds = xr.open_dataset(os.path.join(restart_dir, "fv_core.res.tile1.nc")) np.testing.assert_array_equal( diff --git a/tests/main/driver/test_restart_serial.py b/tests/main/driver/test_restart_serial.py index 51ce615a..4c4542f5 100644 --- a/tests/main/driver/test_restart_serial.py +++ b/tests/main/driver/test_restart_serial.py @@ -10,6 +10,7 @@ from pace.driver import CreatesComm, DriverConfig from pace.driver.driver import RestartConfig from pace.driver.initialization import AnalyticInit +from pace.physics import PHYSICS_PACKAGES from pace.util.null_comm import NullComm @@ -81,6 +82,7 @@ def test_restart_save_to_disk(): damping_coefficients=damping_coefficients, driver_grid_data=driver_grid_data, grid_data=grid_data, + schemes=[PHYSICS_PACKAGES["GFS_microphysics"]], ) time = datetime(2016, 1, 1, 0, 0, 0) diff --git a/tests/main/physics/test_integration.py b/tests/main/physics/test_integration.py index 8d86f80b..84cd6d31 100644 --- a/tests/main/physics/test_integration.py +++ b/tests/main/physics/test_integration.py @@ -72,10 +72,9 @@ def setup_physics(): quantity_factory, grid_data, physics_config, - active_packages=["microphysics"], ) physics_state = pace.physics.PhysicsState.init_zeros( - quantity_factory, active_packages=["microphysics"] + quantity_factory, schemes=[pace.physics.PHYSICS_PACKAGES["GFS_microphysics"]] ) random = np.random.RandomState(0) for field in fields(pace.physics.PhysicsState): diff --git a/util/HISTORY.md b/util/HISTORY.md index 972d7167..0d54eca5 100644 --- a/util/HISTORY.md +++ b/util/HISTORY.md @@ -4,6 +4,7 @@ History latest ------ +- Added `MetaEnumStr` to utils to make enums more functional - Added `fill_for_translate_test` to MetricTerms to fill fields with NaNs only when required for testing - Added `init_cartesian` method to MetricTerms to handle grid generation for orthogonal grids - Added `from_layout` and `size` methods to TileCommunicator and Communicator diff --git a/util/pace/util/__init__.py b/util/pace/util/__init__.py index 8137ab77..deb3b081 100644 --- a/util/pace/util/__init__.py +++ b/util/pace/util/__init__.py @@ -70,6 +70,7 @@ from .quantity import Quantity, QuantityMetadata from .time import FMS_TO_CFTIME_TYPE, datetime64_to_datetime from .units import UnitsError, ensure_equal_units, units_are_equal +from .utils import MetaEnumStr __version__ = "0.10.0" diff --git a/util/pace/util/utils.py b/util/pace/util/utils.py index 737a55a0..9854609d 100644 --- a/util/pace/util/utils.py +++ b/util/pace/util/utils.py @@ -1,3 +1,4 @@ +from enum import EnumMeta from typing import Iterable, Sequence, Tuple, TypeVar, Union import numpy as np @@ -21,6 +22,11 @@ T = TypeVar("T") +class MetaEnumStr(EnumMeta): + def __contains__(cls, item) -> bool: + return item in cls.__members__.keys() + + def list_by_dims( dims: Sequence[str], horizontal_list: Sequence[T], non_horizontal_value: T ) -> Tuple[T, ...]: