Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/ndsl #57

Merged
merged 20 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
8 changes: 2 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ repos:
fv3core/pace/fv3core/stencils/fv_subgridz.py |
fv3core/tests/conftest.py
)$
- id: mypy
name: mypy-util
args: [--config-file, setup.cfg]
files: ^util
- id: mypy
name: mypy-stencils
args: [--config-file, setup.cfg]
Expand All @@ -53,10 +49,10 @@ repos:
- id: mypy
name: mypy-dsl
args: [--config-file, setup.cfg]
files: dsl
files: ndsl
exclude: |
(?x)^(
dsl/pace/dsl/gt4py_utils.py |
ndsl/ndsl/gt4py_utils.py |
)$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
Expand Down
8 changes: 4 additions & 4 deletions driver/examples/stencil_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

import yaml

import ndsl.dsl
import ndsl.util
import pace.driver
import pace.dsl
import pace.util


def has_stencils(object):
for name in dir(object):
try:
stencil_found = isinstance(getattr(object, name), pace.dsl.FrozenStencil)
stencil_found = isinstance(getattr(object, name), ndsl.dsl.FrozenStencil)
except (AttributeError, RuntimeError):
stencil_found = False
if stencil_found:
Expand All @@ -26,7 +26,7 @@ def report_stencils(obj, file: Optional[TextIO]):
print(f"module {module.__name__}, class {obj.__class__.__name__}:", file=file)
all_access_names = collections.defaultdict(list)
for name, value in obj.__dict__.items():
if isinstance(value, pace.dsl.FrozenStencil):
if isinstance(value, ndsl.dsl.FrozenStencil):
print(f" stencil {name}:", file=file)
for arg_name, field_info in value.stencil_object.field_info.items():
if field_info is None:
Expand Down
20 changes: 9 additions & 11 deletions driver/pace/driver/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import os
from typing import Any, ClassVar, List

import pace.driver
import pace.dsl
import pace.stencils
import pace.util
import pace.util.grid
from pace.util.caching_comm import CachingCommReader, CachingCommWriter
from pace.util.comm import Comm
import ndsl.stencils
import ndsl.util
import ndsl.util.grid
from ndsl.util.caching_comm import CachingCommReader, CachingCommWriter
from ndsl.util.comm import Comm

from .registry import Registry

Expand Down Expand Up @@ -86,7 +84,7 @@ class MPICommConfig(CreatesComm):
"""

def get_comm(self):
return pace.util.MPIComm()
return ndsl.util.MPIComm()

def cleanup(self, comm):
pass
Expand All @@ -113,7 +111,7 @@ class NullCommConfig(CreatesComm):
fill_value: float = 0.0

def get_comm(self):
return pace.util.NullComm(
return ndsl.util.NullComm(
rank=self.rank, total_ranks=self.total_ranks, fill_value=self.fill_value
)

Expand Down Expand Up @@ -144,7 +142,7 @@ class WriterCommConfig(CreatesComm):
def get_comm(self) -> CachingCommWriter:
underlying = MPICommConfig().get_comm()
if underlying.Get_rank() in self.ranks:
return pace.util.CachingCommWriter(underlying)
return ndsl.util.CachingCommWriter(underlying)
else:
return underlying

Expand Down Expand Up @@ -181,7 +179,7 @@ class ReaderCommConfig(CreatesComm):

def get_comm(self) -> CachingCommReader:
with open(os.path.join(self.path, f"comm_{self.rank}.pkl"), "rb") as f:
return pace.util.CachingCommReader.load(f)
return ndsl.util.CachingCommReader.load(f)

def cleanup(self, comm: CachingCommWriter):
pass
19 changes: 9 additions & 10 deletions driver/pace/driver/configs/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

import dacite

import pace.driver
import pace.dsl
import pace.stencils
import pace.util
import pace.util.grid
from pace.util.caching_comm import CachingCommReader, CachingCommWriter
import ndsl.dsl
import ndsl.stencils
import ndsl.util
import ndsl.util.grid
from ndsl.util.caching_comm import CachingCommReader, CachingCommWriter


class CreatesComm(abc.ABC):
Expand Down Expand Up @@ -85,7 +84,7 @@ class MPICommConfig(CreatesComm):
"""

def get_comm(self):
return pace.util.MPIComm()
return ndsl.util.MPIComm()

def cleanup(self, comm):
pass
Expand All @@ -112,7 +111,7 @@ class NullCommConfig(CreatesComm):
fill_value: float

