Skip to content

Commit

Permalink
ADD: Ability to create a Py-ART Grid FROM an xarray object (#1479)
Browse files Browse the repository at this point in the history
* ADD: Xarray to Py-ART Grid interface

* FIX: Pre-commit hooks

* DOC: Parameters for Xgrid

* ADD: Close dataset to test_accessor

---------

Co-authored-by: Robert Jackson <[email protected]>
  • Loading branch information
rcjackson and Robert Jackson authored Oct 31, 2023
1 parent fb62211 commit 9ce78bf
Show file tree
Hide file tree
Showing 3 changed files with 379 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyart/xradar/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .accessor import Xradar # noqa
from .accessor import Xradar, Xgrid # noqa
306 changes: 304 additions & 2 deletions pyart/xradar/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,261 @@
import pandas as pd
from datatree import DataTree, formatting, formatting_html
from datatree.treenode import NodePath
from xarray import concat
from xarray import DataArray, Dataset, concat
from xarray.core import utils

from ..core.transforms import antenna_vectors_to_cartesian
from ..config import get_metadata
from ..core.transforms import (
antenna_vectors_to_cartesian,
cartesian_to_geographic,
cartesian_vectors_to_geographic,
)
from ..lazydict import LazyLoadDict


class Xgrid:
def __init__(self, grid_ds):
"""
Wraps a Cf-compliant xarray Dataset into a PyART Grid Object.
Note that the times must not be decoded by xr.open_dataset when loading the file.
Parameters
----------
grid_ds: xarray Dataset
The xarray Dataset to convert to a Py-ART grid.
"""
if "units" not in list(grid_ds["time"].attrs.keys()):
raise RuntimeError(
"decode_times must be set to false when opening grid file!"
)
self.ds = grid_ds
self.time = dict(data=np.atleast_1d(self.ds["time"].values))
self.time.update(self.ds["time"].attrs)
self.fields = {}
self._find_fields()
self.origin_altitude = dict(
data=np.atleast_1d(self.ds["origin_altitude"].values)
)
self.origin_altitude.update(self.ds["origin_altitude"].attrs)
self.origin_latitude = dict(
data=np.atleast_1d(self.ds["origin_latitude"].values)
)
self.origin_latitude.update(self.ds["origin_latitude"].attrs)
self.origin_longitude = dict(
data=np.atleast_1d(self.ds["origin_longitude"].values)
)
self.origin_longitude.update(self.ds["origin_longitude"].attrs)
self.z = dict(data=np.atleast_1d(self.ds["z"].values))
self.z.update(self.ds["z"].attrs)
self.y = dict(data=np.atleast_1d(self.ds["y"].values))
self.y.update(self.ds["y"].attrs)
self.x = dict(data=np.atleast_1d(self.ds["x"].values))
self.x.update(self.ds["x"].attrs)
self.nradar = len(self.ds["nradar"].values)
self.radar_altitude = dict(data=np.atleast_1d(self.ds["radar_altitude"].values))
self.radar_altitude.update(self.ds["radar_altitude"].attrs)
self.radar_longitude = dict(
data=np.atleast_1d(self.ds["radar_longitude"].values)
)
self.radar_longitude.update(self.ds["radar_longitude"].attrs)
self.radar_latitude = dict(data=np.atleast_1d(self.ds["radar_latitude"].values))
self.radar_latitude.update(self.ds["radar_latitude"].attrs)
self.radar_time = dict(data=np.atleast_1d(self.ds["radar_time"].values))
self.radar_time.update(self.ds["radar_time"].attrs)
self.radar_name = dict(data=self.ds["radar_name"].values.astype("<U12"))
self.radar_name.update(self.ds["radar_name"].attrs)
self.projection = self.ds["projection"].attrs
if "_include_lon_0_lat_0" in list(self.projection.keys()):
if self.projection["_include_lon_0_lat_0"].lower() == "true":
self.projection["_include_lon_0_lat_0"] = True
else:
self.projection["_include_lon_0_lat_0"] = False

self.init_point_altitude()
self.init_point_longitude_latitude()
self.init_point_x_y_z()

def _find_fields(self):
for key in list(self.ds.variables.keys()):
if self.ds[key].dims == ("time", "z", "y", "x"):
self.fields[key] = {}
self.fields[key]["data"] = self.ds[key].values.squeeze()
self.fields[key].update(self.ds[key].attrs)

def get_projparams(self):
projparams = self.projection.copy()
if projparams.pop("_include_lon_0_lat_0", False):
projparams["lon_0"] = self.origin_longitude["data"][0]
projparams["lat_0"] = self.origin_latitude["data"][0]
return projparams

@property
def metadata(self):
return self.ds.attrs

@property
def ny(self):
return self.ds.dims["y"]

@property
def nx(self):
return self.ds.dims["x"]

@property
def nz(self):
return self.ds.dims["z"]

# Attribute init/reset methods
def init_point_x_y_z(self):
"""Initialize or reset the point_{x, y, z} attributes."""
self.point_x = LazyLoadDict(get_metadata("point_x"))
self.point_x.set_lazy("data", _point_data_factory(self, "x"))

self.point_y = LazyLoadDict(get_metadata("point_y"))
self.point_y.set_lazy("data", _point_data_factory(self, "y"))

self.point_z = LazyLoadDict(get_metadata("point_z"))
self.point_z.set_lazy("data", _point_data_factory(self, "z"))

def init_point_longitude_latitude(self):
"""
Initialize or reset the point_{longitude, latitudes} attributes.
"""
point_longitude = LazyLoadDict(get_metadata("point_longitude"))
point_longitude.set_lazy("data", _point_lon_lat_data_factory(self, 0))
self.point_longitude = point_longitude

point_latitude = LazyLoadDict(get_metadata("point_latitude"))
point_latitude.set_lazy("data", _point_lon_lat_data_factory(self, 1))
self.point_latitude = point_latitude

def init_point_altitude(self):
"""Initialize the point_altitude attribute."""
point_altitude = LazyLoadDict(get_metadata("point_altitude"))
point_altitude.set_lazy("data", _point_altitude_data_factory(self))
self.point_altitude = point_altitude

def get_point_longitude_latitude(self, level=0, edges=False):
"""
Return arrays of longitude and latitude for a given grid height level.
Parameters
----------
level : int, optional
Grid height level at which to determine latitudes and longitudes.
This is not currently used as all height level have the same
layout.
edges : bool, optional
True to calculate the latitude and longitudes of the edges by
interpolating between Cartesian coordinates points and
extrapolating at the boundaries. False to calculate the locations
at the centers.
Returns
-------
longitude, latitude : 2D array
Arrays containing the latitude and longitudes, in degrees, of the
grid points or edges between grid points for the given height.
"""
x = self.x["data"]
y = self.y["data"]
projparams = self.get_projparams()
return cartesian_vectors_to_geographic(x, y, projparams, edges=edges)

def add_field(self, field_name, dic, replace_existing=False):
"""
Add a field to the object.
Parameters
----------
field_name : str
Name of the field to add to the dictionary of fields.
dic : dict
Dictionary contain field data and metadata.
replace_existing : bool, optional
True to replace the existing field with key field_name if it
exists, loosing any existing data. False will raise a ValueError
when the field already exists.
"""
# check that the field dictionary to add is valid
if field_name in self.fields and replace_existing is False:
err = "A field with name: %s already exists" % (field_name)
raise ValueError(err)
if "data" not in dic:
raise KeyError("dic must contain a 'data' key")
if dic["data"].shape != (self.nz, self.ny, self.nx):
t = (self.nz, self.ny, self.nx)
err = "'data' has invalid shape, should be (%i, %i)" % t
raise ValueError(err)
self.fields[field_name] = dic

def to_xarray(self):
"""
Convert the Grid object to an xarray format.
Attributes
----------
time : dict
Time of the grid.
fields : dict of dicts
Moments from radars or other variables.
longitude, latitude : dict, 2D
Arrays of latitude and longitude for the grid height level.
x, y, z : dict, 1D
Distance from the grid origin for each Cartesian coordinate axis
in a one dimensional array.
"""

lon, lat = self.get_point_longitude_latitude()
z = self.z["data"]
y = self.y["data"]
x = self.x["data"]

time = self.ds.time.values

ds = Dataset()
for field in list(self.fields.keys()):
field_data = self.fields[field]["data"]
data = DataArray(
np.expand_dims(field_data, 0),
dims=("time", "z", "y", "x"),
coords={
"time": (["time"], time),
"z": (["z"], z),
"lat": (["y", "x"], lat),
"lon": (["y", "x"], lon),
"y": (["y"], y),
"x": (["x"], x),
},
)
for meta in list(self.fields[field].keys()):
if meta != "data":
data.attrs.update({meta: self.fields[field][meta]})

ds[field] = data
ds.lon.attrs = [
("long_name", "longitude of grid cell center"),
("units", "degree_E"),
("standard_name", "Longitude"),
]
ds.lat.attrs = [
("long_name", "latitude of grid cell center"),
("units", "degree_N"),
("standard_name", "Latitude"),
]

ds.z.attrs = get_metadata("z")
ds.y.attrs = get_metadata("y")
ds.x.attrs = get_metadata("x")

ds.z.encoding["_FillValue"] = None
ds.lat.encoding["_FillValue"] = None
ds.lon.encoding["_FillValue"] = None
ds.close()
return ds


class Xradar:
Expand Down Expand Up @@ -427,3 +678,54 @@ def _find_fields(self, ds):
**self.combined_sweeps[field].attrs,
}
return fields


