diff --git a/.gitignore b/.gitignore index 9634237..778dc13 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ dist *.egg-info **/tiles/tms* **/tiles/vpu* -*.tar.gz \ No newline at end of file +*.tar.gz +*.dat diff --git a/modules/data_processing/create_realization.py b/modules/data_processing/create_realization.py index 690d709..269dad8 100644 --- a/modules/data_processing/create_realization.py +++ b/modules/data_processing/create_realization.py @@ -7,7 +7,8 @@ import json import s3fs import xarray as xr - +from tqdm.rich import tqdm +from dask.distributed import Client, LocalCluster from data_processing.file_paths import file_paths from data_processing.gpkg_utils import get_cat_to_nex_flowpairs, get_cat_to_nhd_feature_id @@ -21,11 +22,20 @@ def get_approximate_gw_storage(paths: file_paths, start_date: datetime): fs = s3fs.S3FileSystem(anon=True) nc_url = f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/netcdf/GWOUT/{year}/{formatted_dt}.GWOUT_DOMAIN1" + + # make sure there's a dask cluster running + try: + client = Client.current() + except ValueError: + cluster = LocalCluster() + client = Client(cluster) + + with fs.open(nc_url) as file_obj: ds = xr.open_dataset(file_obj) water_levels = dict() - for cat, feature in cat_to_feature.items(): + for cat, feature in tqdm(cat_to_feature.items()): # this value is in CM, we need meters to match max_gw_depth # xarray says it's in mm, with 0.1 scale factor. calling .values doesn't apply the scale water_level = ds.sel(feature_id=feature).depth.values / 100 @@ -129,7 +139,6 @@ def make_ngen_realization_json( realization["time"]["start_time"] = start_time.strftime("%Y-%m-%d %H:%M:%S") realization["time"]["end_time"] = end_time.strftime("%Y-%m-%d %H:%M:%S") realization["time"]["output_interval"] = 3600 - realization["time"]["nts"] = nts with open(config_dir / "realization.json", "w") as file: json.dump(realization, file, indent=4) diff --git a/modules/data_processing/forcings.py b/modules/data_processing/forcings.py index ef893fc..00ae90b 100644 --- a/modules/data_processing/forcings.py +++ b/modules/data_processing/forcings.py @@ -6,11 +6,12 @@ from pathlib import Path from datetime import datetime from functools import partial +import psutil +from math import ceil from tqdm.rich import tqdm import numpy as np import dask -from dask.distributed import Client, LocalCluster, progress import geopandas as gpd import pandas as pd import xarray as xr @@ -21,7 +22,7 @@ from data_processing.zarr_utils import get_forcing_data logger = logging.getLogger(__name__) -# Suppress the specific warning from numpy +# Suppress the specific warning from numpy to keep the cli output clean warnings.filterwarnings( "ignore", message="'DataFrame.swapaxes' is deprecated", category=FutureWarning ) @@ -29,18 +30,20 @@ "ignore", message="'GeoDataFrame.swapaxes' is deprecated", category=FutureWarning ) - -def weighted_sum_of_cells(flat_tensor, cell_ids, factors): +def weighted_sum_of_cells(flat_raster: np.ndarray, cell_ids: np.ndarray , factors: np.ndarray): # Create an output array initialized with zeros - result = np.zeros(flat_tensor.shape[0]) - result = np.sum(flat_tensor[:, cell_ids] * factors, axis=1) + # dimensions are raster[time][x*y] + result = np.zeros(flat_raster.shape[0]) + result = np.sum(flat_raster[:, cell_ids] * factors, axis=1) sum_of_weights = np.sum(factors) result /= sum_of_weights return result + def get_cell_weights(raster, gdf): + # Get the cell weights for each divide output = exact_extract( - raster["LWDOWN"], + raster["RAINRATE"], gdf, ["cell_id", "coverage"], include_cols=["divide_id"], @@ -57,24 +60,40 @@ def add_APCP_SURFACE_to_dataset(dataset: xr.Dataset) -> xr.Dataset: return dataset -def save_to_csv(catchment_ds, csv_path): - catchment_df = catchment_ds.to_dataframe().drop(["catchment"], axis=1) - catchment_df.to_csv(csv_path) - return csv_path +def get_index_chunks(data: xr.DataArray) -> list[tuple[int, int]]: + # takes a data array and calculates the start and end index for each chunk + # based on the available memory. + array_memory_usage = data.nbytes + free_memory = psutil.virtual_memory().available * 0.8 # 80% of available memory + num_chunks = ceil(array_memory_usage / free_memory) + max_index = data.shape[0] + stride = max_index // num_chunks + chunk_start = range(0, max_index, stride) + index_chunks = [(start, start + stride) for start in chunk_start] + return index_chunks + + +def create_shared_memory(lazy_array): + logger.debug(f"Creating shared memory size {lazy_array.nbytes/ 10**6} Mb.") + shm = shared_memory.SharedMemory(create=True, size=lazy_array.nbytes) + shared_array = np.ndarray(lazy_array.shape, dtype=np.float32, buffer=shm.buf) + # if your data is not float32, xarray will do an automatic conversion here + # which consumes a lot more memory, forcings downloaded with this tool will work + for start, end in get_index_chunks(lazy_array): + # copy data from lazy to shared memory one chunk at a time + shared_array[start:end] = lazy_array[start:end] + time, x, y = shared_array.shape + shared_array = shared_array.reshape(time, -1) -def create_shared_memory(data): - shm = shared_memory.SharedMemory(create=True, size=data.nbytes) - shared_array = np.ndarray(data.shape, dtype=data.dtype, buffer=shm.buf) - shared_array[:] = data[:] - return shm, shared_array + return shm, shared_array.shape, shared_array.dtype def process_chunk_shared(variable, times, shm_name, shape, dtype, chunk): existing_shm = shared_memory.SharedMemory(name=shm_name) raster = np.ndarray(shape, dtype=dtype, buffer=existing_shm.buf) - results = [] + for catchment in chunk.index.unique(): cell_ids = chunk.loc[catchment]["cell_id"] weights = chunk.loc[catchment]["coverage"] @@ -87,11 +106,20 @@ def process_chunk_shared(variable, times, shm_name, shape, dtype, chunk): ) temp_da = temp_da.assign_coords(catchment=catchment) results.append(temp_da) - existing_shm.close() return xr.concat(results, dim="catchment") +def get_cell_weights_parallel(gdf, input_forcings, num_partitions): + gdf_chunks = np.array_split(gdf, num_partitions) + one_timestep = input_forcings.isel(time=0).compute() + with multiprocessing.Pool() as pool: + args = [(one_timestep, gdf_chunk) for gdf_chunk in gdf_chunks] + catchments = pool.starmap(get_cell_weights, args) + + return pd.concat(catchments) + + def compute_zonal_stats( gdf: gpd.GeoDataFrame, merged_data: xr.Dataset, forcings_dir: Path ) -> None: @@ -100,102 +128,117 @@ def compute_zonal_stats( num_partitions = multiprocessing.cpu_count() - 1 if num_partitions > len(gdf): num_partitions = len(gdf) - gfd_chunks = np.array_split(gdf, multiprocessing.cpu_count() - 1) - one_timestep = merged_data.isel(time=0).compute() - with multiprocessing.Pool() as pool: - args = [(one_timestep, gdf_chunk) for gdf_chunk in gfd_chunks] - catchments = pool.starmap(get_cell_weights, args) - catchments = pd.concat(catchments) + catchments = get_cell_weights_parallel(gdf, merged_data, num_partitions) - variables = [ - "LWDOWN", - "PSFC", - "Q2D", - "RAINRATE", - "SWDOWN", - "T2D", - "U2D", - "V2D", - ] + variables = { + "LWDOWN": "DLWRF_surface", + "PSFC": "PRES_surface", + "Q2D": "SPFH_2maboveground", + "RAINRATE": "precip_rate", + "SWDOWN": "DSWRF_surface", + "T2D": "TMP_2maboveground", + "U2D": "UGRD_10maboveground", + "V2D": "VGRD_10maboveground", + } results = [] + cat_chunks = np.array_split(catchments, num_partitions) + forcing_times = merged_data.time.values + + for variable in variables.keys(): + + if variable not in merged_data.data_vars: + logger.warning(f"Variable {variable} not in forcings, skipping") + continue + + # to make sure this fits in memory, we need to chunk the data + time_chunks = get_index_chunks(merged_data[variable]) + + for i, times in enumerate(time_chunks): + start, end = times + # select the chunk of time we want to process + data_chunk = merged_data[variable].isel(time=slice(start,end)) + # put it in shared memory + shm, shape, dtype = create_shared_memory(data_chunk) + times = data_chunk.time.values + # create a partial function to pass to the multiprocessing pool + partial_process_chunk = partial(process_chunk_shared,variable,times,shm.name,shape,dtype) + + logger.debug(f"Processing variable: {variable}") + # process the chunks of catchments in parallel + with multiprocessing.Pool(num_partitions) as pool: + variable_data = pool.map(partial_process_chunk, cat_chunks) + del partial_process_chunk + # clean up the shared memory + shm.close() + shm.unlink() + logger.debug(f"Processed variable: {variable}") + concatenated_da = xr.concat(variable_data, dim="catchment") + # delete the data to free up memory + del variable_data + logger.debug(f"Concatenated variable: {variable}") + # write this to disk now to save memory + # xarray will monitor memory usage, but it doesn't account for the shared memory used to store the raster + # This reduces memory usage by about 60% + concatenated_da.to_dataset(name=variable).to_netcdf(forcings_dir/ "temp" / f"{variable}_{i}.nc") + # Merge the chunks back together + datasets = [xr.open_dataset(forcings_dir / "temp" / f"{variable}_{i}.nc") for i in range(len(time_chunks))] + xr.concat(datasets, dim="time").to_netcdf(forcings_dir / f"{variable}.nc") + for file in forcings_dir.glob("temp/*.nc"): + file.unlink() + logger.info( + f"Forcing generation complete! Zonal stats computed in {time.time() - timer_start} seconds" + ) + write_outputs(forcings_dir, variables) - for variable in variables: - raster = merged_data[variable].values.reshape(merged_data[variable].shape[0], -1) - - # Create shared memory for the raster - shm, shared_raster = create_shared_memory(raster) - - cat_chunks = np.array_split(catchments, num_partitions) - times = merged_data.time.values - - partial_process_chunk = partial( - process_chunk_shared, - variable, - times, - shm.name, - shared_raster.shape, - shared_raster.dtype, - ) - - logger.debug(f"Processing variable: {variable}") - with multiprocessing.Pool(num_partitions) as pool: - variable_data = pool.map(partial_process_chunk, cat_chunks) - - # Clean up the shared memory - shm.close() - shm.unlink() - - logger.debug(f"Processed variable: {variable}") - concatenated_da = xr.concat(variable_data, dim="catchment") - logger.debug(f"Concatenated variable: {variable}") - results.append(concatenated_da.to_dataset(name=variable)) +def write_outputs(forcings_dir, variables): # Combine all variables into a single dataset + results = [xr.open_dataset(file) for file in forcings_dir.glob("*.nc")] final_ds = xr.merge(results) output_folder = forcings_dir / "by_catchment" - # Clear out any existing files - for file in output_folder.glob("*.csv"): - file.unlink() - final_ds = final_ds.rename_vars( - { - "LWDOWN": "DLWRF_surface", - "PSFC": "PRES_surface", - "Q2D": "SPFH_2maboveground", - "RAINRATE": "precip_rate", - "SWDOWN": "DSWRF_surface", - "T2D": "TMP_2maboveground", - "U2D": "UGRD_10maboveground", - "V2D": "VGRD_10maboveground", - } - ) + rename_dict = {} + for key, value in variables.items(): + if key in final_ds: + rename_dict[key] = value + final_ds = final_ds.rename_vars(rename_dict) final_ds = add_APCP_SURFACE_to_dataset(final_ds) - logger.info("Saving to disk") - # Save to disk - delayed_saves = [] - for catchment in final_ds.catchment.values: - catchment_ds = final_ds.sel(catchment=catchment) - csv_path = output_folder / f"{catchment}.csv" - delayed_save = dask.delayed(save_to_csv)(catchment_ds, csv_path) - delayed_saves.append(delayed_save) - logger.debug("Delayed saves created") - try: - client = Client.current() - except ValueError: - cluster = LocalCluster() - client = Client(cluster) - - dask.compute(*delayed_saves) + # this step halves the storage size of the forcings + for var in final_ds.data_vars: + final_ds[var] = final_ds[var].astype(np.float32) - logger.info( - f"Forcing generation complete! Zonal stats computed in {time.time() - timer_start} seconds" - ) - client.shutdown() + logger.info("Saving to disk") + # The format for the netcdf is to support a legacy format + # which is why it's a little "unorthodox" + # There are no coordinates, just dimensions, catchment ids are stored in a 1d data var + # and time is stored in a 2d data var with the same time array for every catchment + # time is stored as unix timestamps, units have to be set + + # add the catchment ids as a 1d data var + final_ds["ids"] = final_ds["catchment"].astype(str) + # time needs to be a 2d array of the same time array as unix timestamps for every catchment + with warnings.catch_warnings(action="ignore"): + time_array = ( + final_ds.time.astype("datetime64[s]").astype(np.int64).values // 10**9 + ) ## convert from ns to s + time_array = time_array.astype(np.int32) ## convert to int32 to save space + final_ds = final_ds.drop_vars(["catchment", "time"]) ## drop the original time and catchment vars + final_ds = final_ds.rename_dims({"catchment": "catchment-id"}) # rename the catchment dimension + # add the time as a 2d data var, yes this is wasting disk space. + final_ds["Time"] = (("catchment-id", "time"), [time_array for _ in range(len(final_ds["ids"]))]) + # set the time unit + final_ds["Time"].attrs["units"] = "s" + final_ds["Time"].attrs["epoch_start"] = "01/01/1970 00:00:00" # not needed but suppresses the ngen warning + + final_ds.to_netcdf(output_folder / "forcings.nc", engine="netcdf4") + # delete the individual variable files + for file in forcings_dir.glob("*.nc"): + file.unlink() def setup_directories(cat_id: str) -> file_paths: @@ -212,8 +255,6 @@ def create_forcings(start_time: str, end_time: str, output_folder_name: str) -> gdf = gpd.read_file(forcing_paths.geopackage_path, layer="divides").to_crs(projection) logger.debug(f"gdf bounds: {gdf.total_bounds}") - logger.debug(gdf) - logger.debug("Got gdf") if type(start_time) == datetime: start_time = start_time.strftime("%Y-%m-%d %H:%M") diff --git a/modules/data_processing/gpkg_utils.py b/modules/data_processing/gpkg_utils.py index 2569f59..6339b39 100644 --- a/modules/data_processing/gpkg_utils.py +++ b/modules/data_processing/gpkg_utils.py @@ -352,11 +352,11 @@ def get_cat_from_gage_id(gage_id: str, gpkg: Path = file_paths.conus_hydrofabric # use flowpath_attributes instead # both have errors, cross reference them with sqlite3.connect(gpkg) as con: - sql_query = f"""SELECT f.id - FROM flowpaths AS f + sql_query = f"""SELECT f.id + FROM flowpaths AS f JOIN hydrolocations AS h ON f.toid = h.id JOIN flowpath_attributes AS fa ON f.id = fa.id - WHERE h.hl_uri = 'Gages-{gage_id}' + WHERE h.hl_uri = 'Gages-{gage_id}' AND fa.rl_gages LIKE '%{gage_id}%'""" result = con.execute(sql_query).fetchall() if len(result) == 0: @@ -426,7 +426,6 @@ def get_cat_to_nhd_feature_id(gpkg: Path = file_paths.conus_hydrofabric) -> dict result = conn.execute(sql_query).fetchall() mapping = {} - print(result) for cat, feature in result: # the ids are stored as floats this converts to int to match nwm output # numeric ids should be stored as strings. diff --git a/modules/data_processing/zarr_utils.py b/modules/data_processing/zarr_utils.py index fd25f72..a5d4409 100644 --- a/modules/data_processing/zarr_utils.py +++ b/modules/data_processing/zarr_utils.py @@ -9,10 +9,11 @@ import geopandas as gpd from data_processing.file_paths import file_paths import time +from fsspec.mapping import FSMap logger = logging.getLogger(__name__) -def open_s3_store(url: str) -> s3fs.S3Map: +def open_s3_store(url: str) -> FSMap: """Open an s3 store from a given url.""" return s3fs.S3Map(url, s3=s3fs.S3FileSystem(anon=True)) @@ -74,6 +75,14 @@ def compute_store(stores: xr.Dataset, cached_nc_path: Path) -> xr.Dataset: if os.path.exists(temp_path): os.remove(temp_path) + ## Drop crs that's included with one of the datasets + stores = stores.drop_vars("crs") + + ## Cast every single variable to float32 to save space to save a lot of memory issues later + ## easier to do it now in this slow download step than later in the steps without dask + for var in stores.data_vars: + stores[var] = stores[var].astype("float32") + client = Client.current() future = client.compute(stores.to_netcdf(temp_path, compute=False)) # Display progress bar diff --git a/modules/data_sources/ngen-realization-template.json b/modules/data_sources/ngen-realization-template.json index 6b6d7b4..9a27086 100644 --- a/modules/data_sources/ngen-realization-template.json +++ b/modules/data_sources/ngen-realization-template.json @@ -1,98 +1,96 @@ { - "global": { - "formulations": [ + "global": { + "formulations": [ + { + "name": "bmi_multi", + "params": { + "name": "bmi_multi", + "model_type_name": "bmi_multi", + "main_output_variable": "Q_OUT", + "forcing_file": "", + "init_config": "", + "allow_exceed_end_time": true, + "modules": [ { - "name": "bmi_multi", - "params": { - "name": "bmi_multi", - "model_type_name": "bmi_multi", - "main_output_variable": "Q_OUT", - "forcing_file": "", - "init_config": "", - "allow_exceed_end_time": true, - "modules": [ - { - "name": "bmi_c++", - "params": { - "name": "bmi_c++", - "model_type_name": "SLOTH", - "main_output_variable": "z", - "init_config": "/dev/null", - "allow_exceed_end_time": true, - "fixed_time_step": false, - "uses_forcing_file": false, - "model_params": { - "sloth_ice_fraction_schaake(1,double,m,node)": 0.0, - "sloth_ice_fraction_xinanjiang(1,double,1,node)": 0.0, - "sloth_soil_moisture_profile(1,double,1,node)": 0.0 - }, - "library_file": "/dmod/shared_libs/libslothmodel.so", - "registration_function": "none" - } - }, - { - "name": "bmi_fortran", - "params": { - "name": "bmi_fortran", - "model_type_name": "NoahOWP", - "library_file": "/dmod/shared_libs/libsurfacebmi.so", - "forcing_file": "", - "init_config": "./config/cat_config/NOAH-OWP-M/{{id}}.input", - "allow_exceed_end_time": true, - "main_output_variable": "QINSUR", - "variables_names_map": { - "PRCPNONC": "precip_rate", - "Q2": "SPFH_2maboveground", - "SFCTMP": "TMP_2maboveground", - "UU": "UGRD_10maboveground", - "VV": "VGRD_10maboveground", - "LWDN": "DLWRF_surface", - "SOLDN": "DSWRF_surface", - "SFCPRS": "PRES_surface" - }, - "uses_forcing_file": false - } - }, - { - "name": "bmi_c", - "params": { - "name": "bmi_c", - "model_type_name": "CFE", - "main_output_variable": "Q_OUT", - "init_config": "./config/cat_config/CFE/{{id}}.ini", - "allow_exceed_end_time": true, - "fixed_time_step": false, - "uses_forcing_file": false, - "registration_function": "register_bmi_cfe", - "variables_names_map": { - "water_potential_evaporation_flux": "EVAPOTRANS", - "atmosphere_water__liquid_equivalent_precipitation_rate": "QINSUR", - "ice_fraction_schaake": "sloth_ice_fraction_schaake", - "ice_fraction_xinanjiang": "sloth_ice_fraction_xinanjiang", - "soil_moisture_profile": "sloth_soil_moisture_profile" - }, - "library_file": "/dmod/shared_libs/libcfebmi.so.1.0.0" - } - } - ], - "uses_forcing_file": false - } + "name": "bmi_c++", + "params": { + "name": "bmi_c++", + "model_type_name": "SLOTH", + "main_output_variable": "z", + "init_config": "/dev/null", + "allow_exceed_end_time": true, + "fixed_time_step": false, + "uses_forcing_file": false, + "model_params": { + "sloth_ice_fraction_schaake(1,double,m,node)": 0.0, + "sloth_ice_fraction_xinanjiang(1,double,1,node)": 0.0, + "sloth_soil_moisture_profile(1,double,1,node)": 0.0 + }, + "library_file": "/dmod/shared_libs/libslothmodel.so", + "registration_function": "none" + } + }, + { + "name": "bmi_fortran", + "params": { + "name": "bmi_fortran", + "model_type_name": "NoahOWP", + "library_file": "/dmod/shared_libs/libsurfacebmi.so", + "forcing_file": "", + "init_config": "./config/cat_config/NOAH-OWP-M/{{id}}.input", + "allow_exceed_end_time": true, + "main_output_variable": "QINSUR", + "variables_names_map": { + "PRCPNONC": "precip_rate", + "Q2": "SPFH_2maboveground", + "SFCTMP": "TMP_2maboveground", + "UU": "UGRD_10maboveground", + "VV": "VGRD_10maboveground", + "LWDN": "DLWRF_surface", + "SOLDN": "DSWRF_surface", + "SFCPRS": "PRES_surface" + }, + "uses_forcing_file": false + } + }, + { + "name": "bmi_c", + "params": { + "name": "bmi_c", + "model_type_name": "CFE", + "main_output_variable": "Q_OUT", + "init_config": "./config/cat_config/CFE/{{id}}.ini", + "allow_exceed_end_time": true, + "fixed_time_step": false, + "uses_forcing_file": false, + "registration_function": "register_bmi_cfe", + "variables_names_map": { + "water_potential_evaporation_flux": "EVAPOTRANS", + "atmosphere_water__liquid_equivalent_precipitation_rate": "QINSUR", + "ice_fraction_schaake": "sloth_ice_fraction_schaake", + "ice_fraction_xinanjiang": "sloth_ice_fraction_xinanjiang", + "soil_moisture_profile": "sloth_soil_moisture_profile" + }, + "library_file": "/dmod/shared_libs/libcfebmi.so.1.0.0" + } } - ], - "forcing": { - "file_pattern": "{{id}}.csv", - "path": "./forcings/by_catchment/", - "provider": "CsvPerFeature" + ], + "uses_forcing_file": false } - }, - "time": { - "start_time": "2010-01-01 00:00:00", - "end_time": "2010-02-23 00:00:00", - "output_interval": 3600, - "nts": 15264.0 - }, - "routing": { - "t_route_config_file_with_path": "/ngen/ngen/data/config/troute.yaml" - }, - "output_root": "/ngen/ngen/data/outputs/ngen" -} \ No newline at end of file + } + ], + "forcing": { + "path": "./forcings/by_catchment/forcings.nc", + "provider": "NetCDF" + } + }, + "time": { + "start_time": "2010-01-01 00:00:00", + "end_time": "2010-02-23 00:00:00", + "output_interval": 3600 + }, + "routing": { + "t_route_config_file_with_path": "/ngen/ngen/data/config/troute.yaml" + }, + "output_root": "/ngen/ngen/data/outputs/ngen" +} diff --git a/modules/data_sources/ngen-routing-template.yaml b/modules/data_sources/ngen-routing-template.yaml index af49a5c..fc948bf 100644 --- a/modules/data_sources/ngen-routing-template.yaml +++ b/modules/data_sources/ngen-routing-template.yaml @@ -36,7 +36,7 @@ compute_parameters: parallel_compute_method: by-subnetwork-jit-clustered # serial compute_kernel: V02-structured assume_short_ts: True - subnetwork_target_size: 10000 + subnetwork_target_size: 50 cpu_pool: {cpu_pool} restart_parameters: #---------- @@ -62,7 +62,7 @@ compute_parameters: qlat_input_folder: ./outputs/ngen/ qlat_file_pattern_filter: "nex-*" - binary_nexus_file_folder: ./outputs/parquet/ # this is required if nexus_file_pattern_filter="nex-*" + #binary_nexus_file_folder: ./outputs/parquet/ # if nexus_file_pattern_filter="nex-*" and you want it to reformat them as parquet, you need this #coastal_boundary_input_file : channel_forcing/schout_1.nc nts: {nts} #288 for 1day max_loop_size: {max_loop_size} # [number of timesteps] diff --git a/modules/ngiab_data_cli/__main__.py b/modules/ngiab_data_cli/__main__.py index 48f3938..f65a9c1 100644 --- a/modules/ngiab_data_cli/__main__.py +++ b/modules/ngiab_data_cli/__main__.py @@ -4,6 +4,8 @@ from typing import List import subprocess +from dask.distributed import Client + from data_processing.file_paths import file_paths from data_processing.gpkg_utils import get_catid_from_point, get_cat_from_gage_id from data_processing.subset import subset @@ -74,7 +76,7 @@ def set_dependent_flags(args, paths: file_paths): # realization and forcings require subset to have been run at least once if args.realization or args.forcings: - if not paths.subset_dir.exists(): + if not paths.subset_dir.exists() and not args.subset: logging.warning( "Subset required for forcings and realization generation, enabling subset." ) @@ -84,6 +86,7 @@ def set_dependent_flags(args, paths: file_paths): raise ValueError( "Both --start and --end are required for forcings generation or realization creation. YYYY-MM-DD" ) + return args @@ -138,6 +141,14 @@ def main() -> None: create_realization(output_folder, start_time=args.start_date, end_time=args.end_date) logging.info("Realization creation complete.") + # check if the dask client is still running and close it + try: + client = Client().current() + client.close() + except ValueError: + # value error is raised if no client is running + pass + if args.run: logging.info("Running Next Gen using NGIAB...") # open the partitions.json file and get the number of partitions @@ -149,8 +160,8 @@ def main() -> None: except: logging.error("Docker is not running, please start Docker and try again.") try: - # right now this expects a local hardcoded image name, while this is still a hidden feature it's fine - command = f'docker run --rm -it -v "{str(paths.subset_dir)}:/ngen/ngen/data" joshcu/ngiab_datastream /ngen/ngen/data/ auto {num_partitions}' + # command = f'docker run --rm -it -v "{str(paths.subset_dir)}:/ngen/ngen/data" prod_test /ngen/ngen/data/ auto {num_partitions}' + command = f'docker run --rm -it -v "{str(paths.subset_dir)}:/ngen/ngen/data" awiciroh/ciroh-ngen-image:latest-x86 /ngen/ngen/data/ auto {num_partitions}' subprocess.run(command, shell=True) logging.info("Next Gen run complete.") except: @@ -173,12 +184,20 @@ def main() -> None: if plot: logging.info("Plotting enabled") logging.info("Evaluating model performance...") - evaluate_folder(paths.subset_dir, plot=plot) + evaluate_folder(paths.subset_dir, plot=plot, debug=args.debug) except ImportError: logging.error( "Evaluation module not found. Please install the ngiab_eval package to evaluate model performance." ) + if args.vis: + try: + command = f'docker run --rm -it -p 3000:3000 -v "{str(paths.subset_dir)}:/ngen/ngen/data/" joshcu/ngiab_grafana:0.1.0' + subprocess.run(command, shell=True) + logging.info("Next Gen run complete.") + except: + logging.error("Next Gen run failed.") + logging.info("All operations completed successfully.") logging.info(f"Output folder: file:///{paths.subset_dir}") # set logging to ERROR level only as dask distributed can clutter the terminal with INFO messages diff --git a/modules/ngiab_data_cli/arguments.py b/modules/ngiab_data_cli/arguments.py index 723cc47..a615815 100644 --- a/modules/ngiab_data_cli/arguments.py +++ b/modules/ngiab_data_cli/arguments.py @@ -83,6 +83,9 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument( "--eval", action="store_true", help="Evaluate perforance of the model after running" ) + parser.add_argument( + "--vis", "--visualise", action="store_true", help="Visualize the model output" + ) parser.add_argument( "-a", "--all", @@ -98,6 +101,10 @@ def parse_arguments() -> argparse.Namespace: args.realization = True args.run = True args.eval = True + args.vis = True + + if args.vis: + args.eval = True if args.run: args.validate = True