def get_comm(self):
return pace.util.NullComm(
return ndsl.util.NullComm(
rank=self.rank, total_ranks=self.total_ranks, fill_value=self.fill_value
)

Expand Down Expand Up @@ -143,7 +142,7 @@ class WriterCommConfig(CreatesComm):
def get_comm(self) -> CachingCommWriter:
underlying = MPICommConfig().get_comm()
if underlying.Get_rank() in self.ranks:
return pace.util.CachingCommWriter(underlying)
return ndsl.util.CachingCommWriter(underlying)
else:
return underlying

Expand Down Expand Up @@ -180,7 +179,7 @@ class ReaderCommConfig(CreatesComm):

def get_comm(self) -> CachingCommReader:
with open(os.path.join(self.path, f"comm_{self.rank}.pkl"), "rb") as f:
return pace.util.CachingCommReader.load(f)
return ndsl.util.CachingCommReader.load(f)

def cleanup(self, comm: CachingCommWriter):
pass
39 changes: 19 additions & 20 deletions driver/pace/driver/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from datetime import datetime, timedelta
from typing import List, Optional, Union

import pace.driver
import pace.dsl
import pace.stencils
import pace.util
import pace.util.grid
from pace.dsl.dace.orchestration import dace_inhibitor
import ndsl.dsl
import ndsl.stencils
import ndsl.util
import ndsl.util.grid
from ndsl.dsl.dace.orchestration import dace_inhibitor
from ndsl.util.constants import RGRAV
from pace.fv3core.dycore_state import DycoreState
from pace.util.constants import RGRAV

from .state import DriverState

Expand All @@ -28,7 +27,7 @@ def store(self, time: Union[datetime, timedelta], state: DriverState):
...

@abc.abstractmethod
def store_grid(self, grid_data: pace.util.grid.GridData):
def store_grid(self, grid_data: ndsl.util.grid.GridData):
...

@abc.abstractmethod
Expand All @@ -48,14 +47,14 @@ def select_data(self, state: DycoreState):
raise ValueError(f"Invalid state variable {name} for level select")
assert len(getattr(state, name).dims) > 2
if getattr(state, name).dims[2] != (
pace.util.Z_DIM or pace.util.Z_INTERFACE_DIM
ndsl.util.Z_DIM or ndsl.util.Z_INTERFACE_DIM
):
raise ValueError(
f"z_select only works for state variables with dimension (x, y, z). \
\n {name} has dimension {getattr(state, name).dims}"
)
var_name = f"{name}_z{self.level}"
output[var_name] = pace.util.Quantity(
output[var_name] = ndsl.util.Quantity(
getattr(state, name).data[:, :, self.level],
dims=getattr(state, name).dims[0:2],
origin=getattr(state, name).origin[0:2],
Expand Down Expand Up @@ -100,7 +99,7 @@ def __post_init__(self):
f"got {self.output_format}"
)

def diagnostics_factory(self, communicator: pace.util.Communicator) -> Diagnostics:
def diagnostics_factory(self, communicator: ndsl.util.Communicator) -> Diagnostics:
"""
Create a diagnostics object.

Expand All @@ -111,18 +110,18 @@ def diagnostics_factory(self, communicator: pace.util.Communicator) -> Diagnosti
if self.path is None:
diagnostics: Diagnostics = NullDiagnostics()
else:
fs = pace.util.get_fs(self.path)
fs = ndsl.util.get_fs(self.path)
if not fs.exists(self.path):
fs.makedirs(self.path, exist_ok=True)
if self.output_format == "zarr":
store = zarr_storage.DirectoryStore(path=self.path)
monitor: pace.util.Monitor = pace.util.ZarrMonitor(
monitor: ndsl.util.Monitor = ndsl.util.ZarrMonitor(
store=store,
partitioner=communicator.partitioner,
mpi_comm=communicator.comm,
)
elif self.output_format == "netcdf":
monitor = pace.util.NetCDFMonitor(
monitor = ndsl.util.NetCDFMonitor(
path=self.path,
communicator=communicator,
time_chunk_size=self.time_chunk_size,
Expand All @@ -146,7 +145,7 @@ class MonitorDiagnostics(Diagnostics):

def __init__(
self,
monitor: pace.util.Monitor,
monitor: ndsl.util.Monitor,
names: List[str],
derived_names: List[str],
z_select: List[ZSelect],
Expand Down Expand Up @@ -198,7 +197,7 @@ def _get_z_select_state(self, state: DycoreState):
z_select_state.update(zselect.select_data(state))
return z_select_state

def store_grid(self, grid_data: pace.util.grid.GridData):
def store_grid(self, grid_data: ndsl.util.grid.GridData):
zarr_grid = {
"lat": grid_data.lat,
"lon": grid_data.lon,
Expand All @@ -218,15 +217,15 @@ class NullDiagnostics(Diagnostics):
def store(self, time: Union[datetime, timedelta], state: DriverState):
pass

def store_grid(self, grid_data: pace.util.grid.GridData):
def store_grid(self, grid_data: ndsl.util.grid.GridData):
pass

def cleanup(self):
pass


def _compute_column_integral(
name: str, q_in: pace.util.Quantity, delp: pace.util.Quantity
name: str, q_in: ndsl.util.Quantity, delp: ndsl.util.Quantity
):
"""
Compute column integrated mixing ratio (e.g., total liquid water path)
Expand All @@ -237,12 +236,12 @@ def _compute_column_integral(
delp: pressure thickness of atmospheric layer
"""
assert len(q_in.dims) > 2
if q_in.dims[2] != pace.util.Z_DIM:
if q_in.dims[2] != ndsl.util.Z_DIM:
raise NotImplementedError(
"this function assumes the z-dimension is the third dimension"
)
k_slice = slice(q_in.origin[2], q_in.origin[2] + q_in.extent[2])
column_integral = pace.util.Quantity(
column_integral = ndsl.util.Quantity(
RGRAV
* q_in.np.sum(q_in.data[:, :, k_slice] * delp.data[:, :, k_slice], axis=2),
dims=tuple(q_in.dims[:2]) + tuple(q_in.dims[3:]),
Expand Down
Loading
Loading