Skip to content

Commit

Permalink
Interpolation to unstructured meshes
Browse files Browse the repository at this point in the history
  • Loading branch information
boryasbora committed Dec 21, 2023
1 parent e155735 commit 628512a
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 101 deletions.
276 changes: 177 additions & 99 deletions src/fint/fint.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def save_data(
lon,
lat,
out_path,
unstructured = False
):
"""
Saves the interpolated data to a NetCDF file.
Expand All @@ -516,6 +517,7 @@ def save_data(
lon (np.ndarray): The longitudes of the original dataset.
lat (np.ndarray): The latitudes of the original dataset.
out_path (str): The path to save the output NetCDF file.
unstructured (bool, optional): Interpolation to unstructured mesh. Default is False.
Returns:
None: This function does not return any value.
Expand All @@ -533,16 +535,26 @@ def save_data(
out_time = np.atleast_1d(shifted_time)
else:
out_time = np.atleast_1d(data.time.data[timesteps])

out1 = xr.Dataset(
{variable_name: (["time", "depth", "lat", "lon"], interpolated3d)},
if unstructured is False:
out1 = xr.Dataset(
{variable_name: (["time", "depth", "lat", "lon"], interpolated3d)},
coords={
"time": out_time,
"depth": realdepths,
"lon": (["lon"], x),
"lat": (["lat"], y),
"longitude": (["lat", "lon"], lon),
"latitude": (["lat", "lon"], lat),
},
attrs=data.attrs,
)
else:
out1 = xr.Dataset(
{variable_name: (["time", "depth", "ncells"], interpolated3d)},
coords={
"time": out_time,
"depth": realdepths,
"lon": (["lon"], x),
"lat": (["lat"], y),
"longitude": (["lat", "lon"], lon),
"latitude": (["lat", "lon"], lat),

},
attrs=data.attrs,
)
Expand All @@ -561,18 +573,28 @@ def save_data(
# )

# out1.to_netcdf(out_path, encoding={variable_name: {"zlib": True, "complevel": 9}})
out1.to_netcdf(
out_path,
encoding={
"time": {"dtype": np.dtype("double")},
"depth": {"dtype": np.dtype("double")},
"lat": {"dtype": np.dtype("double")},
"lon": {"dtype": np.dtype("double")},
"longitude": {"dtype": np.dtype("double")},
"latitude": {"dtype": np.dtype("double")},
variable_name: {"zlib": True, "complevel": 1, "dtype": np.dtype("single")},
},
)
if unstructured is False:
out1.to_netcdf(
out_path,
encoding={
"time": {"dtype": np.dtype("double")},
"depth": {"dtype": np.dtype("double")},
"lat": {"dtype": np.dtype("double")},
"lon": {"dtype": np.dtype("double")},
"longitude": {"dtype": np.dtype("double")},
"latitude": {"dtype": np.dtype("double")},
variable_name: {"zlib": True, "complevel": 1, "dtype": np.dtype("single")},
},
)
else:
out1.to_netcdf(
out_path,
encoding={
"time": {"dtype": np.dtype("double")},
"depth": {"dtype": np.dtype("double")},
variable_name: {"zlib": True, "complevel": 1, "dtype": np.dtype("single")},
},
)
# if args.rotate:
# out2.to_netcdf(
# out_path2,
Expand Down Expand Up @@ -743,6 +765,10 @@ def fint(args=None):
"--save_weights",
action = "store_true",
help = "Save CDO weights ac netcdf file at the output directory.")
parser.add_argument(
"--to_fesom_mesh",
action = "store_true",
help = "Interpolate data to another fesom mesh. Works only with the target flag.")


args = parser.parse_args()
Expand Down Expand Up @@ -848,12 +874,27 @@ def fint(args=None):
# define region of interpolation
if args.target is None:
x, y, lon, lat = define_region(args.box, args.res, projection)
if args.to_fesom_mesh:
raise ValueError("target mesh for interpolation is not provided")
else:
x, y, lon, lat = define_region_from_file(args.target)
if args.to_fesom_mesh:
lon, lat = load_mesh(args.target)[:2]
x = 0
y = 0
if placement == "elements":
endswith = "_elements_IFS.nc"
for root, dirs, files in os.walk(args.target):
for file in files:
if file.endswith(endswith):
mesh_xr = xr.open_dataset(os.path.join(root, file))
lon = mesh_xr.grid_center_lon.values
lat = mesh_xr.grid_center_lat.values
else:
x, y, lon, lat = define_region_from_file(args.target)

x2, y2 = match_longitude_format(x2, y2, lon, lat)
# if we want to use shapelly mask, load it
if args.no_shape_mask is False:
if args.no_shape_mask is False and args.to_fesom_mesh is False:
m2 = mask_ne(lon, lat)

# additional variables, that we need for different interplations
Expand Down Expand Up @@ -887,66 +928,75 @@ def fint(args=None):
xesmf_weights_path = out_path.replace(".nc", "xesmf_weights.nc")
ds_w.to_netcdf(xesmf_weights_path)
elif interpolation in ["cdo_remapcon", "cdo_remaplaf", "cdo_remapnn", "cdo_remapdis", "smm_remapcon", "smm_remaplaf", "smm_remapnn", "smm_remapdis"]:
gridtype = 'latlon'
gridsize = x.size*y.size
xsize = x.size
ysize = y.size
xname = 'longitude'
xlongname = 'longitude'
xunits = 'degrees_east'
yname = 'latitude'
ylongname = 'latitude'
yunits = 'degrees_north'
xfirst = float(lon[0,0])
xinc = float(lon[0,1]-lon[0,0])
yfirst = float(lat[0,0])
yinc = float(lat[1,0]-lat[0,0])
grid_mapping = []
grid_mapping_name = []
straight_vertical_longitude_from_pole = []
latitude_of_projection_origin = []
standard_parallel = []
if projection == "np":
gridtype = 'projection'
xlongname = 'x coordinate of projection'
xunits = 'meters'
ylongname = 'y coordinate of projection'
yunits = 'meters'
xfirst = float(x[0])
xinc = float(x[1]-x[0])
yfirst = float(y[0])
yinc = float(y[1]-y[0])
grid_mapping = 'crs'
grid_mapping_name = 'polar_stereographic'
straight_vertical_longitude_from_pole = 0.0
latitude_of_projection_origin = 90.0
standard_parallel = 71.0


formatted_content = f"""\
gridtype = {gridtype}
gridsize = {gridsize}
xsize = {xsize}
ysize = {ysize}
xname = {xname}
xlongname = "{xlongname}"
xunits = "{xunits}"
yname = {yname}
ylongname = "{ylongname}"
yunits = "{yunits}"
xfirst = {xfirst}
xinc = {xinc}
yfirst = {yfirst}
yinc = {yinc}
grid_mapping = {grid_mapping}
grid_mapping_name = {grid_mapping_name}
straight_vertical_longitude_from_pole = {straight_vertical_longitude_from_pole}
latitude_of_projection_origin = {latitude_of_projection_origin}
standard_parallel = {standard_parallel}"""

target_grid_path = out_path.replace(".nc", "target_grid.txt")
with open(target_grid_path, 'w') as file:
file.write(formatted_content)
if args.to_fesom_mesh:
endswith = "_nodes_IFS.nc"
if placement == "elements":
endswith = "_elements_IFS.nc"
for root, dirs, files in os.walk(args.target):
for file in files:
if file.endswith(endswith):
target_grid_path = os.path.join(root, file)
else:
gridtype = 'latlon'
gridsize = x.size*y.size
xsize = x.size
ysize = y.size
xname = 'longitude'
xlongname = 'longitude'
xunits = 'degrees_east'
yname = 'latitude'
ylongname = 'latitude'
yunits = 'degrees_north'
xfirst = float(lon[0,0])
xinc = float(lon[0,1]-lon[0,0])
yfirst = float(lat[0,0])
yinc = float(lat[1,0]-lat[0,0])
grid_mapping = []
grid_mapping_name = []
straight_vertical_longitude_from_pole = []
latitude_of_projection_origin = []
standard_parallel = []
if projection == "np":
gridtype = 'projection'
xlongname = 'x coordinate of projection'
xunits = 'meters'
ylongname = 'y coordinate of projection'
yunits = 'meters'
xfirst = float(x[0])
xinc = float(x[1]-x[0])
yfirst = float(y[0])
yinc = float(y[1]-y[0])
grid_mapping = 'crs'
grid_mapping_name = 'polar_stereographic'
straight_vertical_longitude_from_pole = 0.0
latitude_of_projection_origin = 90.0
standard_parallel = 71.0


formatted_content = f"""\
gridtype = {gridtype}
gridsize = {gridsize}
xsize = {xsize}
ysize = {ysize}
xname = {xname}
xlongname = "{xlongname}"
xunits = "{xunits}"
yname = {yname}
ylongname = "{ylongname}"
yunits = "{yunits}"
xfirst = {xfirst}
xinc = {xinc}
yfirst = {yfirst}
yinc = {yinc}
grid_mapping = {grid_mapping}
grid_mapping_name = {grid_mapping_name}
straight_vertical_longitude_from_pole = {straight_vertical_longitude_from_pole}
latitude_of_projection_origin = {latitude_of_projection_origin}
standard_parallel = {standard_parallel}"""

target_grid_path = out_path.replace(".nc", "target_grid.txt")
with open(target_grid_path, 'w') as file:
file.write(formatted_content)

endswith = "_nodes_IFS.nc"
if placement == "elements":
Expand All @@ -959,15 +1009,27 @@ def fint(args=None):

# we will fill this array with interpolated values
if not args.oneout:
interpolated3d = np.zeros((len(timesteps), len(realdepths), len(y), len(x)))
if args.rotate:
interpolated3d2 = np.zeros(
(len(timesteps), len(realdepths), len(y), len(x))
)
if args.to_fesom_mesh is False:
interpolated3d = np.zeros((len(timesteps), len(realdepths), len(y), len(x)))
if args.rotate:
interpolated3d2 = np.zeros(
(len(timesteps), len(realdepths), len(y), len(x))
)
else:
interpolated3d = np.zeros((len(timesteps), len(realdepths), len(lat)))
if args.rotate:
interpolated3d2 = np.zeros(
(len(timesteps), len(realdepths), len(lat))
)
else:
interpolated3d = np.zeros((1, len(realdepths), len(y), len(x)))
if args.rotate:
interpolated3d2 = np.zeros((1, len(realdepths), len(y), len(x)))
if args.to_fesom_mesh is False:
interpolated3d = np.zeros((1, len(realdepths), len(y), len(x)))
if args.rotate:
interpolated3d2 = np.zeros((1, len(realdepths), len(y), len(x)))
else:
interpolated3d = np.zeros((1, len(realdepths), len(lat)))
if args.rotate:
interpolated3d2 = np.zeros((1, len(realdepths), len(lat)))

# main loop
for t_index, ttime in enumerate(timesteps):
Expand Down Expand Up @@ -1116,19 +1178,29 @@ def fint(args=None):
interpolated[mask] = np.nan
if args.rotate:
interpolated2[mask] = np.nan
elif args.no_shape_mask is False:
elif args.no_shape_mask is False and args.to_fesom_mesh is False:
interpolated[m2] = np.nan
if args.rotate:
interpolated2[m2] = np.nan

if args.oneout:
interpolated3d[0, d_index, :, :] = interpolated
if args.rotate:
interpolated3d2[0, d_index, :, :] = interpolated2
if args.to_fesom_mesh is False:
interpolated3d[0, d_index, :, :] = interpolated
if args.rotate:
interpolated3d2[0, d_index, :, :] = interpolated2
else:
interpolated3d[0, d_index, :] = interpolated
if args.rotate:
interpolated3d2[0, d_index, :] = interpolated2
else:
interpolated3d[t_index, d_index, :, :] = interpolated
if args.rotate:
interpolated3d2[t_index, d_index, :, :] = interpolated2
if args.to_fesom_mesh is False:
interpolated3d[t_index, d_index, :, :] = interpolated
if args.rotate:
interpolated3d2[t_index, d_index, :, :] = interpolated2
else:
interpolated3d[t_index, d_index, :] = interpolated
if args.rotate:
interpolated3d2[t_index, d_index, :] = interpolated2

if args.oneout:
out_path_one = out_path.replace(".nc", f"_{str(t_index).zfill(10)}.nc")
Expand All @@ -1144,6 +1216,7 @@ def fint(args=None):
lon,
lat,
out_path_one,
args.to_fesom_mesh
)
if args.rotate:
out_path_one2 = out_path2.replace(
Expand All @@ -1161,13 +1234,16 @@ def fint(args=None):
lon,
lat,
out_path_one2,
args.to_fesom_mesh
)
if interpolation in ["cdo_remapcon","cdo_remaplaf","cdo_remapnn","cdo_remapdis"]:
os.remove(target_grid_path)
if interpolation in ["cdo_remapcon","cdo_remaplaf","cdo_remapnn","cdo_remapdis"]:
if args.to_fesom_mesh is False:
os.remove(target_grid_path)
if args.save_weights is False and args.weightspath is None and weights_file_path is not None:
os.remove(weights_file_path)
elif interpolation in ["smm_remapcon","smm_remapnn","smm_remaplaf","smm_remapdis"]:
os.remove(target_grid_path)
if args.to_fesom_mesh is False:
os.remove(target_grid_path)

# save data (always 4D array)
if not args.oneout:
Expand All @@ -1183,6 +1259,7 @@ def fint(args=None):
lon,
lat,
out_path,
args.to_fesom_mesh
)
if args.rotate:
save_data(
Expand All @@ -1197,6 +1274,7 @@ def fint(args=None):
lon,
lat,
out_path2,
args.to_fesom_mesh
)


Expand Down
4 changes: 2 additions & 2 deletions src/fint/ut.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def nodes_or_ements(data, variable_name, node_num, elem_num):
"""


if data[variable_name].shape[-1] == node_num:
if node_num in data[variable_name].shape:
return "nodes"
elif data[variable_name].shape[-1] == elem_num:
elif elem_num in data[variable_name].shape :
return "elements"


Expand Down

0 comments on commit 628512a

Please sign in to comment.