Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory management and Visualisation #53

Merged
merged 9 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ dist
*.egg-info
**/tiles/tms*
**/tiles/vpu*
*.tar.gz
*.tar.gz
*.dat
15 changes: 12 additions & 3 deletions modules/data_processing/create_realization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import json
import s3fs
import xarray as xr

from tqdm.rich import tqdm
from dask.distributed import Client, LocalCluster
from data_processing.file_paths import file_paths
from data_processing.gpkg_utils import get_cat_to_nex_flowpairs, get_cat_to_nhd_feature_id

Expand All @@ -21,11 +22,20 @@ def get_approximate_gw_storage(paths: file_paths, start_date: datetime):

fs = s3fs.S3FileSystem(anon=True)
nc_url = f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/netcdf/GWOUT/{year}/{formatted_dt}.GWOUT_DOMAIN1"

# make sure there's a dask cluster running
try:
client = Client.current()
except ValueError:
cluster = LocalCluster()
client = Client(cluster)


with fs.open(nc_url) as file_obj:
ds = xr.open_dataset(file_obj)

water_levels = dict()
for cat, feature in cat_to_feature.items():
for cat, feature in tqdm(cat_to_feature.items()):
# this value is in CM, we need meters to match max_gw_depth
# xarray says it's in mm, with 0.1 scale factor. calling .values doesn't apply the scale
water_level = ds.sel(feature_id=feature).depth.values / 100
Expand Down Expand Up @@ -129,7 +139,6 @@ def make_ngen_realization_json(
realization["time"]["start_time"] = start_time.strftime("%Y-%m-%d %H:%M:%S")
realization["time"]["end_time"] = end_time.strftime("%Y-%m-%d %H:%M:%S")
realization["time"]["output_interval"] = 3600
realization["time"]["nts"] = nts

with open(config_dir / "realization.json", "w") as file:
json.dump(realization, file, indent=4)
Expand Down
243 changes: 142 additions & 101 deletions modules/data_processing/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from pathlib import Path
from datetime import datetime
from functools import partial
import psutil
from math import ceil

from tqdm.rich import tqdm
import numpy as np
import dask
from dask.distributed import Client, LocalCluster, progress
import geopandas as gpd
import pandas as pd
import xarray as xr
Expand All @@ -21,26 +22,28 @@
from data_processing.zarr_utils import get_forcing_data

logger = logging.getLogger(__name__)
# Suppress the specific warning from numpy
# Suppress the specific warning from numpy to keep the cli output clean
warnings.filterwarnings(
"ignore", message="'DataFrame.swapaxes' is deprecated", category=FutureWarning
)
warnings.filterwarnings(
"ignore", message="'GeoDataFrame.swapaxes' is deprecated", category=FutureWarning
)


def weighted_sum_of_cells(flat_tensor, cell_ids, factors):
def weighted_sum_of_cells(flat_raster: np.ndarray, cell_ids: np.ndarray , factors: np.ndarray):
# Create an output array initialized with zeros
result = np.zeros(flat_tensor.shape[0])
result = np.sum(flat_tensor[:, cell_ids] * factors, axis=1)
# dimensions are raster[time][x*y]
result = np.zeros(flat_raster.shape[0])
result = np.sum(flat_raster[:, cell_ids] * factors, axis=1)
sum_of_weights = np.sum(factors)
result /= sum_of_weights
return result


def get_cell_weights(raster, gdf):
# Get the cell weights for each divide
output = exact_extract(
raster["LWDOWN"],
raster["RAINRATE"],
gdf,
["cell_id", "coverage"],
include_cols=["divide_id"],
Expand All @@ -57,24 +60,40 @@ def add_APCP_SURFACE_to_dataset(dataset: xr.Dataset) -> xr.Dataset:
return dataset


def save_to_csv(catchment_ds, csv_path):
catchment_df = catchment_ds.to_dataframe().drop(["catchment"], axis=1)
catchment_df.to_csv(csv_path)
return csv_path
def get_index_chunks(data: xr.DataArray) -> list[tuple[int, int]]:
# takes a data array and calculates the start and end index for each chunk
# based on the available memory.
array_memory_usage = data.nbytes
free_memory = psutil.virtual_memory().available * 0.8 # 80% of available memory
num_chunks = ceil(array_memory_usage / free_memory)
max_index = data.shape[0]
stride = max_index // num_chunks
chunk_start = range(0, max_index, stride)
index_chunks = [(start, start + stride) for start in chunk_start]
return index_chunks


def create_shared_memory(lazy_array):
logger.debug(f"Creating shared memory size {lazy_array.nbytes/ 10**6} Mb.")
shm = shared_memory.SharedMemory(create=True, size=lazy_array.nbytes)
shared_array = np.ndarray(lazy_array.shape, dtype=np.float32, buffer=shm.buf)
# if your data is not float32, xarray will do an automatic conversion here
# which consumes a lot more memory, forcings downloaded with this tool will work
for start, end in get_index_chunks(lazy_array):
# copy data from lazy to shared memory one chunk at a time
shared_array[start:end] = lazy_array[start:end]

time, x, y = shared_array.shape
shared_array = shared_array.reshape(time, -1)

def create_shared_memory(data):
shm = shared_memory.SharedMemory(create=True, size=data.nbytes)
shared_array = np.ndarray(data.shape, dtype=data.dtype, buffer=shm.buf)
shared_array[:] = data[:]
return shm, shared_array
return shm, shared_array.shape, shared_array.dtype


def process_chunk_shared(variable, times, shm_name, shape, dtype, chunk):
existing_shm = shared_memory.SharedMemory(name=shm_name)
raster = np.ndarray(shape, dtype=dtype, buffer=existing_shm.buf)

results = []

for catchment in chunk.index.unique():
cell_ids = chunk.loc[catchment]["cell_id"]
weights = chunk.loc[catchment]["coverage"]
Expand All @@ -87,11 +106,20 @@ def process_chunk_shared(variable, times, shm_name, shape, dtype, chunk):
)
temp_da = temp_da.assign_coords(catchment=catchment)
results.append(temp_da)

existing_shm.close()
return xr.concat(results, dim="catchment")


def get_cell_weights_parallel(gdf, input_forcings, num_partitions):
gdf_chunks = np.array_split(gdf, num_partitions)
one_timestep = input_forcings.isel(time=0).compute()
with multiprocessing.Pool() as pool:
args = [(one_timestep, gdf_chunk) for gdf_chunk in gdf_chunks]
catchments = pool.starmap(get_cell_weights, args)

return pd.concat(catchments)


def compute_zonal_stats(
gdf: gpd.GeoDataFrame, merged_data: xr.Dataset, forcings_dir: Path
) -> None:
Expand All @@ -100,102 +128,117 @@ def compute_zonal_stats(
num_partitions = multiprocessing.cpu_count() - 1
if num_partitions > len(gdf):
num_partitions = len(gdf)
gfd_chunks = np.array_split(gdf, multiprocessing.cpu_count() - 1)
one_timestep = merged_data.isel(time=0).compute()
with multiprocessing.Pool() as pool:
args = [(one_timestep, gdf_chunk) for gdf_chunk in gfd_chunks]
catchments = pool.starmap(get_cell_weights, args)

catchments = pd.concat(catchments)
catchments = get_cell_weights_parallel(gdf, merged_data, num_partitions)

variables = [
"LWDOWN",
"PSFC",
"Q2D",
"RAINRATE",
"SWDOWN",
"T2D",
"U2D",
"V2D",
]
variables = {
"LWDOWN": "DLWRF_surface",
"PSFC": "PRES_surface",
"Q2D": "SPFH_2maboveground",
"RAINRATE": "precip_rate",
"SWDOWN": "DSWRF_surface",
"T2D": "TMP_2maboveground",
"U2D": "UGRD_10maboveground",
"V2D": "VGRD_10maboveground",
}

results = []
cat_chunks = np.array_split(catchments, num_partitions)
forcing_times = merged_data.time.values

for variable in variables.keys():

if variable not in merged_data.data_vars:
logger.warning(f"Variable {variable} not in forcings, skipping")
continue

# to make sure this fits in memory, we need to chunk the data
time_chunks = get_index_chunks(merged_data[variable])

for i, times in enumerate(time_chunks):
start, end = times
# select the chunk of time we want to process
data_chunk = merged_data[variable].isel(time=slice(start,end))
# put it in shared memory
shm, shape, dtype = create_shared_memory(data_chunk)
times = data_chunk.time.values
# create a partial function to pass to the multiprocessing pool
partial_process_chunk = partial(process_chunk_shared,variable,times,shm.name,shape,dtype)

logger.debug(f"Processing variable: {variable}")
# process the chunks of catchments in parallel
with multiprocessing.Pool(num_partitions) as pool:
variable_data = pool.map(partial_process_chunk, cat_chunks)
del partial_process_chunk
# clean up the shared memory
shm.close()
shm.unlink()
logger.debug(f"Processed variable: {variable}")
concatenated_da = xr.concat(variable_data, dim="catchment")
# delete the data to free up memory
del variable_data
logger.debug(f"Concatenated variable: {variable}")
# write this to disk now to save memory
# xarray will monitor memory usage, but it doesn't account for the shared memory used to store the raster
# This reduces memory usage by about 60%
concatenated_da.to_dataset(name=variable).to_netcdf(forcings_dir/ "temp" / f"{variable}_{i}.nc")
# Merge the chunks back together
datasets = [xr.open_dataset(forcings_dir / "temp" / f"{variable}_{i}.nc") for i in range(len(time_chunks))]
xr.concat(datasets, dim="time").to_netcdf(forcings_dir / f"{variable}.nc")
for file in forcings_dir.glob("temp/*.nc"):
file.unlink()
logger.info(
f"Forcing generation complete! Zonal stats computed in {time.time() - timer_start} seconds"
)
write_outputs(forcings_dir, variables)

for variable in variables:
raster = merged_data[variable].values.reshape(merged_data[variable].shape[0], -1)

# Create shared memory for the raster
shm, shared_raster = create_shared_memory(raster)

cat_chunks = np.array_split(catchments, num_partitions)
times = merged_data.time.values

partial_process_chunk = partial(
process_chunk_shared,
variable,
times,
shm.name,
shared_raster.shape,
shared_raster.dtype,
)

logger.debug(f"Processing variable: {variable}")
with multiprocessing.Pool(num_partitions) as pool:
variable_data = pool.map(partial_process_chunk, cat_chunks)

# Clean up the shared memory
shm.close()
shm.unlink()

logger.debug(f"Processed variable: {variable}")
concatenated_da = xr.concat(variable_data, dim="catchment")
logger.debug(f"Concatenated variable: {variable}")
results.append(concatenated_da.to_dataset(name=variable))
def write_outputs(forcings_dir, variables):

# Combine all variables into a single dataset
results = [xr.open_dataset(file) for file in forcings_dir.glob("*.nc")]
final_ds = xr.merge(results)

output_folder = forcings_dir / "by_catchment"
# Clear out any existing files
for file in output_folder.glob("*.csv"):
file.unlink()

final_ds = final_ds.rename_vars(
{
"LWDOWN": "DLWRF_surface",
"PSFC": "PRES_surface",
"Q2D": "SPFH_2maboveground",
"RAINRATE": "precip_rate",
"SWDOWN": "DSWRF_surface",
"T2D": "TMP_2maboveground",
"U2D": "UGRD_10maboveground",
"V2D": "VGRD_10maboveground",
}
)
rename_dict = {}
for key, value in variables.items():
if key in final_ds:
rename_dict[key] = value

final_ds = final_ds.rename_vars(rename_dict)
final_ds = add_APCP_SURFACE_to_dataset(final_ds)

logger.info("Saving to disk")
# Save to disk
delayed_saves = []
for catchment in final_ds.catchment.values:
catchment_ds = final_ds.sel(catchment=catchment)
csv_path = output_folder / f"{catchment}.csv"
delayed_save = dask.delayed(save_to_csv)(catchment_ds, csv_path)
delayed_saves.append(delayed_save)
logger.debug("Delayed saves created")
try:
client = Client.current()
except ValueError:
cluster = LocalCluster()
client = Client(cluster)

dask.compute(*delayed_saves)
# this step halves the storage size of the forcings
for var in final_ds.data_vars:
final_ds[var] = final_ds[var].astype(np.float32)

logger.info(
f"Forcing generation complete! Zonal stats computed in {time.time() - timer_start} seconds"
)
client.shutdown()
logger.info("Saving to disk")
# The format for the netcdf is to support a legacy format
# which is why it's a little "unorthodox"
# There are no coordinates, just dimensions, catchment ids are stored in a 1d data var
# and time is stored in a 2d data var with the same time array for every catchment
# time is stored as unix timestamps, units have to be set

# add the catchment ids as a 1d data var
final_ds["ids"] = final_ds["catchment"].astype(str)
# time needs to be a 2d array of the same time array as unix timestamps for every catchment
with warnings.catch_warnings(action="ignore"):
time_array = (
final_ds.time.astype("datetime64[s]").astype(np.int64).values // 10**9
) ## convert from ns to s
time_array = time_array.astype(np.int32) ## convert to int32 to save space
final_ds = final_ds.drop_vars(["catchment", "time"]) ## drop the original time and catchment vars
final_ds = final_ds.rename_dims({"catchment": "catchment-id"}) # rename the catchment dimension
# add the time as a 2d data var, yes this is wasting disk space.
final_ds["Time"] = (("catchment-id", "time"), [time_array for _ in range(len(final_ds["ids"]))])
# set the time unit
final_ds["Time"].attrs["units"] = "s"
final_ds["Time"].attrs["epoch_start"] = "01/01/1970 00:00:00" # not needed but suppresses the ngen warning

final_ds.to_netcdf(output_folder / "forcings.nc", engine="netcdf4")
# delete the individual variable files
for file in forcings_dir.glob("*.nc"):
file.unlink()


def setup_directories(cat_id: str) -> file_paths:
Expand All @@ -212,8 +255,6 @@ def create_forcings(start_time: str, end_time: str, output_folder_name: str) ->

gdf = gpd.read_file(forcing_paths.geopackage_path, layer="divides").to_crs(projection)
logger.debug(f"gdf bounds: {gdf.total_bounds}")
logger.debug(gdf)
logger.debug("Got gdf")

if type(start_time) == datetime:
start_time = start_time.strftime("%Y-%m-%d %H:%M")
Expand Down
Loading
Loading