Skip to content

Commit

Permalink
GEOS GridTools stencils build override (#27)
Browse files Browse the repository at this point in the history
* Stencil build override for GEOS

* Deactivate warnings if PACE_LOGLEVEL is > WARNING

* Better log level
  • Loading branch information
FlorianDeconinck authored Sep 8, 2023
1 parent e11a5ed commit 0ba419a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 14 deletions.
2 changes: 1 addition & 1 deletion dsl/pace/dsl/dace/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def set_distributed_caches(config: "DaceConfig"):
verb = "reading"

gt_config.cache_settings["dir_name"] = get_cache_directory(config.code_path)
pace.util.pace_log.critical(
pace.util.pace_log.info(
f"[{orchestration_mode}] Rank {config.my_rank} "
f"{verb} cache {gt_config.cache_settings['dir_name']}"
)
82 changes: 69 additions & 13 deletions fv3core/pace/fv3core/initialization/geos_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,76 @@
import enum
import logging
import os
from datetime import timedelta
from typing import Dict, List, Tuple

import f90nml
import numpy as np
from gt4py.cartesian.config import build_settings as gt_build_settings
from mpi4py import MPI

import pace.util
from pace import fv3core
from pace.driver.performance.collector import PerformanceCollector
from pace.dsl.dace import DaceConfig, orchestrate
from pace.dsl.dace import orchestrate
from pace.dsl.dace.build import set_distributed_caches
from pace.dsl.dace.dace_config import DaceConfig, DaCeOrchestration
from pace.dsl.gt4py_utils import is_gpu_backend
from pace.dsl.typing import floating_point_precision
from pace.util._optional_imports import cupy as cp
from pace.util.logging import pace_log


class StencilBackendCompilerOverride:
"""Override the Pace global stencil JIT to allow for 9-rank build
on any setup.
This is a workaround that requires to now _exactly_ when build is happening.
Using this as a context manager, we leverage the DaCe build system to override
the name and build the 9 codepaths required- while every other rank wait.
This should be removed when we refactor the GT JIT to distribute building
much more efficiently
"""

def __init__(self, comm: MPI.Intracomm, config: DaceConfig):
self.comm = comm
self.config = config

# Orchestration or mono-node is not concerned
self.no_op = self.config.is_dace_orchestrated() or self.comm.Get_size() == 1

# We abuse the DaCe build system
if not self.no_op:
config._orchestrate = DaCeOrchestration.Build
set_distributed_caches(config)
config._orchestrate = DaCeOrchestration.Python

# We remove warnings from the stencils compiling when in critical and/or
# error
if pace_log.level > logging.WARNING:
gt_build_settings["extra_compile_args"]["cxx"].append("-w")
gt_build_settings["extra_compile_args"]["cuda"].append("-w")

def __enter__(self):
if self.no_op:
return
if self.config.do_compile:
pace_log.info(f"Stencil backend compiles on {self.comm.Get_rank()}")
else:
pace_log.info(f"Stencil backend waits on {self.comm.Get_rank()}")
self.comm.Barrier()

def __exit__(self, type, value, traceback):
if self.no_op:
return
if not self.config.do_compile:
pace_log.info(f"Stencil backend read cache on {self.comm.Get_rank()}")
else:
pace_log.info(f"Stencil backend compiled on {self.comm.Get_rank()}")
self.comm.Barrier()


@enum.unique
class MemorySpace(enum.Enum):
HOST = 0
Expand Down Expand Up @@ -113,17 +168,18 @@ def __init__(
metric_terms
)

self.dynamical_core = fv3core.DynamicalCore(
comm=self.communicator,
grid_data=grid_data,
stencil_factory=stencil_factory,
quantity_factory=quantity_factory,
damping_coefficients=damping_coefficients,
config=self.dycore_config,
timestep=timedelta(seconds=self.dycore_state.bdt),
phis=self.dycore_state.phis,
state=self.dycore_state,
)
with StencilBackendCompilerOverride(MPI.COMM_WORLD, stencil_config.dace_config):
self.dynamical_core = fv3core.DynamicalCore(
comm=self.communicator,
grid_data=grid_data,
stencil_factory=stencil_factory,
quantity_factory=quantity_factory,
damping_coefficients=damping_coefficients,
config=self.dycore_config,
timestep=timedelta(seconds=self.dycore_state.bdt),
phis=self.dycore_state.phis,
state=self.dycore_state,
)

self._fortran_mem_space = fortran_mem_space
self._pace_mem_space = (
Expand Down Expand Up @@ -154,7 +210,7 @@ def __init__(
f" orchestration : {self._is_orchestrated}\n"
f" sizer : {sizer.nx}x{sizer.ny}x{sizer.nz}"
f"(halo: {sizer.n_halo})\n"
f" {device_ordinal_info}"
f" device ordinal : {device_ordinal_info}\n"
f" Nvidia MPS : {MPS_is_on}"
)

Expand Down

0 comments on commit 0ba419a

Please sign in to comment.