Skip to content

Commit

Permalink
Move Active Physics Schemes to Config (#44)
Browse files Browse the repository at this point in the history
* initial commit, need to adapt and run tests

* revising scheme name

* tests pass

* update history

* linting

* changing typehints for physics schemes to enum instead of str

* driver now works with physics config enum, tests pass

* fixed tests

* missed one
  • Loading branch information
oelbert authored Jan 30, 2024
1 parent 34eeea4 commit e7fd3b4
Show file tree
Hide file tree
Showing 21 changed files with 192 additions and 47 deletions.
2 changes: 1 addition & 1 deletion docs/physics/state.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ You can initialize a zero-filled PhysicsState and MicrophysicsState from other P

>>> quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend="numpy")
>>> physics_state = PhysicsState.init_zeros(
... quantity_factory=quantity_factory, active_packages=["microphysics"]
... quantity_factory=quantity_factory, schemes=["GFS_microphysics"]
... )
>>> microphysics_state = physics_state.microphysics
98 changes: 98 additions & 0 deletions driver/examples/configs/baroclinic_c12_explicit_physics.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
stencil_config:
compilation_config:
backend: numpy
rebuild: false
validate_args: true
format_source: false
device_sync: false
initialization:
type: analytic
config:
case: baroclinic
performance_config:
collect_performance: true
experiment_name: c12_baroclinic
nx_tile: 12
nz: 79
dt_atmos: 225
minutes: 15
layout:
- 1
- 1
diagnostics_config:
path: output
output_format: netcdf
names:
- u
- v
- ua
- va
- pt
- delp
- qvapor
- qliquid
- qice
- qrain
- qsnow
- qgraupel
z_select:
- level: 65
names:
- pt
dycore_config:
a_imp: 1.0
beta: 0.
consv_te: 0.
d2_bg: 0.
d2_bg_k1: 0.2
d2_bg_k2: 0.1
d4_bg: 0.15
d_con: 1.0
d_ext: 0.0
dddmp: 0.5
delt_max: 0.002
do_sat_adj: true
do_vort_damp: true
fill: true
hord_dp: 6
hord_mt: 6
hord_tm: 6
hord_tr: 8
hord_vt: 6
hydrostatic: false
k_split: 1
ke_bg: 0.
kord_mt: 9
kord_tm: -9
kord_tr: 9
kord_wz: 9
n_split: 1
nord: 3
nwat: 6
p_fac: 0.05
rf_cutoff: 3000.
rf_fast: true
tau: 10.
vtdm4: 0.06
z_tracer: true
do_qa: true
tau_i2s: 1000.
tau_g2v: 1200.
ql_gen: 0.001
ql_mlt: 0.002
qs_mlt: 0.000001
qi_lim: 1.0
dw_ocean: 0.1
dw_land: 0.15
icloud_f: 0
tau_l2v: 300.
tau_v2l: 90.
fv_sg_adj: 0
n_sponge: 48

physics_config:
hydrostatic: false
nwat: 6
do_qa: true
schemes:
- GFS_microphysics
5 changes: 4 additions & 1 deletion driver/pace/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def get_driver_state(
damping_coefficients=damping_coefficients,
driver_grid_data=driver_grid_data,
grid_data=grid_data,
schemes=self.physics_config.schemes,
)

@classmethod
Expand Down Expand Up @@ -327,6 +328,9 @@ def write_for_restart(
config_dict["initialization"]["type"] = "restart"
config_dict["initialization"]["config"]["start_time"] = time
config_dict["initialization"]["config"]["path"] = restart_path
# convert physics package enum to str
schemes = [scheme.value for scheme in config_dict["physics_config"]["schemes"]]
config_dict["physics_config"]["schemes"] = schemes
# restart config doesn't have 'case'
if "case" in config_dict["initialization"]["config"].keys():
del config_dict["initialization"]["config"]["case"]
Expand Down Expand Up @@ -508,7 +512,6 @@ def exit_instead_of_build(self):
quantity_factory=self.quantity_factory,
grid_data=self.state.grid_data,
namelist=self.config.physics_config,
active_packages=["microphysics"],
)
else:
# Make sure those are set to None to raise any issues
Expand Down
16 changes: 13 additions & 3 deletions driver/pace/driver/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import pathlib
from datetime import datetime
from typing import Callable, ClassVar, Type, TypeVar
from typing import Callable, ClassVar, List, Type, TypeVar

