Skip to content

Commit

Permalink
fix circular import problems
Browse files Browse the repository at this point in the history
  • Loading branch information
mlee03 authored and mlee03 committed Dec 19, 2023
1 parent 1314117 commit 4c72d51
Show file tree
Hide file tree
Showing 34 changed files with 94 additions and 120 deletions.
5 changes: 5 additions & 0 deletions driver/examples/configs/baroclinic_c12_write_restart.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,8 @@ physics_config:
hydrostatic: false
nwat: 6
do_qa: true

grid_config:
type: generated
config:
eta_file: "tests/main/input/eta79.nc"
9 changes: 4 additions & 5 deletions driver/pace/driver/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import pace.driver
import pace.dsl
import pace.stencils
import pace.util
import pace.util.grid
from pace.dsl.dace.orchestration import dace_inhibitor
from pace.fv3core.dycore_state import DycoreState
from pace.util.constants import RGRAV
from pace.util.grid.helper import GridData

from .state import DriverState

Expand All @@ -28,7 +27,7 @@ def store(self, time: Union[datetime, timedelta], state: DriverState):
...

@abc.abstractmethod
def store_grid(self, grid_data: pace.util.grid.GridData):
def store_grid(self, grid_data: GridData):
...

@abc.abstractmethod
Expand Down Expand Up @@ -198,7 +197,7 @@ def _get_z_select_state(self, state: DycoreState):
z_select_state.update(zselect.select_data(state))
return z_select_state

def store_grid(self, grid_data: pace.util.grid.GridData):
def store_grid(self, grid_data: GridData):
zarr_grid = {
"lat": grid_data.lat,
"lon": grid_data.lon,
Expand All @@ -218,7 +217,7 @@ class NullDiagnostics(Diagnostics):
def store(self, time: Union[datetime, timedelta], state: DriverState):
pass

def store_grid(self, grid_data: pace.util.grid.GridData):
def store_grid(self, grid_data: GridData):
pass

def cleanup(self):
Expand Down
15 changes: 5 additions & 10 deletions driver/pace/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import pace.dsl
import pace.physics
import pace.stencils
import pace.util
import pace.util.grid
from pace import fv3core
from pace.driver.safety_checks import SafetyChecker
from pace.dsl.dace.dace_config import DaceConfig
Expand All @@ -29,6 +27,7 @@
CubedSphereCommunicator,
TileCommunicator,
)
from pace.util.grid.helper import DampingCoefficients, DriverGridData, GridData
from pace.util.logging import pace_log

