Skip to content

Commit

Permalink
Add script to extend GCP zarrs
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Apr 20, 2023
1 parent 4e752bb commit f426197
Showing 1 changed file with 235 additions and 0 deletions.
235 changes: 235 additions & 0 deletions scripts/extend_gcp_zarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import ocf_blosc2
import xarray as xr
import satpy
from satpy import Scene
from satip.eumetsat import DownloadManager
from satip.jpeg_xl_float_with_nans import JpegXlFloatWithNaNs
from satip.scale_to_zero_to_one import ScaleToZeroToOne
from satip.serialize import serialize_attrs
from satip.utils import convert_scene_to_dataarray
import os
import numpy as np
import pandas as pd
import glob
import shutil
import json


def download_data(last_zarr_time):
api_key = os.environ["SAT_API_KEY"]
api_secret = os.environ["SAT_API_SECRET"]
download_manager = DownloadManager(user_key=api_key, user_secret=api_secret, data_dir="./")
start_date = pd.Timestamp.utcnow()
date_range = pd.date_range(start=start_date.strftime("%Y-%m-%d-%H-%M-%S"),
end=pd.Timestamp(last_zarr_time, tz="UTC").strftime("%Y-%m-%d-%H-%M-%S"),
freq="1D")
for date in date_range:
start_date = pd.Timestamp(date) - pd.Timedelta("1min")
end_date = pd.Timestamp(date) + pd.Timedelta("1min")
datasets = download_manager.identify_available_datasets(
start_date=start_date.strftime("%Y-%m-%d-%H-%M-%S"),
end_date=end_date.strftime("%Y-%m-%d-%H-%M-%S"),
)
download_manager.download_datasets(datasets)

# Get native files in order
native_files = list(glob.glob("./native_files/*"))
native_files.sort()
return native_files


def preprocess_function(xr_data: xr.Dataset) -> xr.Dataset:
attrs = xr_data.attrs
y_coords = xr_data.coords["y_geostationary"].values
x_coords = xr_data.coords["x_geostationary"].values
x_dataarray = xr.DataArray(
data=np.expand_dims(xr_data.coords["x_geostationary"].values, axis=0),
dims=["time", "x_geostationary"],
coords=dict(time=xr_data.coords["time"].values, x_geostationary=x_coords),
)
y_dataarray = xr.DataArray(
data=np.expand_dims(xr_data.coords["y_geostationary"].values, axis=0),
dims=["time", "y_geostationary"],
coords=dict(time=xr_data.coords["time"].values, y_geostationary=y_coords),
)
xr_data["x_geostationary_coordinates"] = x_dataarray
xr_data["y_geostationary_coordinates"] = y_dataarray
xr_data.attrs = attrs
return xr_data


def open_and_scale_data_hrv(zarr_times, f):
hrv_scaler = ScaleToZeroToOne(
variable_order=["HRV"], maxs=np.array([103.90016]), mins=np.array([-1.2278595])
)

hrv_scene = Scene(filenames={"seviri_l1b_native": [f]})
hrv_scene.load(
[
"HRV",
]
)
hrv_dataarray: xr.DataArray = convert_scene_to_dataarray(
hrv_scene, band="HRV", area="RSS", calculate_osgb=False
)
attrs = serialize_attrs(hrv_dataarray.attrs)
hrv_dataarray = hrv_scaler.rescale(hrv_dataarray)
hrv_dataarray.attrs.update(attrs)
hrv_dataarray = hrv_dataarray.transpose(
"time", "y_geostationary", "x_geostationary", "variable"
)
hrv_dataset = hrv_dataarray.to_dataset(name="data")
hrv_dataset["data"] = hrv_dataset["data"].astype(np.float16)

if hrv_dataset.time.values[0] in zarr_times:
print("Skipping")
return None

return hrv_dataset


