diff --git a/environment.yml b/environment.yml index 157b849..a2b85e9 100644 --- a/environment.yml +++ b/environment.yml @@ -14,6 +14,8 @@ dependencies: - cartopy - shapely - cdo + - xesmf + - sparse - pip: - smmregrid diff --git a/src/fint/fint.py b/src/fint/fint.py index 64c8a55..5c039f6 100644 --- a/src/fint/fint.py +++ b/src/fint/fint.py @@ -6,6 +6,8 @@ import numpy as np import pandas as pd import xarray as xr +import xesmf as xe +import sparse import subprocess from smmregrid import Regridder from scipy.interpolate import ( @@ -266,25 +268,26 @@ def interpolate_linear_scipy(data_in, x2, y2, lon2, lat2): return interpolated -def interpolate_cdo(target_grid,gridfile,original_file,output_file,variable_name,interpolation, mask_zero=True): +def interpolate_cdo(target_grid,gridfile,original_file,output_file,variable_name,interpolation,weights_file_path, mask_zero=True): """ - Interpolate a variable in a file using CDO (Climate Data Operators). + Interpolate a climate variable in a NetCDF file using Climate Data Operators (CDO). Args: - target_grid (str): Path to the target grid file. - gridfile (str): Path to the grid file associated with the original file. - original_file (str): Path to the original file containing the variable to be interpolated. - output_file (str): Path to the output file where the interpolated variable will be saved. + target_grid (str): Path to the target grid file (NetCDF format) for interpolation. + gridfile (str): Path to the grid file (NetCDF format) associated with the original data. + original_file (str): Path to the original NetCDF file containing the variable to be interpolated. + output_file (str): Path to the output NetCDF file where the interpolated variable will be saved. variable_name (str): Name of the variable to be interpolated. - interpolation (str): Interpolation method to be used (cdo_remapcon,cdo_remaplaf,cdo_remapnn, cdo_remapdis). + interpolation (str): Interpolation method to be used (e.g., 'remapcon', 'remaplaf', 'remapnn', 'remapdis'). + weights_file_path (str): Path to the weights file generated by CDO for interpolation. + mask_zero (bool, optional): Whether to mask zero values in the output. Default is True. Returns: np.ndarray: Interpolated variable data as a NumPy array. """ - command = [ "cdo", - f"-{interpolation.split('_')[1]},{target_grid}", + f"-remap,{target_grid},{weights_file_path}", f"-setgrid,{gridfile}", f"{original_file}", f"{output_file}" @@ -299,34 +302,88 @@ def interpolate_cdo(target_grid,gridfile,original_file,output_file,variable_name os.remove(output_file) return interpolated -def generate_cdo_weights(target_grid,gridfile,original_file,output_file,interpolation): +def generate_cdo_weights(target_grid,gridfile,original_file,output_file,interpolation, save = False): """ - Generate CDO weights for interpolation using ssmregrid. + Generate CDO interpolation weights for smmregrid and cdo interpolation. Args: - target_grid (str): Path to the target grid file. - gridfile (str): Path to the grid file associated with the original file. - original_file (str): Path to the original file containing the data to be remapped. - output_file (str): Path to the output file where the weights will be saved. + target_grid (str): Path to the target grid file (NetCDF format). + gridfile (str): Path to the grid file (NetCDF format) associated with the original data. + original_file (str): Path to the original NetCDF file containing the data to be remapped. + output_file (str): Path to the output NetCDF file where the weights will be saved. + interpolation (str): Interpolation method to be used. + save (bool, optional): Whether to save the weights file. Default is False. Returns: - xr.Dataset: Generated weights as an xarray Dataset. + xr.Dataset: Generated interpolation weights as an xarray Dataset. """ + + int_method = interpolation.split('_')[-1][5:] + command = [ "cdo", - f"-gen{interpolation.split('_')[-1]},{target_grid}", + f"-gen{int_method},{target_grid}", f"-setgrid,{gridfile}", f"{original_file}", f"{output_file}" ] - # Execute the command subprocess.run(command) - + weights = xr.open_dataset(output_file) - os.remove(output_file) + if save == False: + os.remove(output_file) + return weights +def xesmf_weights_to_xarray(regridder): + """ + Converts xESMF regridder weights to an xarray Dataset. + + This function takes a regridder object from xESMF, extracts the weights data, + and converts it into an xarray Dataset with relevant dimensions and attributes. + + Args: + regridder (xesmf.Regridder): A regridder object created using xESMF, which contains the weights to be converted. + + Returns: + xr.Dataset: An xarray Dataset containing the weights data with dimensions 'n_s', 'col', and 'row'. + """ + w = regridder.weights.data + dim = 'n_s' + ds = xr.Dataset( + { + 'S': (dim, w.data), + 'col': (dim, w.coords[1, :] + 1), + 'row': (dim, w.coords[0, :] + 1), + } + ) + ds.attrs = {'n_in': regridder.n_in, 'n_out': regridder.n_out} + return ds + +def reconstruct_xesmf_weights(ds_w): + """ + Reconstruct weights into a format that xESMF understands. + + This function takes a dataset with weights in a specific format and converts + it into a format that can be used by xESMF for regridding. + + Args: + ds_w (xarray.Dataset): The input dataset containing weights data. + It should have 'S', 'col', 'row', and appropriate attributes 'n_out' and 'n_in'. + + Returns: + xarray.DataArray: A DataArray containing reconstructed weights in COO format suitable for use with xESMF. + """ + col = ds_w['col'].values - 1 + row = ds_w['row'].values - 1 + s = ds_w['S'].values + n_out, n_in = ds_w.attrs['n_out'], ds_w.attrs['n_in'] + crds = np.stack([row, col]) + return xr.DataArray( + sparse.COO(crds, s, (n_out, n_in)), dims=('out_dim', 'in_dim'), name='weights' + ) + def parse_depths(depths, depths_from_file): """ Parses the selected depths from the available depth values and returns the corresponding depth indices and values. @@ -614,7 +671,7 @@ def fint(args=None): "--interp", choices=["nn", "mtri_linear", "linear_scipy", "cdo_remapcon","cdo_remaplaf","cdo_remapnn", "cdo_remapdis", - "smm_con","smm_laf","smm_nn","smm_dis"], # "idist", "linear", "cubic"], + "smm_remapcon","smm_remaplaf","smm_remapnn", "smm_remapdis", "xesmf_nearest_s2d"], # "idist", "linear", "cubic"], default="nn", help="Interpolation method. Options are \ nn - nearest neighbor (KDTree implementation, fast), \ @@ -678,6 +735,15 @@ def fint(args=None): Valid units are 'D' (days), 'h' (hours), 'm' (minutes), 's' (seconds). \ To substract timedelta, put argument in quotes, and prepend ' -', so SPACE and then -, e.g. ' -10D'.", ) + parser.add_argument( + "--weightspath", + type=str, + help="File with CDO weiths for smm interpolation.") + parser.add_argument( + "--save_weights", + action = "store_true", + help = "Save CDO weights ac netcdf file at the output directory.") + args = parser.parse_args() @@ -797,7 +863,30 @@ def fint(args=None): trifinder = triang2.get_trifinder() elif interpolation == "nn": distances, inds = create_indexes_and_distances(x2, y2, lon, lat, k=1, workers=4) - elif interpolation in ["cdo_remapcon", "cdo_remaplaf", "cdo_remapnn", "cdo_remapdis", "smm_con", "smm_laf", "smm_nn", "smm_dis"]: + elif interpolation in ['xesmf_nearest_s2d']: + ds_in = xr.open_dataset(args.data) + ds_in = ds_in.assign_coords(lat=('nod2',y2), lon=('nod2',x2)) + ds_in.lat.attrs = {'units': 'degrees', 'standard_name': 'latitude'} + ds_in.lon.attrs = {'units': 'degrees', 'standard_name': 'longitude'} + ds_out = xr.Dataset( + { + 'x': xr.DataArray(x, dims=['x']), + 'y': xr.DataArray(y, dims=['y']), + 'lat': xr.DataArray(lat, dims=['y', 'x']), + 'lon': xr.DataArray(lon, dims=['y', 'x']), + }) + if args.weightspath is not None: + xesmf_weights_path = args.weightspath + ds_w = xr.open_dataset(xesmf_weights_path) + wegiths_xesmf = reconstruct_xesmf_weights(ds_w) + regridder = xe.Regridder(ds_in,ds_out, method='nearest_s2d', weights=wegiths_xesmf,locstream_in=True) + else: + regridder = xe.Regridder(ds_in, ds_out, method='nearest_s2d', locstream_in=True) + if args.save_weights is True: + ds_w = xesmf_weights_to_xarray(regridder) + xesmf_weights_path = out_path.replace(".nc", "xesmf_weights.nc") + ds_w.to_netcdf(xesmf_weights_path) + elif interpolation in ["cdo_remapcon", "cdo_remaplaf", "cdo_remapnn", "cdo_remapdis", "smm_remapcon", "smm_remaplaf", "smm_remapnn", "smm_remapdis"]: gridtype = 'latlon' gridsize = x.size*y.size xsize = x.size @@ -962,38 +1051,63 @@ def fint(args=None): variable_name: {"dtype": np.dtype("double")}, }, ) + if args.weightspath is not None: + weights_file_path = args.weightspath + else: + if t_index == 0 and d_index == 0: + weights_file_path = out_path.replace(".nc", "weighs_cdo.nc") + weights = generate_cdo_weights(target_grid_path, + gridfile, + input_file_path, + weights_file_path, + interpolation, + save = True) interpolated = interpolate_cdo(target_grid_path, gridfile, input_file_path, output_file_path, variable_name, interpolation, + weights_file_path, mask_zero=args.no_mask_zero ) os.remove(input_file_path) - elif interpolation in ["smm_con", "smm_laf", "smm_nn", "smm_dis"]: + elif interpolation in ["smm_remapcon", "smm_remaplaf", "smm_remapnn", "smm_remapdis"]: input_data = xr.Dataset({variable_name: (["nod2"], data_in)}) if args.rotate: input_data = xr.Dataset({variable_name: (["nod2"], data_in2)}) - input_file_path = args.data.replace(".nc","cdo_interpolation.nc") - input_data.to_netcdf(input_file_path,encoding={ - variable_name: {"dtype": np.dtype("double")}, - }, - ) - output_file_path = out_path.replace(".nc", "weighs_cdo.nc") - weights = generate_cdo_weights(target_grid_path, - gridfile, - input_file_path, - output_file_path, - interpolation) - os.remove(input_file_path) - interpolator = Regridder(weights=weights) + if t_index == 0 and d_index == 0: + input_file_path = args.data.replace(".nc","cdo_interpolation.nc") + input_data.to_netcdf(input_file_path,encoding={ + variable_name: {"dtype": np.dtype("double")}, + }, + ) + output_file_path = out_path.replace(".nc", "weighs_cdo.nc") + if args.weightspath is not None: + weights_file = args.weightspath + weights = xr.open_dataset(weights_file) + else: + weights = generate_cdo_weights(target_grid_path, + gridfile, + input_file_path, + output_file_path, + interpolation, + save =args.save_weights) + os.remove(input_file_path) + interpolator = Regridder(weights=weights) interpolated = interpolator.regrid(input_data) interpolated = interpolated[variable_name].values mask_zero=args.no_mask_zero if mask_zero: interpolated[interpolated == 0] = np.nan + + elif interpolation in ['xesmf_nearest_s2d']: + ds_in = xr.Dataset({variable_name: (["nod2"], data_in)}) + ds_in = ds_in.assign_coords(lat=('nod2',y2), lon=('nod2',x2)) + ds_in.lat.attrs = {'units': 'degrees', 'standard_name': 'latitude'} + ds_in.lon.attrs = {'units': 'degrees', 'standard_name': 'longitude'} + interpolated = regridder(ds_in)[variable_name].values # masking of the data if mask_file is not None: @@ -1048,7 +1162,11 @@ def fint(args=None): lat, out_path_one2, ) - if interpolation in ["cdo_remapcon","cdo_remaplaf","cdo_remapnn","cdo_remapdis","smm_con","smm_nn","smm_laf","smm_dis"]: + if interpolation in ["cdo_remapcon","cdo_remaplaf","cdo_remapnn","cdo_remapdis"]: + os.remove(target_grid_path) + if args.save_weights is False and args.weightspath is None and weights_file_path is not None: + os.remove(weights_file_path) + elif interpolation in ["smm_remapcon","smm_remapnn","smm_remaplaf","smm_remapdis"]: os.remove(target_grid_path) # save data (always 4D array) diff --git a/test/tests.sh b/test/tests.sh index 80cd73f..3ef9a22 100755 --- a/test/tests.sh +++ b/test/tests.sh @@ -48,11 +48,24 @@ fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b "-150, 150, -50, 70" --interp cdo_rema fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b arctic --interp cdo_remapcon #smm_regrid -fint ${FILE} ${MESH} --influence 500000 -t 0 -d -1 --interp smm_con --no_shape_mask -fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b "-150, 150, -50, 70" --interp smm_laf -fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b arctic --interp smm_nn -fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 --interp smm_dis - +fint ${FILE} ${MESH} ${INFL} -t 0 -d -1 --interp smm_remapcon --no_shape_mask +fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b "-150, 150, -50, 70" --interp smm_remaplaf +fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b arctic --interp smm_remapnn +fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 --interp smm_remapdis + +#xesmf_regrid +fint ${FILE} ${MESH} ${INFL} -t -1 -d -1 -b "-150, 150, -50, 70" --interp xesmf_nearest_s2d +fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b arctic --interp xesmf_nearest_s2d +fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b gulf --interp xesmf_nearest_s2d + +#saving weights and reuse it +fint ${FILE} ${MESH} ${INFL} -t 1:5 -d -1 -b "-150, 150, -50, 70" --interp smm_remapcon --save_weights +export WEIGHTS="--weightspath ./temp.fesom.1948_interpolated_-150_150_-50_70_2.5_6125.0_1_4weighs_cdo.nc" +fint ${FILE} ${MESH} ${INFL} -t 1:5 -d -1 -b "-150, 150, -50, 70" --interp cdo_remapcon ${WEIGHTS} + +fint ${FILE} ${MESH} ${INFL} -t 1:5 -d -1 -b "-150, 150, -50, 70" --interp xesmf_nearest_s2d --save_weights +export WEIGHTS="--weightspath ./temp.fesom.1948_interpolated_-150_150_-50_70_2.5_6125.0_1_4xesmf_weights.nc" +fint ${FILE} ${MESH} ${INFL} -t -1 -d -1 -b "-150, 150, -50, 70" --interp xesmf_nearest_s2d ${WEIGHTS} # create mask fint ${FILE} ${MESH} ${INFL} -t 0 -d -1 --interp mtri_linear -o mask.nc