Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated Runtime Land Block Elimination #263

Merged
merged 10 commits into from
Dec 20, 2023
138 changes: 118 additions & 20 deletions config_src/drivers/nuopc_cap/mom_cap.F90
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module MOM_cap_mod
use MOM_domains, only: MOM_infra_init, MOM_infra_end
use MOM_file_parser, only: get_param, log_version, param_file_type, close_param_file
use MOM_get_input, only: get_MOM_input, directories
use MOM_domains, only: pass_var
use MOM_domains, only: pass_var, pe_here
use MOM_error_handler, only: MOM_error, FATAL, is_root_pe
use MOM_grid, only: ocean_grid_type, get_global_grid_size
use MOM_ocean_model_nuopc, only: ice_ocean_boundary_type
Expand All @@ -29,6 +29,7 @@ module MOM_cap_mod
use MOM_cap_methods, only: med2mod_areacor, state_diagnose
use MOM_cap_methods, only: ChkErr
use MOM_ensemble_manager, only: ensemble_manager_init
use MOM_coms, only: sum_across_PEs

#ifdef CESMCOUPLED
use shr_log_mod, only: shr_log_setLogUnit
Expand Down Expand Up @@ -826,6 +827,7 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
type(ocean_grid_type) , pointer :: ocean_grid
type(ocean_internalstate_wrapper) :: ocean_internalstate
integer :: npet, ntiles
integer :: npes ! number of PEs (from FMS).
integer :: nxg, nyg, cnt
integer :: isc,iec,jsc,jec
integer, allocatable :: xb(:),xe(:),yb(:),ye(:),pe(:)
Expand All @@ -852,6 +854,8 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
integer :: lsize
integer :: ig,jg, ni,nj,k
integer, allocatable :: gindex(:) ! global index space
integer, allocatable :: gindex_ocn(:) ! global index space for ocean cells (excl. masked cells)
integer, allocatable :: gindex_elim(:) ! global index space for eliminated cells
character(len=128) :: fldname
character(len=256) :: cvalue
character(len=256) :: frmt ! format specifier for several error msgs
Expand All @@ -875,6 +879,11 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
real(ESMF_KIND_R8) :: min_areacor_glob(2)
real(ESMF_KIND_R8) :: max_areacor_glob(2)
character(len=*), parameter :: subname='(MOM_cap:InitializeRealize)'
integer :: niproc, njproc
integer :: ip, jp, pe_ix
integer :: num_elim_blocks ! number of blocks to be eliminated
integer :: num_elim_cells_global, num_elim_cells_local, num_elim_cells_remaining
integer, allocatable :: cell_mask(:,:)
!--------------------------------

