Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Script: create_trunks() #30

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions marquette/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: MERIT
data_path: /projects/mhpi/data/${name}
zone: 72
zone: 74
gpu: 6
create_edges:
buffer: 0.3334
dx: 2000
Expand All @@ -15,14 +16,14 @@ create_N:
pad_gage_id: False
obs_dataset: ${data_path}/gage_information/obs_csvs/GRDC_point_data.csv
obs_dataset_output: ${data_path}/gage_information/formatted_gage_csvs/gnn_formatted_basins.csv
zone_obs_dataset: ${data_path}/gage_information/formatted_gage_csvs/gnn_zone_${zone}.csv
zone_obs_dataset: ${data_path}/gage_information/formatted_gage_csvs/subzones.csv
create_TMs:
MERIT:
save_sparse: True
TM: ${data_path}/zarr/TMs/sparse_MERIT_FLOWLINES_${zone}
shp_files: ${data_path}/raw/basins/cat_pfaf_${zone}_MERIT_Hydro_v07_Basins_v01_bugfix1.shp
create_streamflow:
version: merit_conus_v6.14
version: merit_conus_v6.18_snow
data_store: ${data_path}/streamflow/zarr/${create_streamflow.version}/${zone}
obs_attributes: ${data_path}/gage_information/MERIT_basin_area_info
predictions: /projects/mhpi/yxs275/DM_output/water_loss_model/dPL_local_daymet_new_attr_RMSEloss_with_log_2800
Expand Down
1 change: 1 addition & 0 deletions marquette/merit/_streamflow_conversion_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def calculate_merit_flow(cfg: DictConfig, edges: zarr.hierarchy.Group) -> None:
for key in zone_keys:
zone_comids.append(streamflow_predictions_root[key].COMID[:])
zone_runoff.append(streamflow_predictions_root[key].Qr[:])
# zone_runoff.append(streamflow_predictions_root[key]["Q0"][:])
streamflow_comids = np.concatenate(zone_comids).astype(int)
file_runoff = np.transpose(np.concatenate(zone_runoff))
del zone_comids
Expand Down
20 changes: 13 additions & 7 deletions marquette/merit/extensions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import logging
from pathlib import Path

# import cugraph as cnx
# import cudf
import binsparse
import cupy as cp
from cupyx.scipy import sparse as cp_sparse
import geopandas as gpd
import networkx as nx
import numpy as np
import pandas as pd
import polars as pl
from scipy import sparse
import zarr
from omegaconf import DictConfig
from tqdm import tqdm
Expand Down Expand Up @@ -282,7 +283,7 @@ def calculate_q_prime_summation(cfg: DictConfig, edges: zarr.Group) -> None:
time_split = np.array_split(streamflow_time, n) # type: ignore
for idx, time_range in enumerate(time_split):
q_prime_cp = cp.zeros([time_range.shape[0], dim_1]).transpose(1, 0)

