Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xesmf interpolation #30

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ dependencies:
- cartopy
- shapely
- cdo
- xesmf
- sparse
- pip:
- smmregrid

Expand Down
192 changes: 155 additions & 37 deletions src/fint/fint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import sparse
import subprocess
from smmregrid import Regridder
from scipy.interpolate import (
Expand Down Expand Up @@ -266,25 +268,26 @@ def interpolate_linear_scipy(data_in, x2, y2, lon2, lat2):
return interpolated


def interpolate_cdo(target_grid,gridfile,original_file,output_file,variable_name,interpolation, mask_zero=True):
def interpolate_cdo(target_grid,gridfile,original_file,output_file,variable_name,interpolation,weights_file_path, mask_zero=True):
"""
Interpolate a variable in a file using CDO (Climate Data Operators).
Interpolate a climate variable in a NetCDF file using Climate Data Operators (CDO).

Args:
target_grid (str): Path to the target grid file.
gridfile (str): Path to the grid file associated with the original file.
original_file (str): Path to the original file containing the variable to be interpolated.
output_file (str): Path to the output file where the interpolated variable will be saved.
target_grid (str): Path to the target grid file (NetCDF format) for interpolation.
gridfile (str): Path to the grid file (NetCDF format) associated with the original data.
original_file (str): Path to the original NetCDF file containing the variable to be interpolated.
output_file (str): Path to the output NetCDF file where the interpolated variable will be saved.
variable_name (str): Name of the variable to be interpolated.
interpolation (str): Interpolation method to be used (cdo_remapcon,cdo_remaplaf,cdo_remapnn, cdo_remapdis).
interpolation (str): Interpolation method to be used (e.g., 'remapcon', 'remaplaf', 'remapnn', 'remapdis').
weights_file_path (str): Path to the weights file generated by CDO for interpolation.
mask_zero (bool, optional): Whether to mask zero values in the output. Default is True.

Returns:
np.ndarray: Interpolated variable data as a NumPy array.
"""

command = [
"cdo",
f"-{interpolation.split('_')[1]},{target_grid}",
f"-remap,{target_grid},{weights_file_path}",
f"-setgrid,{gridfile}",
f"{original_file}",
f"{output_file}"
Expand All @@ -299,34 +302,88 @@ def interpolate_cdo(target_grid,gridfile,original_file,output_file,variable_name
os.remove(output_file)
return interpolated

def generate_cdo_weights(target_grid,gridfile,original_file,output_file,interpolation):
def generate_cdo_weights(target_grid,gridfile,original_file,output_file,interpolation, save = False):
"""
Generate CDO weights for interpolation using ssmregrid.
Generate CDO interpolation weights for smmregrid and cdo interpolation.

Args:
target_grid (str): Path to the target grid file.
gridfile (str): Path to the grid file associated with the original file.
original_file (str): Path to the original file containing the data to be remapped.
output_file (str): Path to the output file where the weights will be saved.
target_grid (str): Path to the target grid file (NetCDF format).
gridfile (str): Path to the grid file (NetCDF format) associated with the original data.
original_file (str): Path to the original NetCDF file containing the data to be remapped.
output_file (str): Path to the output NetCDF file where the weights will be saved.
interpolation (str): Interpolation method to be used.
save (bool, optional): Whether to save the weights file. Default is False.

Returns:
xr.Dataset: Generated weights as an xarray Dataset.
xr.Dataset: Generated interpolation weights as an xarray Dataset.
"""

int_method = interpolation.split('_')[-1][5:]

command = [
"cdo",
f"-gen{interpolation.split('_')[-1]},{target_grid}",
f"-gen{int_method},{target_grid}",
f"-setgrid,{gridfile}",
f"{original_file}",
f"{output_file}"
]

# Execute the command
subprocess.run(command)

weights = xr.open_dataset(output_file)
os.remove(output_file)
if save == False:
os.remove(output_file)

return weights

def xesmf_weights_to_xarray(regridder):
"""
Converts xESMF regridder weights to an xarray Dataset.

This function takes a regridder object from xESMF, extracts the weights data,
and converts it into an xarray Dataset with relevant dimensions and attributes.

Args:
regridder (xesmf.Regridder): A regridder object created using xESMF, which contains the weights to be converted.

Returns:
xr.Dataset: An xarray Dataset containing the weights data with dimensions 'n_s', 'col', and 'row'.
"""
w = regridder.weights.data
dim = 'n_s'
ds = xr.Dataset(
{
'S': (dim, w.data),
'col': (dim, w.coords[1, :] + 1),
'row': (dim, w.coords[0, :] + 1),
}
)
ds.attrs = {'n_in': regridder.n_in, 'n_out': regridder.n_out}
return ds

def reconstruct_xesmf_weights(ds_w):
"""
Reconstruct weights into a format that xESMF understands.

This function takes a dataset with weights in a specific format and converts
it into a format that can be used by xESMF for regridding.

Args:
ds_w (xarray.Dataset): The input dataset containing weights data.
It should have 'S', 'col', 'row', and appropriate attributes 'n_out' and 'n_in'.

Returns:
xarray.DataArray: A DataArray containing reconstructed weights in COO format suitable for use with xESMF.
"""
col = ds_w['col'].values - 1
row = ds_w['row'].values - 1
s = ds_w['S'].values
n_out, n_in = ds_w.attrs['n_out'], ds_w.attrs['n_in']
crds = np.stack([row, col])
return xr.DataArray(
sparse.COO(crds, s, (n_out, n_in)), dims=('out_dim', 'in_dim'), name='weights'
)

def parse_depths(depths, depths_from_file):
"""
Parses the selected depths from the available depth values and returns the corresponding depth indices and values.
Expand Down Expand Up @@ -614,7 +671,7 @@ def fint(args=None):
"--interp",
choices=["nn", "mtri_linear", "linear_scipy",
"cdo_remapcon","cdo_remaplaf","cdo_remapnn", "cdo_remapdis",
"smm_con","smm_laf","smm_nn","smm_dis"], # "idist", "linear", "cubic"],
"smm_remapcon","smm_remaplaf","smm_remapnn", "smm_remapdis", "xesmf_nearest_s2d"], # "idist", "linear", "cubic"],
default="nn",
help="Interpolation method. Options are \
nn - nearest neighbor (KDTree implementation, fast), \
Expand Down Expand Up @@ -678,6 +735,15 @@ def fint(args=None):
Valid units are 'D' (days), 'h' (hours), 'm' (minutes), 's' (seconds). \
To substract timedelta, put argument in quotes, and prepend ' -', so SPACE and then -, e.g. ' -10D'.",
)
parser.add_argument(
"--weightspath",
type=str,
help="File with CDO weiths for smm interpolation.")
parser.add_argument(
"--save_weights",
action = "store_true",
help = "Save CDO weights ac netcdf file at the output directory.")


args = parser.parse_args()

Expand Down Expand Up @@ -797,7 +863,30 @@ def fint(args=None):
trifinder = triang2.get_trifinder()
elif interpolation == "nn":
distances, inds = create_indexes_and_distances(x2, y2, lon, lat, k=1, workers=4)
elif interpolation in ["cdo_remapcon", "cdo_remaplaf", "cdo_remapnn", "cdo_remapdis", "smm_con", "smm_laf", "smm_nn", "smm_dis"]:
elif interpolation in ['xesmf_nearest_s2d']:
ds_in = xr.open_dataset(args.data)
ds_in = ds_in.assign_coords(lat=('nod2',y2), lon=('nod2',x2))
ds_in.lat.attrs = {'units': 'degrees', 'standard_name': 'latitude'}
ds_in.lon.attrs = {'units': 'degrees', 'standard_name': 'longitude'}
ds_out = xr.Dataset(
{
'x': xr.DataArray(x, dims=['x']),
'y': xr.DataArray(y, dims=['y']),
'lat': xr.DataArray(lat, dims=['y', 'x']),
'lon': xr.DataArray(lon, dims=['y', 'x']),
})
if args.weightspath is not None:
xesmf_weights_path = args.weightspath
ds_w = xr.open_dataset(xesmf_weights_path)
wegiths_xesmf = reconstruct_xesmf_weights(ds_w)
regridder = xe.Regridder(ds_in,ds_out, method='nearest_s2d', weights=wegiths_xesmf,locstream_in=True)
else:
regridder = xe.Regridder(ds_in, ds_out, method='nearest_s2d', locstream_in=True)
if args.save_weights is True:
ds_w = xesmf_weights_to_xarray(regridder)
xesmf_weights_path = out_path.replace(".nc", "xesmf_weights.nc")
ds_w.to_netcdf(xesmf_weights_path)
elif interpolation in ["cdo_remapcon", "cdo_remaplaf", "cdo_remapnn", "cdo_remapdis", "smm_remapcon", "smm_remaplaf", "smm_remapnn", "smm_remapdis"]:
gridtype = 'latlon'
gridsize = x.size*y.size
xsize = x.size
Expand Down Expand Up @@ -962,38 +1051,63 @@ def fint(args=None):
variable_name: {"dtype": np.dtype("double")},
},
)
if args.weightspath is not None:
weights_file_path = args.weightspath
else:
if t_index == 0 and d_index == 0:
weights_file_path = out_path.replace(".nc", "weighs_cdo.nc")
weights = generate_cdo_weights(target_grid_path,
gridfile,
input_file_path,
weights_file_path,
interpolation,
save = True)
interpolated = interpolate_cdo(target_grid_path,
gridfile,
input_file_path,
output_file_path,
variable_name,
interpolation,
weights_file_path,
mask_zero=args.no_mask_zero
)
os.remove(input_file_path)

elif interpolation in ["smm_con", "smm_laf", "smm_nn", "smm_dis"]:
elif interpolation in ["smm_remapcon", "smm_remaplaf", "smm_remapnn", "smm_remapdis"]:
input_data = xr.Dataset({variable_name: (["nod2"], data_in)})
if args.rotate:
input_data = xr.Dataset({variable_name: (["nod2"], data_in2)})
input_file_path = args.data.replace(".nc","cdo_interpolation.nc")
input_data.to_netcdf(input_file_path,encoding={
variable_name: {"dtype": np.dtype("double")},
},
)
output_file_path = out_path.replace(".nc", "weighs_cdo.nc")
weights = generate_cdo_weights(target_grid_path,
gridfile,
input_file_path,
output_file_path,
interpolation)
os.remove(input_file_path)
interpolator = Regridder(weights=weights)
if t_index == 0 and d_index == 0:
input_file_path = args.data.replace(".nc","cdo_interpolation.nc")
input_data.to_netcdf(input_file_path,encoding={
variable_name: {"dtype": np.dtype("double")},
},
)
output_file_path = out_path.replace(".nc", "weighs_cdo.nc")
if args.weightspath is not None:
weights_file = args.weightspath
weights = xr.open_dataset(weights_file)
else:
weights = generate_cdo_weights(target_grid_path,
gridfile,
input_file_path,
output_file_path,
interpolation,
save =args.save_weights)
os.remove(input_file_path)
interpolator = Regridder(weights=weights)
interpolated = interpolator.regrid(input_data)
interpolated = interpolated[variable_name].values
mask_zero=args.no_mask_zero
if mask_zero:
interpolated[interpolated == 0] = np.nan

elif interpolation in ['xesmf_nearest_s2d']:
ds_in = xr.Dataset({variable_name: (["nod2"], data_in)})
ds_in = ds_in.assign_coords(lat=('nod2',y2), lon=('nod2',x2))
ds_in.lat.attrs = {'units': 'degrees', 'standard_name': 'latitude'}
ds_in.lon.attrs = {'units': 'degrees', 'standard_name': 'longitude'}
interpolated = regridder(ds_in)[variable_name].values

# masking of the data
if mask_file is not None:
Expand Down Expand Up @@ -1048,7 +1162,11 @@ def fint(args=None):
lat,
out_path_one2,
)
if interpolation in ["cdo_remapcon","cdo_remaplaf","cdo_remapnn","cdo_remapdis","smm_con","smm_nn","smm_laf","smm_dis"]:
if interpolation in ["cdo_remapcon","cdo_remaplaf","cdo_remapnn","cdo_remapdis"]:
os.remove(target_grid_path)
if args.save_weights is False and args.weightspath is None and weights_file_path is not None:
os.remove(weights_file_path)
elif interpolation in ["smm_remapcon","smm_remapnn","smm_remaplaf","smm_remapdis"]:
os.remove(target_grid_path)

# save data (always 4D array)
Expand Down
23 changes: 18 additions & 5 deletions test/tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,24 @@ fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b "-150, 150, -50, 70" --interp cdo_rema
fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b arctic --interp cdo_remapcon

#smm_regrid
fint ${FILE} ${MESH} --influence 500000 -t 0 -d -1 --interp smm_con --no_shape_mask
fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b "-150, 150, -50, 70" --interp smm_laf
fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b arctic --interp smm_nn
fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 --interp smm_dis

fint ${FILE} ${MESH} ${INFL} -t 0 -d -1 --interp smm_remapcon --no_shape_mask
fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b "-150, 150, -50, 70" --interp smm_remaplaf
fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b arctic --interp smm_remapnn
fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 --interp smm_remapdis

#xesmf_regrid
fint ${FILE} ${MESH} ${INFL} -t -1 -d -1 -b "-150, 150, -50, 70" --interp xesmf_nearest_s2d
fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b arctic --interp xesmf_nearest_s2d
fint ${FILE} ${MESH} ${INFL} -t 0 -d 0 -b gulf --interp xesmf_nearest_s2d

#saving weights and reuse it
fint ${FILE} ${MESH} ${INFL} -t 1:5 -d -1 -b "-150, 150, -50, 70" --interp smm_remapcon --save_weights
export WEIGHTS="--weightspath ./temp.fesom.1948_interpolated_-150_150_-50_70_2.5_6125.0_1_4weighs_cdo.nc"
fint ${FILE} ${MESH} ${INFL} -t 1:5 -d -1 -b "-150, 150, -50, 70" --interp cdo_remapcon ${WEIGHTS}

fint ${FILE} ${MESH} ${INFL} -t 1:5 -d -1 -b "-150, 150, -50, 70" --interp xesmf_nearest_s2d --save_weights
export WEIGHTS="--weightspath ./temp.fesom.1948_interpolated_-150_150_-50_70_2.5_6125.0_1_4xesmf_weights.nc"
fint ${FILE} ${MESH} ${INFL} -t -1 -d -1 -b "-150, 150, -50, 70" --interp xesmf_nearest_s2d ${WEIGHTS}

# create mask
fint ${FILE} ${MESH} ${INFL} -t 0 -d -1 --interp mtri_linear -o mask.nc
Expand Down
Loading