Skip to content

Commit

Permalink
fix forcing generation out of memory error
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshCu committed Jan 6, 2025
1 parent 02f5b60 commit 3aa40e2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
27 changes: 20 additions & 7 deletions modules/data_processing/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from multiprocessing import shared_memory
from pathlib import Path

from dask.distributed import Client, LocalCluster

import geopandas as gpd
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -68,18 +70,22 @@ def get_cell_weights(raster, gdf, wkt):


def add_APCP_SURFACE_to_dataset(dataset: xr.Dataset) -> xr.Dataset:
# precip rate is mm/s
# cfe says input m/h
dataset["APCP_surface"] = dataset["precip_rate"] * 3600 / 1000
# technically should be kg/m^2/h, at 1kg = 1l it equates to mm/h
# precip_rate is mm/s
# cfe says input atmosphere_water__liquid_equivalent_precipitation_rate is mm/h
# nom says prcpnonc input is mm/s
# technically should be kg/m^2/s at 1kg = 1l it equates to mm/s
# nom says qinsur output is m/s, hopefully qinsur is converted to mm/h by ngen
dataset["APCP_surface"] = dataset["precip_rate"] * 3600
return dataset


def get_index_chunks(data: xr.DataArray) -> list[tuple[int, int]]:
# takes a data array and calculates the start and end index for each chunk
# based on the available memory.
array_memory_usage = data.nbytes
free_memory = psutil.virtual_memory().available * 0.8 # 80% of available memory
# free_memory = psutil.virtual_memory().available * 0.8 # 80% of available memory
# limit the chunk to 20gb, makes things more stable
free_memory = min(free_memory, 20 * 1024 * 1024 * 1024)
num_chunks = ceil(array_memory_usage / free_memory)
max_index = data.shape[0]
stride = max_index // num_chunks
Expand Down Expand Up @@ -233,8 +239,15 @@ def compute_zonal_stats(

def write_outputs(forcings_dir, variables):

# Combine all variables into a single dataset
results = [xr.open_dataset(file) for file in forcings_dir.glob("*.nc")]
# start a dask cluster if there isn't one already running
try:
client = Client.current()
except ValueError:
cluster = LocalCluster()
client = Client(cluster)

# Combine all variables into a single dataset using dask
results = [xr.open_dataset(file, chunks="auto") for file in forcings_dir.glob("*.nc")]
final_ds = xr.merge(results)

output_folder = forcings_dir / "by_catchment"
Expand Down
2 changes: 1 addition & 1 deletion modules/data_processing/s3fs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def _cat_file(
# Fall back to single request if HEAD fails
return await self._download_chunk(bucket, key, {}, version_kw)

CHUNK_SIZE = 1 * 1024 * 1024 # 1MB chunks
CHUNK_SIZE = 5 * 1024 * 1024 # 1MB chunks
if obj_size <= CHUNK_SIZE:
return await self._download_chunk(bucket, key, {}, version_kw)

Expand Down

0 comments on commit 3aa40e2

Please sign in to comment.