q_prime_np_edge_count_cp = cp.zeros([time_range.shape[0], dim_1]).transpose(1, 0)
for jdx, _ in enumerate(
tqdm(
edge_ids,
Expand All @@ -300,18 +301,23 @@ def calculate_q_prime_summation(cfg: DictConfig, edges: zarr.Group) -> None:
downstream_comid_idx = cp.unique(
edges_segment_sorting_index[downstream_idx]
) # type: ignore
q_prime_cp[downstream_comid_idx] += streamflow_data[
streamflow = streamflow_data[
time_range, streamflow_ds_id
] # type: ignore
]
q_prime_cp[downstream_comid_idx] += streamflow
q_prime_np_edge_count_cp[downstream_comid_idx] += 1
except nx.exception.NetworkXError:
# This means there is no connectivity from this basin. It's one-node graph
q_prime_cp[streamflow_ds_id] = streamflow_data[
streamflow = streamflow_data[
time_range, streamflow_ds_id
]
q_prime_cp[streamflow_ds_id] = streamflow
q_prime_np_edge_count_cp[streamflow_ds_id] += 1

print("Saving GPU Memory to CPU; freeing GPU Memory")
q_prime_np[:, time_range] = cp.asnumpy(q_prime_cp)
q_prime_np[:, time_range] = cp.asnumpy(q_prime_cp / q_prime_np_edge_count_cp)
del q_prime_cp
del q_prime_np_edge_count_cp
cp.get_default_memory_pool().free_all_blocks()

edges.array(
Expand Down
132 changes: 132 additions & 0 deletions marquette/merit/scripts/create_trunks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Functions for processing and filtering zone pairs based on subzone indices.

This module provides functionality to take indices from subzones and create saved
coordinate (coo) pairs based on the missing indices. It handles memory-efficient
processing of large arrays using chunking and GPU acceleration via CuPy.

This function creates the pairs that are used in the Network object of dMC

Notes
-----
The main workflow consists of:
1. Loading full zone and subzone pairs
2. Finding and removing upstream connections
3. Saving the filtered pairs to a new Zarr array
"""
import time
from pathlib import Path

import cupy as cp
import numpy as np
import zarr
from tqdm import trange

def _find_upstream_mask(full_pairs: np.ndarray, sub_pairs: np.ndarray, chunk_size=1000) -> cp.ndarray:
"""
Find indices where upstream connections exist in the full-zone connections.

Notes
-----
Algorithm steps:
1. Create broadcasted comparison
2. Handle NaN equality separately
3. Combine regular equality and NaN equality
4. Check if both elements in each pair match (axis=2)
5. Check if any pair in sub_pairs matches (axis=1)

Parameters
----------
full_pairs : cp.ndarray
The pairs intersections [to, from] for the full zone
sub_pairs : cp.ndarray
The pairs intersections [to, from] for the sub zone we're masking out
chunk_size : int, optional
Size of chunks to process at once, by default 1000

Returns
-------
cp.ndarray
Boolean mask indicating which pairs in full_pairs have upstream connections
"""
n_rows = full_pairs.shape[0]
final_mask = np.zeros(n_rows, dtype=bool)

for start_idx in trange(0, n_rows, chunk_size, desc="Processing chunks for subpairs"):
end_idx = min(start_idx + chunk_size, n_rows)
chunk = full_pairs[start_idx:end_idx]

regular_equal = chunk[:, None] == sub_pairs
nan_equal = cp.isnan(chunk)[:, None] & cp.isnan(sub_pairs)
equal_or_both_nan = regular_equal | nan_equal
pairs_match = cp.all(equal_or_both_nan, axis=2)
chunk_mask = pairs_match.any(axis=1)

final_mask[start_idx:end_idx] = cp.asnumpy(chunk_mask)
return final_mask


def create_trunks(coo_path: zarr.Group, subzones: list[str]) -> cp.ndarray:
"""
Create trunk pairs by removing subzone connections from full zone pairs.

Parameters
----------
coo_path : zarr.Group
Zarr group containing the coordinate pairs for full zone and subzones
subzones : list[str]
List of subzone names to process and remove from full zone

Returns
-------
cp.ndarray
Filtered pairs array with subzone connections removed

Notes
-----
This function:
1. Loads full zone pairs into GPU memory
2. Iteratively processes each subzone
3. Removes matching pairs from full zone
4. Manages GPU memory by freeing unused arrays
"""
full_zone_pairs = cp.array(coo_path["full_zone"].pairs[:])
mempool = cp.get_default_memory_pool()
start_time = time.perf_counter()
for _subzone in subzones:
subzone_pairs = cp.array(coo_path[_subzone].pairs[:])
mask = _find_upstream_mask(full_zone_pairs, subzone_pairs)
full_zone_pairs = full_zone_pairs[~mask]
del mask
del subzone_pairs
mempool.free_all_blocks()
end_time = time.perf_counter()
print(f"Time taken: {end_time - start_time:.4f} seconds")
return full_zone_pairs


if __name__ == "__main__":
"""
Main execution block for processing and saving trunk pairs.

This script:
1. Loads coordinate pairs from a specified zone
2. Removes connections from specified subzones
3. Saves the resulting filtered pairs to a new Zarr array
"""
zone = "74"
coo_path = zarr.open_group(Path("/projects/mhpi/data/MERIT/zarr/gage_coo_indices") / zone)
subzones = [
"arkansas",
"missouri",
"ohio",
"tennessee",
"upper_mississippi",
]
save_name = f"zone_without_{'_'.join(subzones)}"
pairs = create_trunks(coo_path, subzones)
root = zarr.group(Path("/projects/mhpi/data/MERIT/zarr/gage_coo_indices") / zone / save_name)
root.create_dataset(
"pairs", data=cp.asnumpy(pairs), chunks=(5000, 5000), dtype="float32"
)
print(f"Saved: {save_name}")
Loading