Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doubly periodic dycore config #19

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
52d1cf9
initial commit, first version of d2a2c_vect
oelbert Jul 14, 2023
42bd5d9
doubly periodic implementation for a2b_ord4
oelbert Jul 14, 2023
90ab231
doubly-periodic implementations of update_dwinds_physics and updatedzc
oelbert Jul 14, 2023
0eb042e
fixing domains, initial dp xppm, yppm, xyp, ytp, divergence_corner, c_sw
oelbert Jul 17, 2023
d7907fd
d_sw, smag_corner initial doubly periodic config done
oelbert Jul 18, 2023
912ac38
removed asserts, initial doubly periodic grid should be supported?
oelbert Jul 18, 2023
232d269
c2l and some config cleanup
oelbert Jul 18, 2023
7f6b97f
maybe this will work for the driver?
oelbert Aug 8, 2023
6258b4d
add umax to grid_config
oelbert Aug 8, 2023
969c7e0
updating namelist, adding test config for driver init
oelbert Aug 10, 2023
7fb4471
Merge branch 'main' into feature/dp_dycore
oelbert Aug 11, 2023
90e7160
debugging driver init with dp grid
oelbert Aug 14, 2023
06c5426
fix varname
oelbert Aug 14, 2023
ed401b1
rework grid type to be in grid config
oelbert Aug 15, 2023
b07ce84
test fixes
oelbert Aug 15, 2023
afb02b2
merging bugfix changes
oelbert Aug 15, 2023
1d7b279
bugfixes
oelbert Aug 16, 2023
38ec8d5
fixing dp a2b
oelbert Aug 17, 2023
0c44c60
need to disable a2b_ord4 test for gridtype 4, exploring more of d2a2c
oelbert Aug 17, 2023
190512d
workaround for d2a2c on dp domain
oelbert Aug 17, 2023
f49b3a1
remove breakpoint
oelbert Aug 17, 2023
5f46c1b
add attrs to divergence damping
oelbert Aug 17, 2023
2354291
correcting types
oelbert Aug 17, 2023
3068c83
small cleanup
oelbert Aug 18, 2023
ba5d5e1
merge main
oelbert Aug 18, 2023
509e550
changing type enforcement on communicators, mocking single rank excha…
oelbert Aug 22, 2023
b06af86
prolly not gonna push, making one rank tests work
oelbert Aug 23, 2023
7560069
Merge branch 'feature/dp_dycore' of github.com:oelbert/pace into feat…
oelbert Aug 23, 2023
b98408a
why test no work
oelbert Aug 28, 2023
567400e
undo silly
oelbert Aug 30, 2023
e37484c
reconfigure tests for doubly periodic domains
oelbert Aug 30, 2023
1174563
Merge branch 'feature/dp_dycore' of github.com:oelbert/pace into feat…
oelbert Aug 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions driver/pace/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@

# TODO: move update_atmos_state into pace.driver
from pace.stencils import update_atmos_state
from pace.util.communicator import CubedSphereCommunicator
from pace.util.communicator import (
Communicator,
CubedSphereCommunicator,
TileCommunicator,
)
from pace.util.logging import pace_log

