Skip to content

Commit

Permalink
weight generation performance improvement (reduce np.where input)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshCu committed Jan 11, 2024
1 parent d70b7d8 commit 9010c66
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions forcingprocessor/src/forcingprocessor/weight_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import geopandas as gpd
import numpy as np
import xarray as xr
from rasterio.features import rasterize
from rasterio.features import rasterize, bounds
import json, argparse, os
from pathlib import Path
from math import floor, ceil
import datetime

def generate_weights_file(geopackage,grid_file,weights_filepath):

Expand All @@ -20,8 +22,8 @@ def generate_weights_file(geopackage,grid_file,weights_filepath):
PARAMETER["standard_parallel_2",60.0],PARAMETER["latitude_of_origin",40.0],UNIT["Meter",1.0]]')

crosswalk_dict = {}
i = 0
for index, row in gdf_proj.iterrows():
start_time = datetime.datetime.now()
for i, row in gdf_proj.iterrows():
geom_rasterize = rasterize(
[(row["geometry"], 1)],
out_shape=grid.rio.shape,
Expand All @@ -30,13 +32,25 @@ def generate_weights_file(geopackage,grid_file,weights_filepath):
fill=0,
dtype="uint8",
)
crosswalk_dict[row["divide_id"]] = np.where(geom_rasterize == 1)
# numpy.where runs slowly on large arrays
# so we slice off the empty space
y_min, x_max, y_max, x_min = bounds(row["geometry"], transform=~grid.rio.transform())
x_min = floor(x_min)
x_max = ceil(x_max)
y_min = floor(y_min)
y_max = ceil(y_max)
geom_rasterize = geom_rasterize[x_min:x_max, y_min:y_max]
localized_coords = np.where(geom_rasterize == 1)
global_coords = (localized_coords[0] + x_min, localized_coords[1] + y_min)

crosswalk_dict[row["divide_id"]] = global_coords

if i % 100 == 0:
perc = i / len(gdf_proj) * 100
print(f"{i}, {perc:.2f}%".ljust(40), end="\r")
i += 1
elapsed_time = datetime.datetime.now() - start_time
estimated_remaining_time = elapsed_time / i+1 * (len(gdf_proj) - i)
print(f"{i}, {perc:.2f}% elapsed: {elapsed_time}, remaining: {estimated_remaining_time}".ljust(40), end="\r")


weights_json = json.dumps(
{k: [x.tolist() for x in v] for k, v in crosswalk_dict.items()}
Expand Down

0 comments on commit 9010c66

Please sign in to comment.