Skip to content

Commit

Permalink
add tests for K->C conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
navidcy committed May 14, 2024
1 parent 4115c76 commit e9ec8f8
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 65 deletions.
19 changes: 14 additions & 5 deletions regional_mom6/regional_mom6.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def initial_condition(
+ "in the varnames dictionary. For example, {'x': 'lon', 'y': 'lat'}.\n\n"
+ "Terminating!"
)

if arakawa_grid == "B":
if (
"xq" in varnames.keys()
Expand Down Expand Up @@ -695,6 +696,7 @@ def initial_condition(
+ "in the varnames dictionary. For example, {'xh': 'lonh', 'yh': 'lath', ...}.\n\n"
+ "Terminating!"
)

## Construct the xq, yh and xh, yq grids
ugrid = (
self.hgrid[["x", "y"]]
Expand Down Expand Up @@ -780,8 +782,11 @@ def initial_condition(
)

print("INITIAL CONDITIONS")

## Regrid all fields horizontally.
print("Regridding Velocities...", end="")

print("Regridding Velocities... ", end="")

vel_out = xr.merge(
[
regridder_u(ic_raw_u)
Expand All @@ -792,18 +797,22 @@ def initial_condition(
.rename("v"),
]
)
print("Done.\nRegridding Tracers...")

print("Done.\nRegridding Tracers... ", end="")

tracers_out = xr.merge(
[
regridder_t(ic_raw_tracers[varnames["tracers"][i]]).rename(i)
for i in varnames["tracers"]
]
).rename({"lon": "xh", "lat": "yh", varnames["zl"]: "zl"})
print("Done.\nRegridding Free surface...")

print("Done.\nRegridding Free surface... ", end="")

eta_out = (
regridder_t(ic_raw_eta).rename({"lon": "xh", "lat": "yh"}).rename("eta_t")
) ## eta_t is the name set in MOM_input by default
print("Done.")

## Return attributes to arrays

Expand All @@ -827,7 +836,7 @@ def initial_condition(

## if min(temp) > 100 then assume that units must be degrees K
## (otherwise we can't be on Earth) and convert to degrees C
if np.nanmin(tracers_out["temp"].isel({"zl": 0})) > 100:
if np.min(tracers_out["temp"].isel({"zl": 0})) > 100:
tracers_out["temp"] -= 273.15

## Regrid the fields vertically
Expand Down Expand Up @@ -1910,7 +1919,7 @@ def rectangular_brushcut(self):
del segment_out["lat"]
## Convert temperatures to celsius # use pint
if (
np.nanmin(segment_out[self.tracers["temp"]].isel({self.time: 0, self.z: 0}))
np.min(segment_out[self.tracers["temp"]].isel({self.time: 0, self.z: 0}))
> 100
):
segment_out[self.tracers["temp"]] -= 273.15
Expand Down
181 changes: 121 additions & 60 deletions tests/test_expt_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,88 @@ def test_setup_bathymetry(
bathymetry_file.unlink()


def number_of_gridpoints(longitude_extent, latitude_extent, resolution):
nx = int((longitude_extent[-1] - longitude_extent[0]) / resolution)
ny = int((latitude_extent[-1] - latitude_extent[0]) / resolution)

return nx, ny


def generate_temperature_arrays(nx, ny, number_vertical_layers):

temp_in_C = np.random.randn(ny, nx, number_vertical_layers)

temp_in_C_masked = np.copy(temp_in_C)
temp_in_C_masked[20:24, 32:35, :] = float("nan")

temp_in_K = np.copy(temp_in_C) + 273.15
temp_in_K_masked = np.copy(temp_in_C_masked) + 273.15

if np.nanmin(temp_in_C_masked) == np.min(temp_in_C):
return temp_in_C, temp_in_C_masked, temp_in_K, temp_in_K_masked
else:
return generate_temperature_arrays(nx, ny, number_vertical_layers)


def generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
):
nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)

horizontal_buffer = 5

silly_lat = np.linspace(
latitude_extent[0] - horizontal_buffer,
latitude_extent[1] + horizontal_buffer,
ny,
)
silly_lon = np.linspace(
longitude_extent[0] - horizontal_buffer,
longitude_extent[1] + horizontal_buffer,
nx,
)
silly_depth = np.linspace(0, depth, number_vertical_layers)

return silly_lat, silly_lon, silly_depth


longitude_extent = [-5, 3]
latitude_extent = (0, 10)
date_range = ["2003-01-01 00:00:00", "2003-01-01 00:00:00"]
resolution = 0.1
number_vertical_layers = 5
layer_thickness_ratio = 1
depth = 1000

silly_lat, silly_lon, silly_depth = generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
)

dims = ["silly_lat", "silly_lon", "silly_depth"]

coords = {"silly_lat": silly_lat, "silly_lon": silly_lon, "silly_depth": silly_depth}

mom_run_dir = "rundir/"
mom_input_dir = "inputdir/"
toolpath_dir = "toolpath"
grid_type = "even_spacing"

nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)

temp_in_C, temp_in_C_masked, temp_in_K, temp_in_K_masked = generate_temperature_arrays(
nx, ny, number_vertical_layers
)

temp_C = xr.DataArray(temp_in_C, dims=dims, coords=coords)
temp_K = xr.DataArray(temp_in_K, dims=dims, coords=coords)
temp_C_masked = xr.DataArray(temp_in_C_masked, dims=dims, coords=coords)
temp_K_masked = xr.DataArray(temp_in_K_masked, dims=dims, coords=coords)


@pytest.mark.parametrize(
"temp_dataarray_initial_condition",
[temp_C, temp_C_masked, temp_K, temp_K_masked],
)
@pytest.mark.parametrize(
(
"longitude_extent",
Expand All @@ -116,13 +198,13 @@ def test_setup_bathymetry(
),
[
(
[-5, 5],
(0, 10),
["2003-01-01 00:00:00", "2003-01-01 00:00:00"],
0.1,
5,
1,
1000,
longitude_extent,
latitude_extent,
date_range,
resolution,
number_vertical_layers,
layer_thickness_ratio,
depth,
"rundir/",
"inputdir/",
"toolpath",
Expand All @@ -142,8 +224,22 @@ def test_ocean_forcing(
mom_input_dir,
toolpath_dir,
grid_type,
temp_dataarray_initial_condition,
tmp_path,
):

silly_lat, silly_lon, silly_depth = generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
)

dims = ["silly_lat", "silly_lon", "silly_depth"]

coords = {
"silly_lat": silly_lat,
"silly_lon": silly_lon,
"silly_depth": silly_depth,
}

expt = experiment(
longitude_extent=longitude_extent,
latitude_extent=latitude_extent,
Expand All @@ -160,72 +256,34 @@ def test_ocean_forcing(

## Generate some initial condition to test on

nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)

# initial condition includes, temp, salt, eta, u, v
initial_cond = xr.Dataset(
{
"temp": xr.DataArray(
np.random.random((100, 100, 10)),
dims=["silly_lat", "silly_lon", "silly_depth"],
coords={
"silly_lat": np.linspace(
latitude_extent[0] - 5, latitude_extent[1] + 5, 100
),
"silly_lon": np.linspace(
longitude_extent[0] - 5, longitude_extent[1] + 5, 100
),
"silly_depth": np.linspace(0, 1000, 10),
},
),
"eta": xr.DataArray(
np.random.random((100, 100)),
np.random.random((ny, nx)),
dims=["silly_lat", "silly_lon"],
coords={
"silly_lat": np.linspace(
latitude_extent[0] - 5, latitude_extent[1] + 5, 100
),
"silly_lon": np.linspace(
longitude_extent[0] - 5, longitude_extent[1] + 5, 100
),
"silly_lat": silly_lat,
"silly_lon": silly_lon,
},
),
"temp": temp_dataarray_initial_condition,
"salt": xr.DataArray(
np.random.random((100, 100, 10)),
dims=["silly_lat", "silly_lon", "silly_depth"],
coords={
"silly_lat": np.linspace(
latitude_extent[0] - 5, latitude_extent[1] + 5, 100
),
"silly_lon": np.linspace(
longitude_extent[0] - 5, longitude_extent[1] + 5, 100
),
"silly_depth": np.linspace(0, 1000, 10),
},
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
"u": xr.DataArray(
np.random.random((100, 100, 10)),
dims=["silly_lat", "silly_lon", "silly_depth"],
coords={
"silly_lat": np.linspace(
latitude_extent[0] - 5, latitude_extent[1] + 5, 100
),
"silly_lon": np.linspace(
longitude_extent[0] - 5, longitude_extent[1] + 5, 100
),
"silly_depth": np.linspace(0, 1000, 10),
},
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
"v": xr.DataArray(
np.random.random((100, 100, 10)),
dims=["silly_lat", "silly_lon", "silly_depth"],
coords={
"silly_lat": np.linspace(
latitude_extent[0] - 5, latitude_extent[1] + 5, 100
),
"silly_lon": np.linspace(
longitude_extent[0] - 5, longitude_extent[1] + 5, 100
),
"silly_depth": np.linspace(0, 1000, 10),
},
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
}
)
Expand All @@ -251,6 +309,9 @@ def test_ocean_forcing(
arakawa_grid="A",
)

# ensure that temperature is in degrees C
assert np.nanmin(expt.ic_tracers["temp"]) < 100.0


@pytest.mark.parametrize(
(
Expand Down

0 comments on commit e9ec8f8

Please sign in to comment.