Skip to content

Commit

Permalink
add multiprocessing to weight generation
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshCu committed Jan 11, 2024
1 parent 9010c66 commit 181e785
Showing 1 changed file with 38 additions and 26 deletions.
64 changes: 38 additions & 26 deletions forcingprocessor/src/forcingprocessor/weight_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,31 @@
from pathlib import Path
from math import floor, ceil
import datetime
import multiprocessing
from functools import partial

def process_row(row, grid):
# Your existing row processing code here
geom_rasterize = rasterize(
[(row["geometry"], 1)],
out_shape=grid.rio.shape,
transform=grid.rio.transform(),
all_touched=True,
fill=0,
dtype="uint8",
)
# numpy.where runs slowly on large arrays
# so we slice off the empty space
y_min, x_max, y_max, x_min = bounds(row["geometry"], transform=~grid.rio.transform())
x_min = floor(x_min)
x_max = ceil(x_max)
y_min = floor(y_min)
y_max = ceil(y_max)
geom_rasterize = geom_rasterize[x_min:x_max, y_min:y_max]
localized_coords = np.where(geom_rasterize == 1)
global_coords = (localized_coords[0] + x_min, localized_coords[1] + y_min)

return (row["divide_id"], global_coords)

def generate_weights_file(geopackage,grid_file,weights_filepath):

Expand All @@ -23,38 +48,25 @@ def generate_weights_file(geopackage,grid_file,weights_filepath):

crosswalk_dict = {}
start_time = datetime.datetime.now()
for i, row in gdf_proj.iterrows():
geom_rasterize = rasterize(
[(row["geometry"], 1)],
out_shape=grid.rio.shape,
transform=grid.rio.transform(),
all_touched=True,
fill=0,
dtype="uint8",
)
# numpy.where runs slowly on large arrays
# so we slice off the empty space
y_min, x_max, y_max, x_min = bounds(row["geometry"], transform=~grid.rio.transform())
x_min = floor(x_min)
x_max = ceil(x_max)
y_min = floor(y_min)
y_max = ceil(y_max)
geom_rasterize = geom_rasterize[x_min:x_max, y_min:y_max]
localized_coords = np.where(geom_rasterize == 1)
global_coords = (localized_coords[0] + x_min, localized_coords[1] + y_min)

crosswalk_dict[row["divide_id"]] = global_coords
print(f'Starting at {start_time}')
rows = [row for _, row in gdf_proj.iterrows()]
# Create a multiprocessing pool
with multiprocessing.Pool() as pool:
# Use a partial function to pass the constant 'grid' argument
func = partial(process_row, grid=grid)
# Map the function across all rows
results = pool.map(func, rows)

if i % 100 == 0:
perc = i / len(gdf_proj) * 100
elapsed_time = datetime.datetime.now() - start_time
estimated_remaining_time = elapsed_time / i+1 * (len(gdf_proj) - i)
print(f"{i}, {perc:.2f}% elapsed: {elapsed_time}, remaining: {estimated_remaining_time}".ljust(40), end="\r")
# Aggregate results
for divide_id, global_coords in results:
crosswalk_dict[divide_id] = global_coords


weights_json = json.dumps(
{k: [x.tolist() for x in v] for k, v in crosswalk_dict.items()}
)
print(f'Finished at {datetime.datetime.now()}')
print(f'Total time: {datetime.datetime.now() - start_time}')
with open(weights_filepath, "w") as f:
f.write(weights_json)

Expand Down

0 comments on commit 181e785

Please sign in to comment.