Skip to content

Commit

Permalink
Changed grid data read-in unit tests to compare data directly from fi…
Browse files Browse the repository at this point in the history
…le to driver grid data generated from yaml
  • Loading branch information
fmalatino committed Dec 12, 2023
1 parent ada497b commit 6c1acfb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 71 deletions.
42 changes: 6 additions & 36 deletions tests/mpi_54rank/test_external_grid_1x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import pace.util
from pace.driver import Driver, DriverConfig
from pace.util.grid import MetricTerms
from pace.util.mpi import MPIComm


Expand Down Expand Up @@ -40,7 +39,7 @@ def get_tile_num(comm: MPIComm):
def test_extgrid_equals_generated_1x1():

with open(
os.path.join(DIR, "../../driver/examples/configs/test_external_C12_2x2.yaml"),
os.path.join(DIR, "../../driver/examples/configs/test_external_C12_1x1.yaml"),
"r",
) as ext_f:
ext_config = yaml.safe_load(ext_f)
Expand Down Expand Up @@ -93,50 +92,21 @@ def test_extgrid_equals_generated_1x1():
overlap=True,
)

metric_terms_ext = MetricTerms.from_external(
x=lon[subtile_slice_grid],
y=lat[subtile_slice_grid],
dx=dx[subtile_slice_dx],
dy=dy[subtile_slice_dy],
area=area[subtile_slice_area],
quantity_factory=get_quantity_factory(
layout=(1, 1), nx_tile=nx_tile, ny_tile=ny_tile, nz=nz
),
communicator=cube_comm,
grid_type=0,
extdgrid=True,
)

errors = []

if (
ext_driver.state.grid_data.lon.data.any()
!= metric_terms_ext.grid.data[:, :, 0].any()
):
if ext_driver.state.grid_data.lon.data.any() != lon[subtile_slice_grid].any():
errors.append("lon data mismatch between generated and external grid data")

if (
ext_driver.state.grid_data.lat.data.any()
!= metric_terms_ext.grid.data[:, :, 1].any()
):
if ext_driver.state.grid_data.lat.data.any() != lat[subtile_slice_grid].any():
errors.append("lon data mismatch between generated and external grid data")

if (
ext_driver.state.grid_data.dx.data.any()
!= metric_terms_ext._dx.view[:, :].any()
):
if ext_driver.state.grid_data.dx.data.any() != dx[subtile_slice_dx].any():
errors.append("dx data mismatch between generated and external grid data")

if (
ext_driver.state.grid_data.dy.data.any()
!= metric_terms_ext._dy.view[:, :].any()
):
if ext_driver.state.grid_data.dy.data.any() != dy[subtile_slice_dy].any():
errors.append("dy data mismatch between generated and external grid data")

if (
ext_driver.state.grid_data.area.data.any()
!= metric_terms_ext._area.view[:, :].any()
):
if ext_driver.state.grid_data.area.data.any() != area[subtile_slice_area].any():
errors.append("area data mismatch between generated and external grid data")

assert not errors, "errors occured in 1x1:\n{}".format("\n".join(errors))
40 changes: 5 additions & 35 deletions tests/mpi_54rank/test_external_grid_2x2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import pace.util
from pace.driver import Driver, DriverConfig
from pace.util.grid import MetricTerms
from pace.util.mpi import MPIComm


Expand Down Expand Up @@ -93,50 +92,21 @@ def test_extgrid_equals_generated_2x2():
overlap=True,
)

metric_terms_ext = MetricTerms.from_external(
x=lon[subtile_slice_grid],
y=lat[subtile_slice_grid],
dx=dx[subtile_slice_dx],
dy=dy[subtile_slice_dy],
area=area[subtile_slice_area],
quantity_factory=get_quantity_factory(
layout=(2, 2), nx_tile=nx_tile, ny_tile=ny_tile, nz=nz
),
communicator=cube_comm,
grid_type=0,
extdgrid=True,
)

errors = []

if (
ext_driver.state.grid_data.lon.data.any()
!= metric_terms_ext.grid.data[:, :, 0].any()
):
if ext_driver.state.grid_data.lon.data.any() != lon[subtile_slice_grid].any():
errors.append("lon data mismatch between generated and external grid data")

if (
ext_driver.state.grid_data.lat.data.any()
!= metric_terms_ext.grid.data[:, :, 1].any()
):
if ext_driver.state.grid_data.lat.data.any() != lat[subtile_slice_grid].any():
errors.append("lon data mismatch between generated and external grid data")

if (
ext_driver.state.grid_data.dx.data.any()
!= metric_terms_ext._dx.view[:, :].any()
):
if ext_driver.state.grid_data.dx.data.any() != dx[subtile_slice_dx].any():
errors.append("dx data mismatch between generated and external grid data")

if (
ext_driver.state.grid_data.dy.data.any()
!= metric_terms_ext._dy.view[:, :].any()
):
if ext_driver.state.grid_data.dy.data.any() != dy[subtile_slice_dy].any():
errors.append("dy data mismatch between generated and external grid data")

if (
ext_driver.state.grid_data.area.data.any()
!= metric_terms_ext._area.view[:, :].any()
):
if ext_driver.state.grid_data.area.data.any() != area[subtile_slice_area].any():
errors.append("area data mismatch between generated and external grid data")

assert not errors, "errors occured in 2x2:\n{}".format("\n".join(errors))

0 comments on commit 6c1acfb

Please sign in to comment.