rc = ESMF_SUCCESS
Expand Down Expand Up @@ -919,19 +928,19 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
rc = ESMF_FAILURE
call ESMF_LogWrite(subname//' ntiles must be 1', ESMF_LOGMSG_ERROR)
endif
ntiles = mpp_get_domain_npes(ocean_public%domain)
write(tmpstr,'(a,1i6)') subname//' ntiles = ',ntiles
npes = mpp_get_domain_npes(ocean_public%domain)
write(tmpstr,'(a,1i6)') subname//' npes = ',npes
call ESMF_LogWrite(trim(tmpstr), ESMF_LOGMSG_INFO)

!---------------------------------
! get start and end indices of each tile and their PET
!---------------------------------

allocate(xb(ntiles),xe(ntiles),yb(ntiles),ye(ntiles),pe(ntiles))
allocate(xb(npes),xe(npes),yb(npes),ye(npes),pe(npes))
call mpp_get_compute_domains(ocean_public%domain, xbegin=xb, xend=xe, ybegin=yb, yend=ye)
call mpp_get_pelist(ocean_public%domain, pe)
if (dbug > 1) then
do n = 1,ntiles
do n = 1,npes
write(tmpstr,'(a,6i6)') subname//' tiles ',n,pe(n),xb(n),xe(n),yb(n),ye(n)
call ESMF_LogWrite(trim(tmpstr), ESMF_LOGMSG_INFO)
enddo
Expand All @@ -953,17 +962,102 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
call get_global_grid_size(ocean_grid, ni, nj)
lsize = ( ocean_grid%iec - ocean_grid%isc + 1 ) * ( ocean_grid%jec - ocean_grid%jsc + 1 )

! Create the global index space for the computational domain
allocate(gindex(lsize))
k = 0
do j = ocean_grid%jsc, ocean_grid%jec
jg = j + ocean_grid%jdg_offset
do i = ocean_grid%isc, ocean_grid%iec
ig = i + ocean_grid%idg_offset
k = k + 1 ! Increment position within gindex
gindex(k) = ni * (jg - 1) + ig
num_elim_blocks = 0
num_elim_cells_global = 0
num_elim_cells_local = 0
num_elim_cells_remaining = 0

! Compute the number of eliminated blocks (specified in MOM_mask_table)
if (associated(ocean_grid%Domain%maskmap)) then
njproc = size(ocean_grid%Domain%maskmap, 1)
niproc = size(ocean_grid%Domain%maskmap, 2)

do ip = 1, niproc
do jp = 1, njproc
if (.not. ocean_grid%Domain%maskmap(jp,ip)) then
num_elim_blocks = num_elim_blocks+1
endif
enddo
enddo
enddo
endif

! Apply land block elimination to ESMF gindex
! (Here we assume that each processor gets assigned a single tile. If multi-tile implementation is to be added
! in MOM6 NUOPC cap in the future, below code must be updated accordingly.)
if (num_elim_blocks>0) then

allocate(cell_mask(ni, nj), source=0)
allocate(gindex_ocn(lsize))
k = 0
do j = ocean_grid%jsc, ocean_grid%jec
jg = j + ocean_grid%jdg_offset
do i = ocean_grid%isc, ocean_grid%iec
ig = i + ocean_grid%idg_offset
k = k + 1 ! Increment position within gindex
gindex_ocn(k) = ni * (jg - 1) + ig
cell_mask(ig, jg) = 1
enddo
enddo
call sum_across_PEs(cell_mask, ni*nj)

if (maxval(cell_mask) /= 1 ) then
call MOM_error(FATAL, "Encountered cells shared by multiple PEs while attempting to determine masked cells.")
endif

num_elim_cells_global = ni * nj - sum(cell_mask)
num_elim_cells_local = num_elim_cells_global / npes

if (pe_here() == pe(npes)) then
! assign all remaining cells to the last PE.
num_elim_cells_remaining = num_elim_cells_global - num_elim_cells_local * npes
allocate(gindex_elim(num_elim_cells_local+num_elim_cells_remaining))
else
allocate(gindex_elim(num_elim_cells_local))
endif

! Zero-based PE index.
pe_ix = pe_here() - pe(1)

k = 0
do jg = 1, nj
do ig = 1, ni
if (cell_mask(ig, jg) == 0) then
k = k + 1
if (k > pe_ix * num_elim_cells_local .and. &
k <= ((pe_ix+1) * num_elim_cells_local + num_elim_cells_remaining)) then
gindex_elim(k - pe_ix * num_elim_cells_local) = ni * (jg -1) + ig
endif
endif
enddo
enddo

allocate(gindex(lsize + num_elim_cells_local + num_elim_cells_remaining))
do k = 1, lsize
gindex(k) = gindex_ocn(k)
enddo
do k = 1, num_elim_cells_local + num_elim_cells_remaining
gindex(k+lsize) = gindex_elim(k)
enddo

deallocate(cell_mask)
deallocate(gindex_ocn)
deallocate(gindex_elim)

else ! no eliminated land blocks

! Create the global index space for the computational domain
allocate(gindex(lsize))
k = 0
do j = ocean_grid%jsc, ocean_grid%jec
jg = j + ocean_grid%jdg_offset
do i = ocean_grid%isc, ocean_grid%iec
ig = i + ocean_grid%idg_offset
k = k + 1 ! Increment position within gindex
gindex(k) = ni * (jg - 1) + ig
enddo
enddo

endif

DistGrid = ESMF_DistGridCreate(arbSeqIndexList=gindex, rc=rc)
if (ChkErr(rc,__LINE__,u_FILE_u)) return
Expand All @@ -987,6 +1081,10 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
call ESMF_MeshGet(Emesh, spatialDim=spatialDim, numOwnedElements=numOwnedElements, rc=rc)
if (ChkErr(rc,__LINE__,u_FILE_u)) return

if (lsize /= numOwnedElements - num_elim_cells_local - num_elim_cells_remaining) then
call MOM_error(FATAL, "Discrepancy detected between ESMF mesh and internal MOM6 domain sizes. Check mask table.")
endif

allocate(ownedElemCoords(spatialDim*numOwnedElements))
allocate(lonMesh(numOwnedElements), lon(numOwnedElements))
allocate(latMesh(numOwnedElements), lat(numOwnedElements))
Expand Down Expand Up @@ -1018,7 +1116,7 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
end do

eps_omesh = get_eps_omesh(ocean_state)
do n = 1,numOwnedElements
do n = 1,lsize
diff_lon = abs(mod(lonMesh(n) - lon(n),360.0))
if (diff_lon > eps_omesh) then
frmt = "('ERROR: Difference between ESMF Mesh and MOM6 domain coords is "//&
Expand Down Expand Up @@ -1122,11 +1220,11 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)