def _point_data_factory(grid, coordinate):
"""Return a function which returns the locations of all points."""

def _point_data():
"""The function which returns the locations of all points."""
reg_x = grid.x["data"]
reg_y = grid.y["data"]
reg_z = grid.z["data"]
if coordinate == "x":
return np.tile(reg_x, (len(reg_z), len(reg_y), 1)).swapaxes(2, 2)
elif coordinate == "y":
return np.tile(reg_y, (len(reg_z), len(reg_x), 1)).swapaxes(1, 2)
else:
assert coordinate == "z"
return np.tile(reg_z, (len(reg_x), len(reg_y), 1)).swapaxes(0, 2)

return _point_data


def _point_lon_lat_data_factory(grid, coordinate):
"""Return a function which returns the geographic locations of points."""

def _point_lon_lat_data():
"""The function which returns the geographic point locations."""
x = grid.point_x["data"]
y = grid.point_y["data"]
projparams = grid.get_projparams()
geographic_coords = cartesian_to_geographic(x, y, projparams)
# Set point_latitude['data'] when point_longitude['data'] is evaluated
# and vice-versa. This ensures that both attributes contain data from
# the same map projection and that the map projection only needs to be
# evaluated once.
if coordinate == 0:
grid.point_latitude["data"] = geographic_coords[1]
else:
grid.point_longitude["data"] = geographic_coords[0]
return geographic_coords[coordinate]

return _point_lon_lat_data


def _point_altitude_data_factory(grid):
"""Return a function which returns the point altitudes."""

def _point_altitude_data():
"""The function which returns the point altitudes."""
return grid.origin_altitude["data"][0] + grid.point_z["data"]

return _point_altitude_data
Loading

0 comments on commit 9ce78bf

Please sign in to comment.