from . import diagnostics
Expand Down Expand Up @@ -165,11 +164,7 @@ def get_grid(
self,
communicator: pace.util.Communicator,
quantity_factory: Optional[pace.util.QuantityFactory] = None,
) -> Tuple[
pace.util.grid.DampingCoefficients,
pace.util.grid.DriverGridData,
pace.util.grid.GridData,
]:
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:
if quantity_factory is None:
sizer = pace.util.SubtileGridSizer.from_tile_params(
nx_tile=self.nx_tile,
Expand All @@ -193,9 +188,9 @@ def get_grid(
def get_driver_state(
self,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
quantity_factory: Optional[pace.util.QuantityFactory] = None,
stencil_factory: Optional[pace.dsl.StencilFactory] = None,
) -> DriverState:
Expand Down
12 changes: 5 additions & 7 deletions driver/pace/driver/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,17 @@
import pace.util.grid
from pace.stencils.testing import TranslateGrid
from pace.util import Communicator, QuantityFactory
from pace.util.grid import (
DampingCoefficients,
DriverGridData,
GridData,
MetricTerms,
direct_transform,
)
from pace.util.grid.helper import (
AngleGridData,
ContravariantGridData,
DampingCoefficients,
DriverGridData,
GridData,
HorizontalGridData,
MetricTerms,
VerticalGridData,
)
from pace.util.grid.stretch_transformation import direct_transform
from pace.util.logging import pace_log
from pace.util.namelist import Namelist

Expand Down
51 changes: 25 additions & 26 deletions driver/pace/driver/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
import pace.fv3core.initialization.analytic_init as analytic_init
import pace.physics
import pace.stencils
import pace.util
import pace.util.grid
from pace import fv3core
from pace.dsl.dace.orchestration import DaceConfig
from pace.dsl.stencil import StencilFactory
from pace.dsl.stencil_config import CompilationConfig
from pace.fv3core.testing import TranslateFVDynamics
from pace.stencils.testing import TranslateGrid
from pace.util.grid.helper import DampingCoefficients, DriverGridData, GridData
from pace.util.namelist import Namelist

from .registry import Registry
Expand All @@ -37,9 +36,9 @@ def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
) -> DriverState:
...

Expand Down Expand Up @@ -74,9 +73,9 @@ def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
) -> DriverState:
return self.config.get_driver_state(
quantity_factory=quantity_factory,
Expand Down Expand Up @@ -106,9 +105,9 @@ def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
) -> DriverState:
dycore_state = analytic_init.init_analytic_state(
analytic_init_case=self.case,
Expand Down Expand Up @@ -149,9 +148,9 @@ def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
) -> DriverState:
state = _restart_driver_state(
self.path,
Expand Down Expand Up @@ -198,9 +197,9 @@ def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
) -> DriverState:
state = _restart_driver_state(
self.path,
Expand Down Expand Up @@ -269,9 +268,9 @@ def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
) -> DriverState:
backend = quantity_factory.zeros(
dims=[pace.util.X_DIM, pace.util.Y_DIM], units="unknown"
Expand Down Expand Up @@ -337,18 +336,18 @@ class PredefinedStateInit(Initializer):
dycore_state: fv3core.DycoreState
physics_state: pace.physics.PhysicsState
tendency_state: TendencyState
grid_data: pace.util.grid.GridData
damping_coefficients: pace.util.grid.DampingCoefficients
driver_grid_data: pace.util.grid.DriverGridData
grid_data: GridData
damping_coefficients: DampingCoefficients
driver_grid_data: DriverGridData
start_time: datetime = datetime(2016, 8, 1)

def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
) -> DriverState:
return DriverState(
dycore_state=self.dycore_state,
Expand Down
21 changes: 10 additions & 11 deletions driver/pace/driver/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import pace.dsl.gt4py_utils as gt_utils
import pace.physics
import pace.util
import pace.util.grid
from pace import fv3core
from pace.dsl.typing import Float
from pace.util.grid.helper import DampingCoefficients, DriverGridData, GridData


@dataclasses.dataclass()
Expand Down Expand Up @@ -60,9 +59,9 @@ class DriverState:
dycore_state: fv3core.DycoreState
physics_state: pace.physics.PhysicsState
tendency_state: TendencyState
grid_data: pace.util.grid.GridData
damping_coefficients: pace.util.grid.DampingCoefficients
driver_grid_data: pace.util.grid.DriverGridData
grid_data: GridData
damping_coefficients: DampingCoefficients
driver_grid_data: DriverGridData

# TODO: the driver_config argument here isn't type hinted from
# import due to a circular dependency. This can be fixed by refactoring
Expand All @@ -72,9 +71,9 @@ def load_state_from_restart(
cls,
restart_path: str,
driver_config,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
) -> "DriverState":
comm = driver_config.comm_config.get_comm()
communicator = pace.util.Communicator.from_layout(
Expand Down Expand Up @@ -173,9 +172,9 @@ def _restart_driver_state(
rank: int,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
):
fs = pace.util.get_fs(path)

Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/initialization/analytic_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pace.util as fv3util
from pace.fv3core.dycore_state import DycoreState
from pace.util.grid import GridData
from pace.util.grid.helper import GridData


class MetaEnumStr(EnumMeta):
Expand Down
7 changes: 5 additions & 2 deletions fv3core/pace/fv3core/initialization/init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import pace.util.constants as constants
from pace.dsl.typing import Float
from pace.fv3core.dycore_state import DycoreState
from pace.util.grid import lon_lat_midpoint
from pace.util.grid.gnomonic import get_lonlat_vect, get_unit_vector_direction
from pace.util.grid.gnomonic import (
get_lonlat_vect,
get_unit_vector_direction,
lon_lat_midpoint,
)


# maximum windspeed amplitude - close to windspeed of zonal-mean time-mean
Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/stencils/a2b_ord4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pace.dsl.typing import Float, FloatField, FloatFieldI, FloatFieldIJ
from pace.fv3core.stencils.basic_operations import copy_defn
from pace.util import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from pace.util.grid import GridData
from pace.util.grid.helper import GridData


# comact 4-pt cubic interpolation
Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/stencils/c_sw.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pace.fv3core.stencils.d2a2c_vect import DGrid2AGrid2CGridVectors
from pace.stencils import corners
from pace.util import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from pace.util.grid import GridData
from pace.util.grid.helper import GridData


def zero_delpc_ptc(delpc: FloatField, ptc: FloatField):
Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/stencils/d2a2c_vect.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pace.fv3core.stencils.a2b_ord4 import a1, a2, lagrange_x_func, lagrange_y_func
from pace.stencils import corners
from pace.util import X_DIM, Y_DIM, Z_DIM
from pace.util.grid import GridData
from pace.util.grid.helper import GridData


c1 = -2.0 / 14.0
Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/stencils/d_sw.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pace.fv3core.stencils.xtp_u import advect_u_along_x
from pace.fv3core.stencils.ytp_v import advect_v_along_y
from pace.util import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from pace.util.grid import DampingCoefficients, GridData
from pace.util.grid.helper import DampingCoefficients, GridData


dcon_threshold = 1e-5
Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/stencils/del2cubed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pace.dsl.typing import Float, FloatField, FloatFieldIJ, cast_to_index3d
from pace.fv3core.stencils.basic_operations import copy_defn
from pace.util import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from pace.util.grid import DampingCoefficients
from pace.util.grid.helper import DampingCoefficients


#
Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/stencils/delnflux.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pace.dsl.stencil import StencilFactory, get_stencils_with_varied_bounds
from pace.dsl.typing import Float, FloatField, FloatFieldIJ, FloatFieldK
from pace.util import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from pace.util.grid import DampingCoefficients
from pace.util.grid.helper import DampingCoefficients


def calc_damp(
Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/stencils/divergence_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from pace.fv3core.stencils.d2a2c_vect import contravariant
from pace.util import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from pace.util.grid import DampingCoefficients, GridData
from pace.util.grid.helper import DampingCoefficients, GridData


@gtscript.function
Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/stencils/dyn_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
Z_DIM,
Z_INTERFACE_DIM,
)
from pace.util.grid import DampingCoefficients, GridData
from pace.util.grid.helper import DampingCoefficients, GridData


HUGE_R = 1.0e40
Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/stencils/fv_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pace.fv3core.stencils.remapping import LagrangianToEulerian
from pace.stencils.c2l_ord import CubedToLatLon
from pace.util import X_DIM, Y_DIM, Z_INTERFACE_DIM, Timer, constants
from pace.util.grid import DampingCoefficients, GridData
from pace.util.grid.helper import DampingCoefficients, GridData
from pace.util.logging import pace_log
from pace.util.mpi import MPI

Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/stencils/fvtp2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pace.fv3core.stencils.xppm import XPiecewiseParabolic
from pace.fv3core.stencils.yppm import YPiecewiseParabolic
from pace.util import X_DIM, Y_DIM, Z_DIM
from pace.util.grid import DampingCoefficients, GridData
from pace.util.grid.helper import DampingCoefficients, GridData


@gtscript.function
Expand Down
Loading

0 comments on commit 4c72d51

Please sign in to comment.