from . import diagnostics
Expand Down Expand Up @@ -90,6 +94,7 @@ class DriverConfig:
nz: int
layout: Tuple[int, int]
dt_atmos: float
grid_type: Optional[int] = 0
grid_config: GridInitializerSelector = dataclasses.field(
default_factory=lambda: GridInitializerSelector(
type="generated", config=GeneratedGridConfig()
Expand Down Expand Up @@ -158,7 +163,7 @@ def apply_tendencies(self) -> bool:

def get_grid(
self,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
quantity_factory: Optional[pace.util.QuantityFactory] = None,
) -> Tuple[
pace.util.grid.DampingCoefficients,
Expand Down Expand Up @@ -187,7 +192,7 @@ def get_grid(

def get_driver_state(
self,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand All @@ -213,7 +218,7 @@ def get_driver_state(
if stencil_factory is None:
grid_indexing = (
pace.dsl.stencil.GridIndexing.from_sizer_and_communicator(
sizer=sizer, cube=communicator
sizer=sizer, comm=communicator
)
)
stencil_factory = pace.dsl.StencilFactory(
Expand Down Expand Up @@ -275,7 +280,7 @@ def from_dict(cls, kwargs: Dict[str, Any]) -> "DriverConfig":
)
grid_type = kwargs["grid_config"].config.grid_type
# Copy grid_type to the DycoreConfig if it's not the default value
if grid_type != 0:
if grid_type != 0:
kwargs["dycore_config"].grid_type = grid_type

if (
Expand Down Expand Up @@ -404,11 +409,19 @@ def __init__(
if self.config.performance_config.collect_communication
else None
)
communicator = CubedSphereCommunicator.from_layout(
comm=self.comm,
layout=self.config.layout,
timer=comm_timer,
)
communicator: Communicator
if self.config.grid_type <= 3:
communicator = CubedSphereCommunicator.from_layout(
comm=self.comm,
layout=self.config.layout,
timer=comm_timer,
)
else:
communicator = TileCommunicator.from_layout(
comm=self.comm,
layout=self.config.layout,
timer=comm_timer,
)
self._update_driver_config_with_communicator(communicator)

if self.config.stencil_config.compilation_config.run_mode == RunMode.Build:
Expand Down Expand Up @@ -544,7 +557,7 @@ def exit_instead_of_build(self):
pace_log.info("initialization of the object done")

def _update_driver_config_with_communicator(
self, communicator: CubedSphereCommunicator
self, communicator: Communicator
) -> None:
dace_config = DaceConfig(
communicator=communicator,
Expand Down Expand Up @@ -707,7 +720,7 @@ def log_subtile_location(partitioner: pace.util.TilePartitioner, rank: int):

def _setup_factories(
config: DriverConfig,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
stencil_compare_comm,
) -> Tuple[pace.util.QuantityFactory, pace.dsl.StencilFactory]:
"""
Expand Down Expand Up @@ -735,7 +748,7 @@ def _setup_factories(
)

grid_indexing = pace.dsl.stencil.GridIndexing.from_sizer_and_communicator(
sizer=sizer, cube=communicator
sizer=sizer, comm=communicator
)
quantity_factory = pace.util.QuantityFactory.from_backend(
sizer, backend=config.stencil_config.compilation_config.backend
Expand Down
14 changes: 7 additions & 7 deletions driver/pace/driver/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pace.stencils
import pace.util.grid
from pace.stencils.testing import TranslateGrid
from pace.util import CubedSphereCommunicator, QuantityFactory
from pace.util import Communicator, QuantityFactory
from pace.util.grid import (
DampingCoefficients,
DriverGridData,
Expand All @@ -35,7 +35,7 @@ class GridInitializer(abc.ABC):
def get_grid(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:
...

Expand All @@ -62,7 +62,7 @@ def register(cls, type_name):
def get_grid(
self,
quantity_factory: QuantityFactory,
communicator: CubedSphereCommunicator,
communicator: Communicator,
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:
return self.config.get_grid(
quantity_factory=quantity_factory, communicator=communicator
Expand Down Expand Up @@ -103,7 +103,7 @@ class GeneratedGridConfig(GridInitializer):
def get_grid(
self,
quantity_factory: QuantityFactory,
communicator: CubedSphereCommunicator,
communicator: Communicator,
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:

metric_terms = MetricTerms(
Expand Down Expand Up @@ -158,7 +158,7 @@ def _f90_namelist(self) -> f90nml.Namelist:
def _namelist(self) -> Namelist:
return Namelist.from_f90nml(self._f90_namelist)

def _serializer(self, communicator: pace.util.CubedSphereCommunicator):
def _serializer(self, communicator: pace.util.Communicator):
import serialbox

serializer = serialbox.Serializer(
Expand All @@ -170,7 +170,7 @@ def _serializer(self, communicator: pace.util.CubedSphereCommunicator):

def _get_serialized_grid(
self,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
backend: str,
) -> pace.stencils.testing.grid.Grid: # type: ignore
ser = self._serializer(communicator)
Expand All @@ -182,7 +182,7 @@ def _get_serialized_grid(
def get_grid(
self,
quantity_factory: QuantityFactory,
communicator: CubedSphereCommunicator,
communicator: Communicator,
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:

backend = quantity_factory.empty(
Expand Down
25 changes: 13 additions & 12 deletions driver/pace/driver/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def start_time(self) -> datetime:
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down Expand Up @@ -74,7 +74,7 @@ def start_time(self) -> datetime:
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down Expand Up @@ -105,11 +105,12 @@ class BaroclinicInit(Initializer):
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
) -> DriverState:
assert isinstance(communicator, pace.util.CubedSphereCommunicator)
dycore_state = baroclinic_init.init_baroclinic_state(
grid_data=grid_data,
quantity_factory=quantity_factory,
Expand Down Expand Up @@ -149,12 +150,12 @@ class TropicalCycloneConfig(Initializer):
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
) -> DriverState:

assert isinstance(communicator, pace.util.CubedSphereCommunicator)
dycore_state = tc_init.init_tc_state(
grid_data=grid_data,
quantity_factory=quantity_factory,
Expand Down Expand Up @@ -198,7 +199,7 @@ class RestartInit(Initializer):
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down Expand Up @@ -247,7 +248,7 @@ def start_time(self) -> datetime:
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down Expand Up @@ -296,7 +297,7 @@ def _namelist(self) -> Namelist:

def _get_serialized_grid(
self,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
backend: str,
) -> pace.stencils.testing.grid.Grid: # type: ignore
ser = self._serializer(communicator)
Expand All @@ -305,7 +306,7 @@ def _get_serialized_grid(
).python_grid()
return grid

def _serializer(self, communicator: pace.util.CubedSphereCommunicator):
def _serializer(self, communicator: pace.util.Communicator):
import serialbox

serializer = serialbox.Serializer(
Expand All @@ -318,7 +319,7 @@ def _serializer(self, communicator: pace.util.CubedSphereCommunicator):
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand All @@ -345,7 +346,7 @@ def get_driver_state(

def _initialize_dycore_state(
self,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
backend: str,
) -> fv3core.DycoreState:

Expand Down Expand Up @@ -396,7 +397,7 @@ class PredefinedStateInit(Initializer):
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down
4 changes: 2 additions & 2 deletions driver/pace/driver/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def load_state_from_restart(
grid_data: pace.util.grid.GridData,
) -> "DriverState":
comm = driver_config.comm_config.get_comm()
communicator = pace.util.CubedSphereCommunicator.from_layout(
communicator = pace.util.Communicator.from_layout(
comm=comm, layout=driver_config.layout
)
sizer = pace.util.SubtileGridSizer.from_tile_params(
Expand Down Expand Up @@ -172,7 +172,7 @@ def _restart_driver_state(
path: str,
rank: int,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down
2 changes: 1 addition & 1 deletion dsl/pace/dsl/dace/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def unblock_waiting_tiles(comm, sdfg_path: str) -> None:
comm.send(sdfg_path, dest=tile * tilesize + comm.Get_rank())


def get_target_rank(rank: int, partitioner: pace.util.CubedSpherePartitioner):
def get_target_rank(rank: int, partitioner: pace.util.Partitioner):
"""From my rank & the current partitioner we determine which
rank we should read from.
For all layout >= 3,3 this presumes build has been done on a
Expand Down
4 changes: 2 additions & 2 deletions dsl/pace/dsl/dace/dace_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pace.dsl.gt4py_utils import is_gpu_backend
from pace.util._optional_imports import cupy as cp
from pace.util.communicator import CubedSphereCommunicator
from pace.util.communicator import Communicator


# This can be turned on to revert compilation for orchestration
Expand Down Expand Up @@ -57,7 +57,7 @@ def __call__(self):
class DaceConfig:
def __init__(
self,
communicator: Optional[CubedSphereCommunicator],
communicator: Optional[Communicator],
backend: str,
tile_nx: int = 0,
tile_nz: int = 0,
Expand Down
4 changes: 2 additions & 2 deletions dsl/pace/dsl/dace/wrapped_halo_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional

from pace.dsl.dace.orchestration import dace_inhibitor
from pace.util.communicator import CubedSphereCommunicator
from pace.util.communicator import Communicator
from pace.util.halo_updater import HaloUpdater


Expand All @@ -21,7 +21,7 @@ def __init__(
state,
qty_x_names: List[str],
qty_y_names: List[str] = None,
comm: Optional[CubedSphereCommunicator] = None,
comm: Optional[Communicator] = None,
) -> None:
self._updater = updater
self._state = state
Expand Down
10 changes: 5 additions & 5 deletions dsl/pace/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,18 +595,18 @@ def domain(self, domain):

@classmethod
def from_sizer_and_communicator(
cls, sizer: pace.util.GridSizer, cube: pace.util.CubedSphereCommunicator
cls, sizer: pace.util.GridSizer, comm: pace.util.Communicator
) -> "GridIndexing":
# TODO: if this class is refactored to split off the *_edge booleans,
# this init routine can be refactored to require only a GridSizer
domain = cast(
Tuple[int, int, int],
sizer.get_extent([pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM]),
)
south_edge = cube.tile.partitioner.on_tile_bottom(cube.rank)
north_edge = cube.tile.partitioner.on_tile_top(cube.rank)
west_edge = cube.tile.partitioner.on_tile_left(cube.rank)
east_edge = cube.tile.partitioner.on_tile_right(cube.rank)
south_edge = comm.tile.partitioner.on_tile_bottom(comm.rank)
north_edge = comm.tile.partitioner.on_tile_top(comm.rank)
west_edge = comm.tile.partitioner.on_tile_left(comm.rank)
east_edge = comm.tile.partitioner.on_tile_right(comm.rank)
return cls(
domain=domain,
n_halo=sizer.n_halo,
Expand Down
Loading
Loading