diff --git a/forcingprocessor/src/forcingprocessor/weight_generator.py b/forcingprocessor/src/forcingprocessor/weight_generator.py index 311a964e..cceb909b 100644 --- a/forcingprocessor/src/forcingprocessor/weight_generator.py +++ b/forcingprocessor/src/forcingprocessor/weight_generator.py @@ -6,6 +6,31 @@ from pathlib import Path from math import floor, ceil import datetime +import multiprocessing +from functools import partial + +def process_row(row, grid): + # Your existing row processing code here + geom_rasterize = rasterize( + [(row["geometry"], 1)], + out_shape=grid.rio.shape, + transform=grid.rio.transform(), + all_touched=True, + fill=0, + dtype="uint8", + ) + # numpy.where runs slowly on large arrays + # so we slice off the empty space + y_min, x_max, y_max, x_min = bounds(row["geometry"], transform=~grid.rio.transform()) + x_min = floor(x_min) + x_max = ceil(x_max) + y_min = floor(y_min) + y_max = ceil(y_max) + geom_rasterize = geom_rasterize[x_min:x_max, y_min:y_max] + localized_coords = np.where(geom_rasterize == 1) + global_coords = (localized_coords[0] + x_min, localized_coords[1] + y_min) + + return (row["divide_id"], global_coords) def generate_weights_file(geopackage,grid_file,weights_filepath): @@ -23,38 +48,25 @@ def generate_weights_file(geopackage,grid_file,weights_filepath): crosswalk_dict = {} start_time = datetime.datetime.now() - for i, row in gdf_proj.iterrows(): - geom_rasterize = rasterize( - [(row["geometry"], 1)], - out_shape=grid.rio.shape, - transform=grid.rio.transform(), - all_touched=True, - fill=0, - dtype="uint8", - ) - # numpy.where runs slowly on large arrays - # so we slice off the empty space - y_min, x_max, y_max, x_min = bounds(row["geometry"], transform=~grid.rio.transform()) - x_min = floor(x_min) - x_max = ceil(x_max) - y_min = floor(y_min) - y_max = ceil(y_max) - geom_rasterize = geom_rasterize[x_min:x_max, y_min:y_max] - localized_coords = np.where(geom_rasterize == 1) - global_coords = (localized_coords[0] + x_min, localized_coords[1] + y_min) - - crosswalk_dict[row["divide_id"]] = global_coords + print(f'Starting at {start_time}') + rows = [row for _, row in gdf_proj.iterrows()] + # Create a multiprocessing pool + with multiprocessing.Pool() as pool: + # Use a partial function to pass the constant 'grid' argument + func = partial(process_row, grid=grid) + # Map the function across all rows + results = pool.map(func, rows) - if i % 100 == 0: - perc = i / len(gdf_proj) * 100 - elapsed_time = datetime.datetime.now() - start_time - estimated_remaining_time = elapsed_time / i+1 * (len(gdf_proj) - i) - print(f"{i}, {perc:.2f}% elapsed: {elapsed_time}, remaining: {estimated_remaining_time}".ljust(40), end="\r") + # Aggregate results + for divide_id, global_coords in results: + crosswalk_dict[divide_id] = global_coords weights_json = json.dumps( {k: [x.tolist() for x in v] for k, v in crosswalk_dict.items()} ) + print(f'Finished at {datetime.datetime.now()}') + print(f'Total time: {datetime.datetime.now() - start_time}') with open(weights_filepath, "w") as f: f.write(weights_json)