From 628512aa03417378e65c7949f7e5a4f960f133af Mon Sep 17 00:00:00 2001 From: Boris Shapkin Date: Thu, 21 Dec 2023 19:13:30 +0100 Subject: [PATCH] Interpolation to unstructured meshes --- src/fint/fint.py | 276 ++++++++++++++++++++++++++++++----------------- src/fint/ut.py | 4 +- 2 files changed, 179 insertions(+), 101 deletions(-) diff --git a/src/fint/fint.py b/src/fint/fint.py index 5c039f6..8f5a4d4 100644 --- a/src/fint/fint.py +++ b/src/fint/fint.py @@ -500,6 +500,7 @@ def save_data( lon, lat, out_path, + unstructured = False ): """ Saves the interpolated data to a NetCDF file. @@ -516,6 +517,7 @@ def save_data( lon (np.ndarray): The longitudes of the original dataset. lat (np.ndarray): The latitudes of the original dataset. out_path (str): The path to save the output NetCDF file. + unstructured (bool, optional): Interpolation to unstructured mesh. Default is False. Returns: None: This function does not return any value. @@ -533,16 +535,26 @@ def save_data( out_time = np.atleast_1d(shifted_time) else: out_time = np.atleast_1d(data.time.data[timesteps]) - - out1 = xr.Dataset( - {variable_name: (["time", "depth", "lat", "lon"], interpolated3d)}, + if unstructured is False: + out1 = xr.Dataset( + {variable_name: (["time", "depth", "lat", "lon"], interpolated3d)}, + coords={ + "time": out_time, + "depth": realdepths, + "lon": (["lon"], x), + "lat": (["lat"], y), + "longitude": (["lat", "lon"], lon), + "latitude": (["lat", "lon"], lat), + }, + attrs=data.attrs, + ) + else: + out1 = xr.Dataset( + {variable_name: (["time", "depth", "ncells"], interpolated3d)}, coords={ "time": out_time, "depth": realdepths, - "lon": (["lon"], x), - "lat": (["lat"], y), - "longitude": (["lat", "lon"], lon), - "latitude": (["lat", "lon"], lat), + }, attrs=data.attrs, ) @@ -561,18 +573,28 @@ def save_data( # ) # out1.to_netcdf(out_path, encoding={variable_name: {"zlib": True, "complevel": 9}}) - out1.to_netcdf( - out_path, - encoding={ - "time": {"dtype": np.dtype("double")}, - "depth": {"dtype": np.dtype("double")}, - "lat": {"dtype": np.dtype("double")}, - "lon": {"dtype": np.dtype("double")}, - "longitude": {"dtype": np.dtype("double")}, - "latitude": {"dtype": np.dtype("double")}, - variable_name: {"zlib": True, "complevel": 1, "dtype": np.dtype("single")}, - }, - ) + if unstructured is False: + out1.to_netcdf( + out_path, + encoding={ + "time": {"dtype": np.dtype("double")}, + "depth": {"dtype": np.dtype("double")}, + "lat": {"dtype": np.dtype("double")}, + "lon": {"dtype": np.dtype("double")}, + "longitude": {"dtype": np.dtype("double")}, + "latitude": {"dtype": np.dtype("double")}, + variable_name: {"zlib": True, "complevel": 1, "dtype": np.dtype("single")}, + }, + ) + else: + out1.to_netcdf( + out_path, + encoding={ + "time": {"dtype": np.dtype("double")}, + "depth": {"dtype": np.dtype("double")}, + variable_name: {"zlib": True, "complevel": 1, "dtype": np.dtype("single")}, + }, + ) # if args.rotate: # out2.to_netcdf( # out_path2, @@ -743,6 +765,10 @@ def fint(args=None): "--save_weights", action = "store_true", help = "Save CDO weights ac netcdf file at the output directory.") + parser.add_argument( + "--to_fesom_mesh", + action = "store_true", + help = "Interpolate data to another fesom mesh. Works only with the target flag.") args = parser.parse_args() @@ -848,12 +874,27 @@ def fint(args=None): # define region of interpolation if args.target is None: x, y, lon, lat = define_region(args.box, args.res, projection) + if args.to_fesom_mesh: + raise ValueError("target mesh for interpolation is not provided") else: - x, y, lon, lat = define_region_from_file(args.target) + if args.to_fesom_mesh: + lon, lat = load_mesh(args.target)[:2] + x = 0 + y = 0 + if placement == "elements": + endswith = "_elements_IFS.nc" + for root, dirs, files in os.walk(args.target): + for file in files: + if file.endswith(endswith): + mesh_xr = xr.open_dataset(os.path.join(root, file)) + lon = mesh_xr.grid_center_lon.values + lat = mesh_xr.grid_center_lat.values + else: + x, y, lon, lat = define_region_from_file(args.target) x2, y2 = match_longitude_format(x2, y2, lon, lat) # if we want to use shapelly mask, load it - if args.no_shape_mask is False: + if args.no_shape_mask is False and args.to_fesom_mesh is False: m2 = mask_ne(lon, lat) # additional variables, that we need for different interplations @@ -887,66 +928,75 @@ def fint(args=None): xesmf_weights_path = out_path.replace(".nc", "xesmf_weights.nc") ds_w.to_netcdf(xesmf_weights_path) elif interpolation in ["cdo_remapcon", "cdo_remaplaf", "cdo_remapnn", "cdo_remapdis", "smm_remapcon", "smm_remaplaf", "smm_remapnn", "smm_remapdis"]: - gridtype = 'latlon' - gridsize = x.size*y.size - xsize = x.size - ysize = y.size - xname = 'longitude' - xlongname = 'longitude' - xunits = 'degrees_east' - yname = 'latitude' - ylongname = 'latitude' - yunits = 'degrees_north' - xfirst = float(lon[0,0]) - xinc = float(lon[0,1]-lon[0,0]) - yfirst = float(lat[0,0]) - yinc = float(lat[1,0]-lat[0,0]) - grid_mapping = [] - grid_mapping_name = [] - straight_vertical_longitude_from_pole = [] - latitude_of_projection_origin = [] - standard_parallel = [] - if projection == "np": - gridtype = 'projection' - xlongname = 'x coordinate of projection' - xunits = 'meters' - ylongname = 'y coordinate of projection' - yunits = 'meters' - xfirst = float(x[0]) - xinc = float(x[1]-x[0]) - yfirst = float(y[0]) - yinc = float(y[1]-y[0]) - grid_mapping = 'crs' - grid_mapping_name = 'polar_stereographic' - straight_vertical_longitude_from_pole = 0.0 - latitude_of_projection_origin = 90.0 - standard_parallel = 71.0 - - - formatted_content = f"""\ - gridtype = {gridtype} - gridsize = {gridsize} - xsize = {xsize} - ysize = {ysize} - xname = {xname} - xlongname = "{xlongname}" - xunits = "{xunits}" - yname = {yname} - ylongname = "{ylongname}" - yunits = "{yunits}" - xfirst = {xfirst} - xinc = {xinc} - yfirst = {yfirst} - yinc = {yinc} - grid_mapping = {grid_mapping} - grid_mapping_name = {grid_mapping_name} - straight_vertical_longitude_from_pole = {straight_vertical_longitude_from_pole} - latitude_of_projection_origin = {latitude_of_projection_origin} - standard_parallel = {standard_parallel}""" - - target_grid_path = out_path.replace(".nc", "target_grid.txt") - with open(target_grid_path, 'w') as file: - file.write(formatted_content) + if args.to_fesom_mesh: + endswith = "_nodes_IFS.nc" + if placement == "elements": + endswith = "_elements_IFS.nc" + for root, dirs, files in os.walk(args.target): + for file in files: + if file.endswith(endswith): + target_grid_path = os.path.join(root, file) + else: + gridtype = 'latlon' + gridsize = x.size*y.size + xsize = x.size + ysize = y.size + xname = 'longitude' + xlongname = 'longitude' + xunits = 'degrees_east' + yname = 'latitude' + ylongname = 'latitude' + yunits = 'degrees_north' + xfirst = float(lon[0,0]) + xinc = float(lon[0,1]-lon[0,0]) + yfirst = float(lat[0,0]) + yinc = float(lat[1,0]-lat[0,0]) + grid_mapping = [] + grid_mapping_name = [] + straight_vertical_longitude_from_pole = [] + latitude_of_projection_origin = [] + standard_parallel = [] + if projection == "np": + gridtype = 'projection' + xlongname = 'x coordinate of projection' + xunits = 'meters' + ylongname = 'y coordinate of projection' + yunits = 'meters' + xfirst = float(x[0]) + xinc = float(x[1]-x[0]) + yfirst = float(y[0]) + yinc = float(y[1]-y[0]) + grid_mapping = 'crs' + grid_mapping_name = 'polar_stereographic' + straight_vertical_longitude_from_pole = 0.0 + latitude_of_projection_origin = 90.0 + standard_parallel = 71.0 + + + formatted_content = f"""\ + gridtype = {gridtype} + gridsize = {gridsize} + xsize = {xsize} + ysize = {ysize} + xname = {xname} + xlongname = "{xlongname}" + xunits = "{xunits}" + yname = {yname} + ylongname = "{ylongname}" + yunits = "{yunits}" + xfirst = {xfirst} + xinc = {xinc} + yfirst = {yfirst} + yinc = {yinc} + grid_mapping = {grid_mapping} + grid_mapping_name = {grid_mapping_name} + straight_vertical_longitude_from_pole = {straight_vertical_longitude_from_pole} + latitude_of_projection_origin = {latitude_of_projection_origin} + standard_parallel = {standard_parallel}""" + + target_grid_path = out_path.replace(".nc", "target_grid.txt") + with open(target_grid_path, 'w') as file: + file.write(formatted_content) endswith = "_nodes_IFS.nc" if placement == "elements": @@ -959,15 +1009,27 @@ def fint(args=None): # we will fill this array with interpolated values if not args.oneout: - interpolated3d = np.zeros((len(timesteps), len(realdepths), len(y), len(x))) - if args.rotate: - interpolated3d2 = np.zeros( - (len(timesteps), len(realdepths), len(y), len(x)) - ) + if args.to_fesom_mesh is False: + interpolated3d = np.zeros((len(timesteps), len(realdepths), len(y), len(x))) + if args.rotate: + interpolated3d2 = np.zeros( + (len(timesteps), len(realdepths), len(y), len(x)) + ) + else: + interpolated3d = np.zeros((len(timesteps), len(realdepths), len(lat))) + if args.rotate: + interpolated3d2 = np.zeros( + (len(timesteps), len(realdepths), len(lat)) + ) else: - interpolated3d = np.zeros((1, len(realdepths), len(y), len(x))) - if args.rotate: - interpolated3d2 = np.zeros((1, len(realdepths), len(y), len(x))) + if args.to_fesom_mesh is False: + interpolated3d = np.zeros((1, len(realdepths), len(y), len(x))) + if args.rotate: + interpolated3d2 = np.zeros((1, len(realdepths), len(y), len(x))) + else: + interpolated3d = np.zeros((1, len(realdepths), len(lat))) + if args.rotate: + interpolated3d2 = np.zeros((1, len(realdepths), len(lat))) # main loop for t_index, ttime in enumerate(timesteps): @@ -1116,19 +1178,29 @@ def fint(args=None): interpolated[mask] = np.nan if args.rotate: interpolated2[mask] = np.nan - elif args.no_shape_mask is False: + elif args.no_shape_mask is False and args.to_fesom_mesh is False: interpolated[m2] = np.nan if args.rotate: interpolated2[m2] = np.nan if args.oneout: - interpolated3d[0, d_index, :, :] = interpolated - if args.rotate: - interpolated3d2[0, d_index, :, :] = interpolated2 + if args.to_fesom_mesh is False: + interpolated3d[0, d_index, :, :] = interpolated + if args.rotate: + interpolated3d2[0, d_index, :, :] = interpolated2 + else: + interpolated3d[0, d_index, :] = interpolated + if args.rotate: + interpolated3d2[0, d_index, :] = interpolated2 else: - interpolated3d[t_index, d_index, :, :] = interpolated - if args.rotate: - interpolated3d2[t_index, d_index, :, :] = interpolated2 + if args.to_fesom_mesh is False: + interpolated3d[t_index, d_index, :, :] = interpolated + if args.rotate: + interpolated3d2[t_index, d_index, :, :] = interpolated2 + else: + interpolated3d[t_index, d_index, :] = interpolated + if args.rotate: + interpolated3d2[t_index, d_index, :] = interpolated2 if args.oneout: out_path_one = out_path.replace(".nc", f"_{str(t_index).zfill(10)}.nc") @@ -1144,6 +1216,7 @@ def fint(args=None): lon, lat, out_path_one, + args.to_fesom_mesh ) if args.rotate: out_path_one2 = out_path2.replace( @@ -1161,13 +1234,16 @@ def fint(args=None): lon, lat, out_path_one2, + args.to_fesom_mesh ) - if interpolation in ["cdo_remapcon","cdo_remaplaf","cdo_remapnn","cdo_remapdis"]: - os.remove(target_grid_path) + if interpolation in ["cdo_remapcon","cdo_remaplaf","cdo_remapnn","cdo_remapdis"]: + if args.to_fesom_mesh is False: + os.remove(target_grid_path) if args.save_weights is False and args.weightspath is None and weights_file_path is not None: os.remove(weights_file_path) elif interpolation in ["smm_remapcon","smm_remapnn","smm_remaplaf","smm_remapdis"]: - os.remove(target_grid_path) + if args.to_fesom_mesh is False: + os.remove(target_grid_path) # save data (always 4D array) if not args.oneout: @@ -1183,6 +1259,7 @@ def fint(args=None): lon, lat, out_path, + args.to_fesom_mesh ) if args.rotate: save_data( @@ -1197,6 +1274,7 @@ def fint(args=None): lon, lat, out_path2, + args.to_fesom_mesh ) diff --git a/src/fint/ut.py b/src/fint/ut.py index e3baac4..a295fe8 100644 --- a/src/fint/ut.py +++ b/src/fint/ut.py @@ -42,9 +42,9 @@ def nodes_or_ements(data, variable_name, node_num, elem_num): """ - if data[variable_name].shape[-1] == node_num: + if node_num in data[variable_name].shape: return "nodes" - elif data[variable_name].shape[-1] == elem_num: + elif elem_num in data[variable_name].shape : return "elements"