diff --git a/marquette/conf/config.yaml b/marquette/conf/config.yaml index 6d2a5e0..c5fed54 100644 --- a/marquette/conf/config.yaml +++ b/marquette/conf/config.yaml @@ -1,6 +1,7 @@ name: MERIT data_path: /projects/mhpi/data/${name} -zone: 72 +zone: 74 +gpu: 6 create_edges: buffer: 0.3334 dx: 2000 @@ -15,14 +16,14 @@ create_N: pad_gage_id: False obs_dataset: ${data_path}/gage_information/obs_csvs/GRDC_point_data.csv obs_dataset_output: ${data_path}/gage_information/formatted_gage_csvs/gnn_formatted_basins.csv - zone_obs_dataset: ${data_path}/gage_information/formatted_gage_csvs/gnn_zone_${zone}.csv + zone_obs_dataset: ${data_path}/gage_information/formatted_gage_csvs/subzones.csv create_TMs: MERIT: save_sparse: True TM: ${data_path}/zarr/TMs/sparse_MERIT_FLOWLINES_${zone} shp_files: ${data_path}/raw/basins/cat_pfaf_${zone}_MERIT_Hydro_v07_Basins_v01_bugfix1.shp create_streamflow: - version: merit_conus_v6.14 + version: merit_conus_v6.18_snow data_store: ${data_path}/streamflow/zarr/${create_streamflow.version}/${zone} obs_attributes: ${data_path}/gage_information/MERIT_basin_area_info predictions: /projects/mhpi/yxs275/DM_output/water_loss_model/dPL_local_daymet_new_attr_RMSEloss_with_log_2800 diff --git a/marquette/merit/_streamflow_conversion_functions.py b/marquette/merit/_streamflow_conversion_functions.py index d34907b..6972e10 100644 --- a/marquette/merit/_streamflow_conversion_functions.py +++ b/marquette/merit/_streamflow_conversion_functions.py @@ -140,6 +140,7 @@ def calculate_merit_flow(cfg: DictConfig, edges: zarr.hierarchy.Group) -> None: for key in zone_keys: zone_comids.append(streamflow_predictions_root[key].COMID[:]) zone_runoff.append(streamflow_predictions_root[key].Qr[:]) + # zone_runoff.append(streamflow_predictions_root[key]["Q0"][:]) streamflow_comids = np.concatenate(zone_comids).astype(int) file_runoff = np.transpose(np.concatenate(zone_runoff)) del zone_comids diff --git a/marquette/merit/extensions.py b/marquette/merit/extensions.py index bbf9bdc..271d8a5 100644 --- a/marquette/merit/extensions.py +++ b/marquette/merit/extensions.py @@ -1,14 +1,15 @@ import logging from pathlib import Path -# import cugraph as cnx -# import cudf +import binsparse import cupy as cp +from cupyx.scipy import sparse as cp_sparse import geopandas as gpd import networkx as nx import numpy as np import pandas as pd import polars as pl +from scipy import sparse import zarr from omegaconf import DictConfig from tqdm import tqdm @@ -282,7 +283,7 @@ def calculate_q_prime_summation(cfg: DictConfig, edges: zarr.Group) -> None: time_split = np.array_split(streamflow_time, n) # type: ignore for idx, time_range in enumerate(time_split): q_prime_cp = cp.zeros([time_range.shape[0], dim_1]).transpose(1, 0) - + q_prime_np_edge_count_cp = cp.zeros([time_range.shape[0], dim_1]).transpose(1, 0) for jdx, _ in enumerate( tqdm( edge_ids, @@ -300,18 +301,23 @@ def calculate_q_prime_summation(cfg: DictConfig, edges: zarr.Group) -> None: downstream_comid_idx = cp.unique( edges_segment_sorting_index[downstream_idx] ) # type: ignore - q_prime_cp[downstream_comid_idx] += streamflow_data[ + streamflow = streamflow_data[ time_range, streamflow_ds_id - ] # type: ignore + ] + q_prime_cp[downstream_comid_idx] += streamflow + q_prime_np_edge_count_cp[downstream_comid_idx] += 1 except nx.exception.NetworkXError: # This means there is no connectivity from this basin. It's one-node graph - q_prime_cp[streamflow_ds_id] = streamflow_data[ + streamflow = streamflow_data[ time_range, streamflow_ds_id ] + q_prime_cp[streamflow_ds_id] = streamflow + q_prime_np_edge_count_cp[streamflow_ds_id] += 1 print("Saving GPU Memory to CPU; freeing GPU Memory") - q_prime_np[:, time_range] = cp.asnumpy(q_prime_cp) + q_prime_np[:, time_range] = cp.asnumpy(q_prime_cp / q_prime_np_edge_count_cp) del q_prime_cp + del q_prime_np_edge_count_cp cp.get_default_memory_pool().free_all_blocks() edges.array( diff --git a/marquette/merit/scripts/create_trunks.py b/marquette/merit/scripts/create_trunks.py new file mode 100644 index 0000000..fcdfb32 --- /dev/null +++ b/marquette/merit/scripts/create_trunks.py @@ -0,0 +1,132 @@ +""" +Functions for processing and filtering zone pairs based on subzone indices. + +This module provides functionality to take indices from subzones and create saved +coordinate (coo) pairs based on the missing indices. It handles memory-efficient +processing of large arrays using chunking and GPU acceleration via CuPy. + +This function creates the pairs that are used in the Network object of dMC + +Notes +----- +The main workflow consists of: +1. Loading full zone and subzone pairs +2. Finding and removing upstream connections +3. Saving the filtered pairs to a new Zarr array +""" +import time +from pathlib import Path + +import cupy as cp +import numpy as np +import zarr +from tqdm import trange + +def _find_upstream_mask(full_pairs: np.ndarray, sub_pairs: np.ndarray, chunk_size=1000) -> cp.ndarray: + """ + Find indices where upstream connections exist in the full-zone connections. + + Notes + ----- + Algorithm steps: + 1. Create broadcasted comparison + 2. Handle NaN equality separately + 3. Combine regular equality and NaN equality + 4. Check if both elements in each pair match (axis=2) + 5. Check if any pair in sub_pairs matches (axis=1) + + Parameters + ---------- + full_pairs : cp.ndarray + The pairs intersections [to, from] for the full zone + sub_pairs : cp.ndarray + The pairs intersections [to, from] for the sub zone we're masking out + chunk_size : int, optional + Size of chunks to process at once, by default 1000 + + Returns + ------- + cp.ndarray + Boolean mask indicating which pairs in full_pairs have upstream connections + """ + n_rows = full_pairs.shape[0] + final_mask = np.zeros(n_rows, dtype=bool) + + for start_idx in trange(0, n_rows, chunk_size, desc="Processing chunks for subpairs"): + end_idx = min(start_idx + chunk_size, n_rows) + chunk = full_pairs[start_idx:end_idx] + + regular_equal = chunk[:, None] == sub_pairs + nan_equal = cp.isnan(chunk)[:, None] & cp.isnan(sub_pairs) + equal_or_both_nan = regular_equal | nan_equal + pairs_match = cp.all(equal_or_both_nan, axis=2) + chunk_mask = pairs_match.any(axis=1) + + final_mask[start_idx:end_idx] = cp.asnumpy(chunk_mask) + return final_mask + + +def create_trunks(coo_path: zarr.Group, subzones: list[str]) -> cp.ndarray: + """ + Create trunk pairs by removing subzone connections from full zone pairs. + + Parameters + ---------- + coo_path : zarr.Group + Zarr group containing the coordinate pairs for full zone and subzones + subzones : list[str] + List of subzone names to process and remove from full zone + + Returns + ------- + cp.ndarray + Filtered pairs array with subzone connections removed + + Notes + ----- + This function: + 1. Loads full zone pairs into GPU memory + 2. Iteratively processes each subzone + 3. Removes matching pairs from full zone + 4. Manages GPU memory by freeing unused arrays + """ + full_zone_pairs = cp.array(coo_path["full_zone"].pairs[:]) + mempool = cp.get_default_memory_pool() + start_time = time.perf_counter() + for _subzone in subzones: + subzone_pairs = cp.array(coo_path[_subzone].pairs[:]) + mask = _find_upstream_mask(full_zone_pairs, subzone_pairs) + full_zone_pairs = full_zone_pairs[~mask] + del mask + del subzone_pairs + mempool.free_all_blocks() + end_time = time.perf_counter() + print(f"Time taken: {end_time - start_time:.4f} seconds") + return full_zone_pairs + + +if __name__ == "__main__": + """ + Main execution block for processing and saving trunk pairs. + + This script: + 1. Loads coordinate pairs from a specified zone + 2. Removes connections from specified subzones + 3. Saves the resulting filtered pairs to a new Zarr array + """ + zone = "74" + coo_path = zarr.open_group(Path("/projects/mhpi/data/MERIT/zarr/gage_coo_indices") / zone) + subzones = [ + "arkansas", + "missouri", + "ohio", + "tennessee", + "upper_mississippi", + ] + save_name = f"zone_without_{'_'.join(subzones)}" + pairs = create_trunks(coo_path, subzones) + root = zarr.group(Path("/projects/mhpi/data/MERIT/zarr/gage_coo_indices") / zone / save_name) + root.create_dataset( + "pairs", data=cp.asnumpy(pairs), chunks=(5000, 5000), dtype="float32" + ) + print(f"Saved: {save_name}")