import f90nml

Expand Down Expand Up @@ -40,6 +40,7 @@ def get_driver_state(
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
schemes: List[pace.physics.PHYSICS_PACKAGES],
) -> DriverState:
...

Expand Down Expand Up @@ -77,13 +78,15 @@ def get_driver_state(
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
schemes: List[pace.physics.PHYSICS_PACKAGES],
) -> DriverState:
return self.config.get_driver_state(
quantity_factory=quantity_factory,
communicator=communicator,
damping_coefficients=damping_coefficients,
driver_grid_data=driver_grid_data,
grid_data=grid_data,
schemes=schemes,
)

@classmethod
Expand All @@ -109,6 +112,7 @@ def get_driver_state(
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
schemes: List[pace.physics.PHYSICS_PACKAGES],
) -> DriverState:
dycore_state = analytic_init.init_analytic_state(
analytic_init_case=self.case,
Expand All @@ -120,7 +124,7 @@ def get_driver_state(
comm=communicator,
)
physics_state = pace.physics.PhysicsState.init_zeros(
quantity_factory=quantity_factory, active_packages=["microphysics"]
quantity_factory=quantity_factory, schemes=schemes
)
tendency_state = TendencyState.init_zeros(
quantity_factory=quantity_factory,
Expand Down Expand Up @@ -152,6 +156,7 @@ def get_driver_state(
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
schemes: List[pace.physics.PHYSICS_PACKAGES],
) -> DriverState:
state = _restart_driver_state(
self.path,
Expand All @@ -161,6 +166,7 @@ def get_driver_state(
damping_coefficients,
driver_grid_data,
grid_data,
schemes,
)

return state
Expand Down Expand Up @@ -201,6 +207,7 @@ def get_driver_state(
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
schemes: List[pace.physics.PHYSICS_PACKAGES],
) -> DriverState:
state = _restart_driver_state(
self.path,
Expand All @@ -210,6 +217,7 @@ def get_driver_state(
damping_coefficients,
driver_grid_data,
grid_data,
schemes,
)

_update_fortran_restart_pe_peln(state)
Expand Down Expand Up @@ -272,6 +280,7 @@ def get_driver_state(
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
schemes: List[pace.physics.PHYSICS_PACKAGES],
) -> DriverState:
backend = quantity_factory.zeros(
dims=[pace.util.X_DIM, pace.util.Y_DIM], units="unknown"
Expand All @@ -280,7 +289,7 @@ def get_driver_state(
dycore_state = self._initialize_dycore_state(communicator, backend)
physics_state = pace.physics.PhysicsState.init_zeros(
quantity_factory=quantity_factory,
active_packages=["microphysics"],
schemes=schemes,
)
tendency_state = TendencyState.init_zeros(quantity_factory=quantity_factory)

Expand Down Expand Up @@ -349,6 +358,7 @@ def get_driver_state(
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
schemes: List[pace.physics.PHYSICS_PACKAGES],
) -> DriverState:
return DriverState(
dycore_state=self.dycore_state,
Expand Down
9 changes: 6 additions & 3 deletions driver/pace/driver/state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
from dataclasses import fields
from typing import List

import xarray as xr

Expand Down Expand Up @@ -75,6 +76,7 @@ def load_state_from_restart(
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
schemes: List[pace.physics.PHYSICS_PACKAGES],
) -> "DriverState":
comm = driver_config.comm_config.get_comm()
communicator = pace.util.Communicator.from_layout(
Expand Down Expand Up @@ -102,6 +104,7 @@ def load_state_from_restart(
damping_coefficients=damping_coefficients,
driver_grid_data=driver_grid_data,
grid_data=grid_data,
schemes=schemes,
)
return state

Expand Down Expand Up @@ -176,6 +179,7 @@ def _restart_driver_state(
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
schemes: List[pace.physics.PHYSICS_PACKAGES],
):
fs = pace.util.get_fs(path)

Expand All @@ -197,12 +201,11 @@ def _restart_driver_state(
"restart_dycore_state",
)

active_packages = ["microphysics"]
physics_state = pace.physics.PhysicsState.init_zeros(
quantity_factory=quantity_factory, active_packages=active_packages
quantity_factory=quantity_factory, schemes=schemes
)

physics_state.__post_init__(quantity_factory, active_packages)
physics_state.__post_init__(quantity_factory, schemes)
tendency_state = TendencyState.init_zeros(
quantity_factory=quantity_factory,
)
Expand Down
2 changes: 2 additions & 0 deletions driver/tests/mpi/test_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pace.util
from pace.driver import DriverConfig
from pace.driver.state import DriverState
from pace.physics import PHYSICS_PACKAGES


# The packages we import will import MPI, causing an MPI init, but we don't actually
Expand Down Expand Up @@ -65,6 +66,7 @@ def test_restart():
damping_coefficients=damping_coefficients,
driver_grid_data=driver_grid_data,
grid_data=grid_data,
schemes=[PHYSICS_PACKAGES["GFS_microphysics"]],
)

assert isinstance(driver_state, DriverState)
Expand Down
8 changes: 2 additions & 6 deletions fv3core/pace/fv3core/initialization/analytic_init.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from enum import Enum, EnumMeta
from enum import Enum

import pace.util as fv3util
from pace.fv3core.dycore_state import DycoreState
from pace.util import MetaEnumStr
from pace.util.grid import GridData


class MetaEnumStr(EnumMeta):
def __contains__(cls, item):
return item in cls.__members__.keys()


class Cases(Enum, metaclass=MetaEnumStr):
baroclinic = "baroclinic"
tropicalcyclone = "tropicalcyclone"
Expand Down
2 changes: 1 addition & 1 deletion physics/pace/physics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._config import PhysicsConfig
from ._config import PHYSICS_PACKAGES, PhysicsConfig
from .physics_state import PhysicsState
from .stencils.microphysics import Microphysics
from .stencils.physics import Physics
Expand Down
20 changes: 18 additions & 2 deletions physics/pace/physics/_config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import dataclasses
from typing import Optional, Tuple
from enum import Enum, unique
from typing import List, Optional, Tuple

import f90nml

from pace.util import Namelist, NamelistDefaults
from pace.util import MetaEnumStr, Namelist, NamelistDefaults


DEFAULT_INT = 0
DEFAULT_BOOL = False
DEFAULT_SCHEMES = ["GFS_microphysics"]


@unique
class PHYSICS_PACKAGES(Enum, metaclass=MetaEnumStr):
GFS_microphysics = "GFS_microphysics"


@dataclasses.dataclass
Expand All @@ -18,6 +25,7 @@ class PhysicsConfig:
npy: int = DEFAULT_INT
npz: int = DEFAULT_INT
nwat: int = DEFAULT_INT
schemes: List = None
do_qa: bool = DEFAULT_BOOL
c_cracw: float = NamelistDefaults.c_cracw
c_paut: float = NamelistDefaults.c_paut
Expand Down Expand Up @@ -100,6 +108,14 @@ class PhysicsConfig:
namelist_override: Optional[str] = None

def __post_init__(self):
if self.schemes is None:
self.schemes = DEFAULT_SCHEMES
package_schemes = []
for scheme in self.schemes:
if scheme not in PHYSICS_PACKAGES:
raise NotImplementedError(f"{scheme} physics scheme not implemented")
package_schemes.append(PHYSICS_PACKAGES[scheme])
self.schemes = package_schemes
if self.namelist_override is not None:
try:
f90_nml = f90nml.read(self.namelist_override)
Expand Down
Loading

0 comments on commit e7fd3b4

Please sign in to comment.