def open_and_scale_data_nonhrv(zarr_times, f):
"""Zarr path is the path to the zarr file to extend, f is native file to open and scale"""
scaler = ScaleToZeroToOne(
mins=np.array(
[
-2.5118103,
-64.83977,
63.404694,
2.844452,
199.10002,
-17.254883,
-26.29155,
-1.1009827,
-2.4184198,
199.57048,
198.95093,
]
),
maxs=np.array(
[
69.60857,
339.15588,
340.26526,
317.86752,
313.2767,
315.99194,
274.82297,
93.786545,
101.34922,
249.91806,
286.96323,
]
),
variable_order=[
"IR_016",
"IR_039",
"IR_087",
"IR_097",
"IR_108",
"IR_120",
"IR_134",
"VIS006",
"VIS008",
"WV_062",
"WV_073",
],
)

scene = Scene(filenames={"seviri_l1b_native": [f]})
scene.load(
[
"IR_016",
"IR_039",
"IR_087",
"IR_097",
"IR_108",
"IR_120",
"IR_134",
"VIS006",
"VIS008",
"WV_062",
"WV_073",
]
)
dataarray: xr.DataArray = convert_scene_to_dataarray(
scene, band="IR_016", area="RSS", calculate_osgb=False
)
attrs = serialize_attrs(dataarray.attrs)
dataarray = scaler.rescale(dataarray)
dataarray.attrs.update(attrs)
dataarray = dataarray.transpose("time", "y_geostationary", "x_geostationary", "variable")
dataset = dataarray.to_dataset(name="data")
dataset["data"] = dataset.data.astype(np.float16)

if dataset.time.values[0] in zarr_times:
print("Skipping")
return None

return dataset


def write_to_zarr(dataset, zarr_name, mode, chunks):
mode_extra_kwargs = {
"a": {"append_dim": "time"},
}
extra_kwargs = mode_extra_kwargs[mode]
dataset.chunk(chunks).to_zarr(
zarr_name, compute=True, **extra_kwargs, consolidated=True, mode=mode
)


def rewrite_zarr_times(output_name):
# Combine time coords
ds = xr.open_zarr(output_name)
del ds["data"]
# Need to remove these encodings to avoid chunking
del ds.time.encoding['chunks']
del ds.time.encoding['preferred_chunks']
ds.to_zarr(f"{output_name.split('.zarr')[0]}_coord.zarr", consolidated=True)
# Remove current time ones
shutil.rmtree(f"{output_name}/time/")
# Add new time ones
shutil.copytree(f"{output_name.split('.zarr')[0]}_coord.zarr/time", f"{output_name}/time")

# Now replace the part of the .zmetadata with the part of the .zmetadata from the new coord one
with open(f"{output_name}/.zmetadata", "r") as f:
data = json.load(f)
with open(f"{output_name.split('.zarr')[0]}_coord.zarr/.zmetadata", "r") as f2:
coord_data = json.load(f2)
data["metadata"]["time/.zarray"] = coord_data["metadata"]["time/.zarray"]
with open(f"{output_name}/.zmetadata", "w") as f:
json.dump(data, f)


if __name__ == "__main__":
zarr_path = "/mnt/disks/data/2023_hrv.zarr"
non_zarr_path = "/mnt/disks/data/2023_nonhrv.zarr"
zarr_times = xr.open_zarr(non_zarr_path).sortby("time").time.values
hrv_zarr_times = xr.open_zarr(zarr_path).sortby("time").time.values
last_zarr_time = zarr_times[-1]
native_files = download_data(last_zarr_time)
datasets = []
hrv_datasets = []
for f in native_files:
dataset = open_and_scale_data_nonhrv(zarr_times, f)
if dataset is not None:
datasets.append(dataset)
if len(datasets) == 12:
write_to_zarr(xr.concat(datasets, dim="time"), non_zarr_path, "a", chunks={"time": 12})
write_to_zarr(
xr.concat(hrv_datasets, dim="time"), zarr_path, "a", chunks={"time": 12}
)
datasets = []
for f in native_files:
dataset = open_and_scale_data_hrv(hrv_zarr_times, f)
if dataset is not None:
datasets = preprocess_function(dataset)
hrv_datasets.append(dataset)
if len(hrv_datasets) == 12:
write_to_zarr(
xr.concat(hrv_datasets, dim="time"), zarr_path, "a", chunks={"time": 12}
)
hrv_datasets = []
rewrite_zarr_times(non_zarr_path)
rewrite_zarr_times(zarr_path)

0 comments on commit f426197

Please sign in to comment.