Skip to content

Commit

Permalink
Merge pull request CIROH-UA#36 from JoshCu/main
Browse files Browse the repository at this point in the history
Speed up weight generation
  • Loading branch information
JordanLaserGit authored Jan 11, 2024
2 parents e276452 + ffd2032 commit e70bade
Showing 1 changed file with 44 additions and 18 deletions.
62 changes: 44 additions & 18 deletions forcingprocessor/src/forcingprocessor/weight_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
import geopandas as gpd
import numpy as np
import xarray as xr
from rasterio.features import rasterize
from rasterio.features import rasterize, bounds
import json, argparse, os
from pathlib import Path
from math import floor, ceil
import datetime
import multiprocessing
from functools import partial

def process_row(row, grid):
# Your existing row processing code here
geom_rasterize = rasterize(
[(row["geometry"], 1)],
out_shape=grid.rio.shape,
transform=grid.rio.transform(),
all_touched=True,
fill=0,
dtype="uint8",
)
# numpy.where runs slowly on large arrays
# so we slice off the empty space
y_min, x_max, y_max, x_min = bounds(row["geometry"], transform=~grid.rio.transform())
x_min = floor(x_min)
x_max = ceil(x_max)
y_min = floor(y_min)
y_max = ceil(y_max)
geom_rasterize = geom_rasterize[x_min:x_max, y_min:y_max]
localized_coords = np.where(geom_rasterize == 1)
global_coords = (localized_coords[0] + x_min, localized_coords[1] + y_min)

return (row["divide_id"], global_coords)

def generate_weights_file(geopackage,grid_file,weights_filepath):

Expand All @@ -20,27 +47,26 @@ def generate_weights_file(geopackage,grid_file,weights_filepath):
PARAMETER["standard_parallel_2",60.0],PARAMETER["latitude_of_origin",40.0],UNIT["Meter",1.0]]')

crosswalk_dict = {}
i = 0
for index, row in gdf_proj.iterrows():
geom_rasterize = rasterize(
[(row["geometry"], 1)],
out_shape=grid.rio.shape,
transform=grid.rio.transform(),
all_touched=True,
fill=0,
dtype="uint8",
)
crosswalk_dict[row["divide_id"]] = np.where(geom_rasterize == 1)


if i % 100 == 0:
perc = i / len(gdf_proj) * 100
print(f"{i}, {perc:.2f}%".ljust(40), end="\r")
i += 1
start_time = datetime.datetime.now()
print(f'Starting at {start_time}')
rows = [row for _, row in gdf_proj.iterrows()]
# Create a multiprocessing pool
with multiprocessing.Pool() as pool:
# Use a partial function to pass the constant 'grid' argument
func = partial(process_row, grid=grid)
# Map the function across all rows
results = pool.map(func, rows)

# Aggregate results
for divide_id, global_coords in results:
crosswalk_dict[divide_id] = global_coords


weights_json = json.dumps(
{k: [x.tolist() for x in v] for k, v in crosswalk_dict.items()}
)
print(f'Finished at {datetime.datetime.now()}')
print(f'Total time: {datetime.datetime.now() - start_time}')
with open(weights_filepath, "w") as f:
f.write(weights_json)

Expand Down

0 comments on commit e70bade

Please sign in to comment.