Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up weight generation #36

Merged
merged 2 commits into from
Jan 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions forcingprocessor/src/forcingprocessor/weight_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
import geopandas as gpd
import numpy as np
import xarray as xr
from rasterio.features import rasterize
from rasterio.features import rasterize, bounds
import json, argparse, os
from pathlib import Path
from math import floor, ceil
import datetime
import multiprocessing
from functools import partial

def process_row(row, grid):
# Your existing row processing code here
geom_rasterize = rasterize(
[(row["geometry"], 1)],
out_shape=grid.rio.shape,
transform=grid.rio.transform(),
all_touched=True,
fill=0,
dtype="uint8",
)
# numpy.where runs slowly on large arrays
# so we slice off the empty space
y_min, x_max, y_max, x_min = bounds(row["geometry"], transform=~grid.rio.transform())
x_min = floor(x_min)
x_max = ceil(x_max)
y_min = floor(y_min)
y_max = ceil(y_max)
geom_rasterize = geom_rasterize[x_min:x_max, y_min:y_max]
localized_coords = np.where(geom_rasterize == 1)
global_coords = (localized_coords[0] + x_min, localized_coords[1] + y_min)

return (row["divide_id"], global_coords)

def generate_weights_file(geopackage,grid_file,weights_filepath):

Expand All @@ -20,27 +47,26 @@ def generate_weights_file(geopackage,grid_file,weights_filepath):
PARAMETER["standard_parallel_2",60.0],PARAMETER["latitude_of_origin",40.0],UNIT["Meter",1.0]]')

crosswalk_dict = {}
i = 0
for index, row in gdf_proj.iterrows():
geom_rasterize = rasterize(
[(row["geometry"], 1)],
out_shape=grid.rio.shape,
transform=grid.rio.transform(),
all_touched=True,
fill=0,
dtype="uint8",
)
crosswalk_dict[row["divide_id"]] = np.where(geom_rasterize == 1)


if i % 100 == 0:
perc = i / len(gdf_proj) * 100
print(f"{i}, {perc:.2f}%".ljust(40), end="\r")
i += 1
start_time = datetime.datetime.now()
print(f'Starting at {start_time}')
rows = [row for _, row in gdf_proj.iterrows()]
# Create a multiprocessing pool
with multiprocessing.Pool() as pool:
# Use a partial function to pass the constant 'grid' argument
func = partial(process_row, grid=grid)
# Map the function across all rows
results = pool.map(func, rows)

# Aggregate results
for divide_id, global_coords in results:
crosswalk_dict[divide_id] = global_coords


weights_json = json.dumps(
{k: [x.tolist() for x in v] for k, v in crosswalk_dict.items()}
)
print(f'Finished at {datetime.datetime.now()}')
print(f'Total time: {datetime.datetime.now() - start_time}')
with open(weights_filepath, "w") as f:
f.write(weights_json)

Expand Down