diff --git a/forcingprocessor/src/forcingprocessor/weight_generator.py b/forcingprocessor/src/forcingprocessor/weight_generator.py index c45207a5..cceb909b 100644 --- a/forcingprocessor/src/forcingprocessor/weight_generator.py +++ b/forcingprocessor/src/forcingprocessor/weight_generator.py @@ -1,9 +1,36 @@ import geopandas as gpd import numpy as np import xarray as xr -from rasterio.features import rasterize +from rasterio.features import rasterize, bounds import json, argparse, os from pathlib import Path +from math import floor, ceil +import datetime +import multiprocessing +from functools import partial + +def process_row(row, grid): + # Your existing row processing code here + geom_rasterize = rasterize( + [(row["geometry"], 1)], + out_shape=grid.rio.shape, + transform=grid.rio.transform(), + all_touched=True, + fill=0, + dtype="uint8", + ) + # numpy.where runs slowly on large arrays + # so we slice off the empty space + y_min, x_max, y_max, x_min = bounds(row["geometry"], transform=~grid.rio.transform()) + x_min = floor(x_min) + x_max = ceil(x_max) + y_min = floor(y_min) + y_max = ceil(y_max) + geom_rasterize = geom_rasterize[x_min:x_max, y_min:y_max] + localized_coords = np.where(geom_rasterize == 1) + global_coords = (localized_coords[0] + x_min, localized_coords[1] + y_min) + + return (row["divide_id"], global_coords) def generate_weights_file(geopackage,grid_file,weights_filepath): @@ -20,27 +47,26 @@ def generate_weights_file(geopackage,grid_file,weights_filepath): PARAMETER["standard_parallel_2",60.0],PARAMETER["latitude_of_origin",40.0],UNIT["Meter",1.0]]') crosswalk_dict = {} - i = 0 - for index, row in gdf_proj.iterrows(): - geom_rasterize = rasterize( - [(row["geometry"], 1)], - out_shape=grid.rio.shape, - transform=grid.rio.transform(), - all_touched=True, - fill=0, - dtype="uint8", - ) - crosswalk_dict[row["divide_id"]] = np.where(geom_rasterize == 1) - - - if i % 100 == 0: - perc = i / len(gdf_proj) * 100 - print(f"{i}, {perc:.2f}%".ljust(40), end="\r") - i += 1 + start_time = datetime.datetime.now() + print(f'Starting at {start_time}') + rows = [row for _, row in gdf_proj.iterrows()] + # Create a multiprocessing pool + with multiprocessing.Pool() as pool: + # Use a partial function to pass the constant 'grid' argument + func = partial(process_row, grid=grid) + # Map the function across all rows + results = pool.map(func, rows) + + # Aggregate results + for divide_id, global_coords in results: + crosswalk_dict[divide_id] = global_coords + weights_json = json.dumps( {k: [x.tolist() for x in v] for k, v in crosswalk_dict.items()} ) + print(f'Finished at {datetime.datetime.now()}') + print(f'Total time: {datetime.datetime.now() - start_time}') with open(weights_filepath, "w") as f: f.write(weights_json)