! generate delayout and dist_grid

allocate(deBlockList(2,2,ntiles))
allocate(petMap(ntiles))
allocate(deLabelList(ntiles))
allocate(deBlockList(2,2,npes))
allocate(petMap(npes))
allocate(deLabelList(npes))

do n = 1, ntiles
do n = 1, npes
deLabelList(n) = n
deBlockList(1,1,n) = xb(n)
deBlockList(1,2,n) = xe(n)
Expand Down
9 changes: 8 additions & 1 deletion config_src/drivers/nuopc_cap/mom_cap_methods.F90
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ subroutine State_SetExport(state, fldname, isc, iec, jsc, jec, input, ocean_grid

! local variables
type(ESMF_StateItem_Flag) :: itemFlag
integer :: n, i, j, i1, j1, ig,jg
integer :: n, i, j, k, i1, j1, ig,jg
integer :: lbnd1,lbnd2
real(ESMF_KIND_R8), pointer :: dataPtr1d(:)
real(ESMF_KIND_R8), pointer :: dataPtr2d(:,:)
Expand Down Expand Up @@ -888,6 +888,13 @@ subroutine State_SetExport(state, fldname, isc, iec, jsc, jec, input, ocean_grid
enddo
end if

! if a maskmap is provided, set exports of all eliminated cells to zero.
if (associated(ocean_grid%Domain%maskmap)) then
do k = n+1, size(dataPtr1d)
dataPtr1d(k) = 0.0
enddo
endif

else if (geomtype == ESMF_GEOMTYPE_GRID) then

call state_getfldptr(state, trim(fldname), dataptr2d, rc)
Expand Down
15 changes: 13 additions & 2 deletions config_src/infra/FMS1/MOM_domain_infra.F90
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module MOM_domain_infra
use mpp_domains_mod, only : mpp_create_group_update, mpp_do_group_update
use mpp_domains_mod, only : mpp_reset_group_update_field, mpp_group_update_initialized
use mpp_domains_mod, only : mpp_start_group_update, mpp_complete_group_update
use mpp_domains_mod, only : mpp_compute_block_extent
use mpp_domains_mod, only : mpp_compute_block_extent, mpp_compute_extent
use mpp_domains_mod, only : mpp_broadcast_domain, mpp_redistribute, mpp_global_field
use mpp_domains_mod, only : AGRID, BGRID_NE, CGRID_NE, SCALAR_PAIR, BITWISE_EXACT_SUM
use mpp_domains_mod, only : CYCLIC_GLOBAL_DOMAIN, FOLD_NORTH_EDGE
Expand All @@ -40,7 +40,7 @@ module MOM_domain_infra
public :: domain2D, domain1D, group_pass_type
! These interfaces are actually implemented or have explicit interfaces in this file.
public :: create_MOM_domain, clone_MOM_domain, get_domain_components, get_domain_extent
public :: deallocate_MOM_domain, get_global_shape, compute_block_extent
public :: deallocate_MOM_domain, get_global_shape, compute_block_extent, compute_extent
public :: pass_var, pass_vector, fill_symmetric_edges, rescale_comp_data
public :: pass_var_start, pass_var_complete, pass_vector_start, pass_vector_complete
public :: create_group_pass, do_group_pass, start_group_pass, complete_group_pass
Expand Down Expand Up @@ -1945,6 +1945,17 @@ subroutine compute_block_extent(isg, ieg, ndivs, ibegin, iend)
call mpp_compute_block_extent(isg, ieg, ndivs, ibegin, iend)
end subroutine compute_block_extent

!> Get the array ranges in one dimension for the divisions of a global index space
subroutine compute_extent(isg, ieg, ndivs, ibegin, iend)
integer, intent(in) :: isg !< The starting index of the global index space
integer, intent(in) :: ieg !< The ending index of the global index space
integer, intent(in) :: ndivs !< The number of divisions
integer, dimension(:), intent(out) :: ibegin !< The starting index of each division
integer, dimension(:), intent(out) :: iend !< The ending index of each division

call mpp_compute_extent(isg, ieg, ndivs, ibegin, iend)
end subroutine compute_extent

!> Broadcast a 2-d domain from the root PE to the other PEs
subroutine broadcast_domain(domain)
type(domain2d), intent(inout) :: domain !< The domain2d type that will be shared across PEs.
Expand Down
10 changes: 10 additions & 0 deletions config_src/infra/FMS2/MOM_coms_infra.F90
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ module MOM_coms_infra
interface sum_across_PEs
module procedure sum_across_PEs_int4_0d
module procedure sum_across_PEs_int4_1d
module procedure sum_across_PEs_int4_2d
module procedure sum_across_PEs_int8_0d
module procedure sum_across_PEs_int8_1d
module procedure sum_across_PEs_int8_2d
Expand Down Expand Up @@ -357,6 +358,15 @@ subroutine sum_across_PEs_int4_1d(field, length, pelist)
call mpp_sum(field, length, pelist)
end subroutine sum_across_PEs_int4_1d

!> Find the sum of the values in corresponding positions of field across PEs, and return these sums in field.
subroutine sum_across_PEs_int4_2d(field, length, pelist)
integer(kind=int32), dimension(:,:), intent(inout) :: field !< The values to add, the sums upon return
integer, intent(in) :: length !< Number of elements in field to add
integer, optional, intent(in) :: pelist(:) !< List of PEs to work with

call mpp_sum(field, length, pelist)
end subroutine sum_across_PEs_int4_2d

!> Find the sum of field across PEs, and return this sum in field.
subroutine sum_across_PEs_int8_0d(field, pelist)
integer(kind=int64), intent(inout) :: field !< Value on this PE, and the sum across PEs upon return
Expand Down
17 changes: 14 additions & 3 deletions config_src/infra/FMS2/MOM_domain_infra.F90
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module MOM_domain_infra
use mpp_domains_mod, only : mpp_create_group_update, mpp_do_group_update
use mpp_domains_mod, only : mpp_reset_group_update_field, mpp_group_update_initialized
use mpp_domains_mod, only : mpp_start_group_update, mpp_complete_group_update
use mpp_domains_mod, only : mpp_compute_block_extent
use mpp_domains_mod, only : mpp_compute_block_extent, mpp_compute_extent
use mpp_domains_mod, only : mpp_broadcast_domain, mpp_redistribute, mpp_global_field
use mpp_domains_mod, only : AGRID, BGRID_NE, CGRID_NE, SCALAR_PAIR, BITWISE_EXACT_SUM
use mpp_domains_mod, only : CYCLIC_GLOBAL_DOMAIN, FOLD_NORTH_EDGE
Expand All @@ -38,7 +38,7 @@ module MOM_domain_infra
public :: domain2D, domain1D, group_pass_type
! These interfaces are actually implemented or have explicit interfaces in this file.
public :: create_MOM_domain, clone_MOM_domain, get_domain_components, get_domain_extent
public :: deallocate_MOM_domain, get_global_shape, compute_block_extent
public :: deallocate_MOM_domain, get_global_shape, compute_block_extent, compute_extent
public :: pass_var, pass_vector, fill_symmetric_edges, rescale_comp_data
public :: pass_var_start, pass_var_complete, pass_vector_start, pass_vector_complete
public :: create_group_pass, do_group_pass, start_group_pass, complete_group_pass
Expand Down Expand Up @@ -1936,7 +1936,7 @@ subroutine get_global_shape(domain, niglobal, njglobal)
njglobal = domain%njglobal
end subroutine get_global_shape

!> Get the array ranges in one dimension for the divisions of a global index space
!> Get the array ranges in one dimension for the divisions of a global index space (alternative to compute_extent)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but if you can be bothered...

A better description of compute_block_extent vs compute_extent would be nice here, although I would struggle to write one myself. From what I could tell, compute_extent is much more complex (and presumably safer).

subroutine compute_block_extent(isg, ieg, ndivs, ibegin, iend)
integer, intent(in) :: isg !< The starting index of the global index space
integer, intent(in) :: ieg !< The ending index of the global index space
Expand All @@ -1947,6 +1947,17 @@ subroutine compute_block_extent(isg, ieg, ndivs, ibegin, iend)
call mpp_compute_block_extent(isg, ieg, ndivs, ibegin, iend)
end subroutine compute_block_extent

!> Get the array ranges in one dimension for the divisions of a global index space
subroutine compute_extent(isg, ieg, ndivs, ibegin, iend)
integer, intent(in) :: isg !< The starting index of the global index space
integer, intent(in) :: ieg !< The ending index of the global index space
integer, intent(in) :: ndivs !< The number of divisions
integer, dimension(:), intent(out) :: ibegin !< The starting index of each division
integer, dimension(:), intent(out) :: iend !< The ending index of each division

call mpp_compute_extent(isg, ieg, ndivs, ibegin, iend)
end subroutine compute_extent

!> Broadcast a 2-d domain from the root PE to the other PEs
subroutine broadcast_domain(domain)
type(domain2d), intent(inout) :: domain !< The domain2d type that will be shared across PEs.
Expand Down
Loading