diff --git a/dev-requirements.txt b/dev-requirements.txt index 19ab8947..c09b373c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.11 # by the following command: # -# pip-compile --extra=dev --output-file=dev-requirements.txt +# pip-compile --extra=dev --output-file=dev-requirements.txt pyproject.toml # aenum==3.1.11 # via @@ -11,7 +11,6 @@ aenum==3.1.11 affine==2.4.0 # via # pyflwdir - # pysheds # rasterio annotated-types==0.7.0 # via pydantic @@ -86,8 +85,6 @@ fsspec==2024.6.1 # via fastparquet geographiclib==2.0 # via geopy -geojson==3.1.0 - # via pysheds geopandas==1.0.1 # via # osmnx @@ -102,8 +99,6 @@ identify==2.6.0 # via pre-commit idna==3.8 # via requests -imageio==2.34.2 - # via scikit-image iniconfig==2.0.0 # via pytest joblib==1.4.2 @@ -116,8 +111,6 @@ jsonschema-specifications==2023.12.1 # via jsonschema julian==0.14 # via pyswmm -lazy-loader==0.4 - # via scikit-image llvmlite==0.43.0 # via numba loguru==0.7.2 @@ -138,20 +131,16 @@ networkx==3.3 # via # netcomp # osmnx - # scikit-image # swmmanywhere (pyproject.toml) nodeenv==1.9.1 # via pre-commit numba==0.60.0 - # via - # pyflwdir - # pysheds + # via pyflwdir numpy==1.26.4 # via # cftime # fastparquet # geopandas - # imageio # netcdf4 # netcomp # numba @@ -160,15 +149,12 @@ numpy==1.26.4 # pyarrow # pyflwdir # pyogrio - # pysheds # rasterio # rioxarray - # scikit-image # scipy # shapely # snuggs # swmmanywhere (pyproject.toml) - # tifffile # xarray osmnx==1.9.3 # via swmmanywhere (pyproject.toml) @@ -177,26 +163,19 @@ packaging==24.1 # build # fastparquet # geopandas - # lazy-loader # planetary-computer # pyogrio # pyswmm # pytest # rioxarray - # scikit-image # xarray pandas==2.2.2 # via # fastparquet # geopandas # osmnx - # pysheds # swmmanywhere (pyproject.toml) # xarray -pillow==10.4.0 - # via - # imageio - # scikit-image pip-tools==7.4.1 # via swmmanywhere (pyproject.toml) planetary-computer==1.0.0 @@ -224,14 +203,11 @@ pyparsing==3.1.2 pyproj==3.6.1 # via # geopandas - # pysheds # rioxarray pyproject-hooks==1.1.0 # via # build # pip-tools -pysheds==0.3.5 - # via swmmanywhere (pyproject.toml) pystac[validation]==1.10.1 # via # planetary-computer @@ -267,13 +243,14 @@ pytz==2024.1 # multiurl # pandas # planetary-computer +pywbt==0.1.1 + # via swmmanywhere (pyproject.toml) pyyaml==6.0.1 # via # pre-commit # swmmanywhere (pyproject.toml) rasterio==1.3.10 # via - # pysheds # rioxarray # swmmanywhere (pyproject.toml) referencing==0.35.1 @@ -296,14 +273,10 @@ rpds-py==0.19.1 # referencing ruff==0.5.5 # via swmmanywhere (pyproject.toml) -scikit-image==0.24.0 - # via pysheds scipy==1.14.0 # via # netcomp # pyflwdir - # pysheds - # scikit-image # swmmanywhere (pyproject.toml) shapely==2.0.5 # via @@ -320,8 +293,6 @@ snuggs==1.4.7 # via rasterio swmm-toolkit==0.15.5 # via pyswmm -tifffile==2024.7.2 - # via scikit-image toolz==0.12.1 # via cytoolz tqdm==4.66.4 diff --git a/doc-requirements.txt b/doc-requirements.txt index a6ba56e0..d4cc7566 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.11 # by the following command: # -# pip-compile --extra=doc --output-file=doc-requirements.txt +# pip-compile --extra=doc --output-file=doc-requirements.txt pyproject.toml # aenum==3.1.11 # via @@ -11,7 +11,6 @@ aenum==3.1.11 affine==2.4.0 # via # pyflwdir - # pysheds # rasterio annotated-types==0.7.0 # via pydantic @@ -98,8 +97,6 @@ fsspec==2024.6.1 # via fastparquet geographiclib==2.0 # via geopy -geojson==3.1.0 - # via pysheds geopandas==1.0.1 # via # osmnx @@ -116,8 +113,6 @@ griffe==0.47.0 # via mkdocstrings-python idna==3.8 # via requests -imageio==2.34.2 - # via scikit-image ipykernel==6.29.5 # via mkdocs-jupyter ipython==8.26.0 @@ -156,8 +151,6 @@ jupyterlab-pygments==0.3.0 # via nbconvert jupytext==1.16.3 # via mkdocs-jupyter -lazy-loader==0.4 - # via scikit-image llvmlite==0.43.0 # via numba loguru==0.7.2 @@ -248,18 +241,14 @@ networkx==3.3 # via # netcomp # osmnx - # scikit-image # swmmanywhere (pyproject.toml) numba==0.60.0 - # via - # pyflwdir - # pysheds + # via pyflwdir numpy==1.26.4 # via # cftime # fastparquet # geopandas - # imageio # netcdf4 # netcomp # numba @@ -268,15 +257,12 @@ numpy==1.26.4 # pyarrow # pyflwdir # pyogrio - # pysheds # rasterio # rioxarray - # scikit-image # scipy # shapely # snuggs # swmmanywhere (pyproject.toml) - # tifffile # xarray osmnx==1.9.3 # via swmmanywhere (pyproject.toml) @@ -286,14 +272,12 @@ packaging==24.1 # geopandas # ipykernel # jupytext - # lazy-loader # mkdocs # nbconvert # planetary-computer # pyogrio # pyswmm # rioxarray - # scikit-image # xarray paginate==0.5.7 # via mkdocs-material @@ -302,7 +286,6 @@ pandas==2.2.2 # fastparquet # geopandas # osmnx - # pysheds # swmmanywhere (pyproject.toml) # xarray pandocfilters==1.5.1 @@ -311,10 +294,6 @@ parso==0.8.4 # via jedi pathspec==0.12.1 # via mkdocs -pillow==10.4.0 - # via - # imageio - # scikit-image planetary-computer==1.0.0 # via swmmanywhere (pyproject.toml) platformdirs==4.2.2 @@ -355,10 +334,7 @@ pyparsing==3.1.2 pyproj==3.6.1 # via # geopandas - # pysheds # rioxarray -pysheds==0.3.5 - # via swmmanywhere (pyproject.toml) pystac[validation]==1.10.1 # via # planetary-computer @@ -384,6 +360,8 @@ pytz==2024.1 # multiurl # pandas # planetary-computer +pywbt==0.1.1 + # via swmmanywhere (pyproject.toml) pywin32==306 # via jupyter-core pyyaml==6.0.1 @@ -402,7 +380,6 @@ pyzmq==26.0.3 # jupyter-client rasterio==1.3.10 # via - # pysheds # rioxarray # swmmanywhere (pyproject.toml) referencing==0.35.1 @@ -426,14 +403,10 @@ rpds-py==0.19.1 # via # jsonschema # referencing -scikit-image==0.24.0 - # via pysheds scipy==1.14.0 # via # netcomp # pyflwdir - # pysheds - # scikit-image # swmmanywhere (pyproject.toml) shapely==2.0.5 # via @@ -456,8 +429,6 @@ stack-data==0.6.3 # via ipython swmm-toolkit==0.15.5 # via pyswmm -tifffile==2024.7.2 - # via scikit-image tinycss2==1.3.0 # via nbconvert toolz==0.12.1 diff --git a/pyproject.toml b/pyproject.toml index 19d5dd05..98a9d170 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,16 +36,16 @@ dependencies = [ "netcdf4", "netcomp@ git+https://github.com/barneydobson/NetComp.git", "networkx>=3", - "numpy<2.0.0", + "numpy", "osmnx", "pandas", "planetary_computer", "pyarrow", "pydantic", "pyflwdir", - "pysheds==0.3.5", "pystac_client", "pyswmm", + "pywbt", "PyYAML", "rasterio", "rioxarray", diff --git a/requirements.txt b/requirements.txt index 9c8b1212..5f23ec4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,6 @@ aenum==3.1.11 affine==2.4.0 # via # pyflwdir - # pysheds # rasterio annotated-types==0.7.0 # via pydantic @@ -70,8 +69,6 @@ fsspec==2024.6.1 # via fastparquet geographiclib==2.0 # via geopy -geojson==3.1.0 - # via pysheds geopandas==1.0.1 # via # osmnx @@ -84,8 +81,6 @@ gitpython==3.1.43 # via swmmanywhere (pyproject.toml) idna==3.8 # via requests -imageio==2.34.2 - # via scikit-image joblib==1.4.2 # via swmmanywhere (pyproject.toml) jsonschema==4.23.0 @@ -96,8 +91,6 @@ jsonschema-specifications==2023.12.1 # via jsonschema julian==0.14 # via pyswmm -lazy-loader==0.4 - # via scikit-image llvmlite==0.43.0 # via numba loguru==0.7.2 @@ -112,18 +105,14 @@ networkx==3.3 # via # netcomp # osmnx - # scikit-image # swmmanywhere (pyproject.toml) numba==0.60.0 - # via - # pyflwdir - # pysheds + # via pyflwdir numpy==1.26.4 # via # cftime # fastparquet # geopandas - # imageio # netcdf4 # netcomp # numba @@ -132,15 +121,12 @@ numpy==1.26.4 # pyarrow # pyflwdir # pyogrio - # pysheds # rasterio # rioxarray - # scikit-image # scipy # shapely # snuggs # swmmanywhere (pyproject.toml) - # tifffile # xarray osmnx==1.9.3 # via swmmanywhere (pyproject.toml) @@ -148,25 +134,18 @@ packaging==24.1 # via # fastparquet # geopandas - # lazy-loader # planetary-computer # pyogrio # pyswmm # rioxarray - # scikit-image # xarray pandas==2.2.2 # via # fastparquet # geopandas # osmnx - # pysheds # swmmanywhere (pyproject.toml) # xarray -pillow==10.4.0 - # via - # imageio - # scikit-image planetary-computer==1.0.0 # via swmmanywhere (pyproject.toml) pyarrow==16.1.0 @@ -186,10 +165,7 @@ pyparsing==3.1.2 pyproj==3.6.1 # via # geopandas - # pysheds # rioxarray -pysheds==0.3.5 - # via swmmanywhere (pyproject.toml) pystac[validation]==1.10.1 # via # planetary-computer @@ -213,11 +189,12 @@ pytz==2024.1 # multiurl # pandas # planetary-computer +pywbt==0.1.1 + # via swmmanywhere (pyproject.toml) pyyaml==6.0.1 # via swmmanywhere (pyproject.toml) rasterio==1.3.10 # via - # pysheds # rioxarray # swmmanywhere (pyproject.toml) referencing==0.35.1 @@ -238,14 +215,10 @@ rpds-py==0.19.1 # via # jsonschema # referencing -scikit-image==0.24.0 - # via pysheds scipy==1.14.0 # via # netcomp # pyflwdir - # pysheds - # scikit-image # swmmanywhere (pyproject.toml) shapely==2.0.5 # via @@ -262,8 +235,6 @@ snuggs==1.4.7 # via rasterio swmm-toolkit==0.15.5 # via pyswmm -tifffile==2024.7.2 - # via scikit-image toolz==0.12.1 # via cytoolz tqdm==4.66.4 diff --git a/swmmanywhere/filepaths.py b/swmmanywhere/filepaths.py index dabeedde..d3099996 100644 --- a/swmmanywhere/filepaths.py +++ b/swmmanywhere/filepaths.py @@ -289,7 +289,7 @@ def filepaths_from_yaml(f: Path): """Get file paths from a yaml file.""" address_dict = yaml_load(f.read_text()) address_dict["base_dir"] = Path(address_dict["base_dir"]) - overrides = address_dict.pop("overrides") + overrides = address_dict.pop("overrides", {}) addresses = FilePaths(**address_dict, **overrides) return addresses diff --git a/swmmanywhere/geospatial_utilities.py b/swmmanywhere/geospatial_utilities.py index e94071b0..c990d970 100644 --- a/swmmanywhere/geospatial_utilities.py +++ b/swmmanywhere/geospatial_utilities.py @@ -8,9 +8,9 @@ import itertools import json import math -import operator import os -from copy import deepcopy +import shutil +import tempfile from functools import lru_cache from pathlib import Path from typing import List, Optional @@ -23,11 +23,11 @@ import rasterio as rst import rioxarray import shapely +from pywbt import whitebox_tools from rasterio import features from scipy.interpolate import RegularGridInterpolator from scipy.spatial import KDTree from shapely import geometry as sgeom -from shapely import ops as sops from shapely.strtree import STRtree from tqdm.auto import tqdm @@ -35,8 +35,6 @@ os.environ["NUMBA_NUM_THREADS"] = "1" import pyflwdir # noqa: E402 -import pysheds # noqa: E402 -from pysheds import grid as pgrid # noqa: E402 TransformerFromCRS = lru_cache(pyproj.transformer.Transformer.from_crs) @@ -368,117 +366,6 @@ def burn_shape_in_raster( dest.write(data, 1) -def condition_dem( - grid: pysheds.sgrid.sGrid, dem: pysheds.sview.Raster -) -> pysheds.sview.Raster: - """Condition a DEM with pysheds. - - Args: - grid (pysheds.sgrid.sGrid): The grid object. - dem (pysheds.sview.Raster): The input DEM. - - Returns: - pysheds.sview.Raster: The conditioned DEM. - """ - # Fill pits, depressions, and resolve flats in the DEM - pit_filled_dem = grid.fill_pits(dem) - flooded_dem = grid.fill_depressions(pit_filled_dem) - inflated_dem = grid.resolve_flats(flooded_dem) - - return inflated_dem - - -def compute_flow_directions( - grid: pysheds.sgrid.sGrid, inflated_dem: pysheds.sview.Raster -) -> tuple[pysheds.sview.Raster, tuple]: - """Compute flow directions. - - Args: - grid (pysheds.sgrid.sGrid): The grid object. - inflated_dem (pysheds.sview.Raster): The input DEM. - - Returns: - pysheds.sview.Raster: Flow directions. - tuple: Direction mapping. - """ - dirmap = (64, 128, 1, 2, 4, 8, 16, 32) - flow_dir = grid.flowdir(inflated_dem, dirmap=dirmap) - return flow_dir, dirmap - - -def calculate_flow_accumulation( - grid: pysheds.sgrid.sGrid, flow_dir: pysheds.sview.Raster, dirmap: tuple -) -> pysheds.sview.Raster: - """Calculate flow accumulation. - - Args: - grid (pysheds.sgrid.sGrid): The grid object. - flow_dir (pysheds.sview.Raster): Flow directions. - dirmap (tuple): Direction mapping. - - Returns: - pysheds.sview.Raster: Flow accumulations. - """ - flow_acc = grid.accumulation(flow_dir, dirmap=dirmap) - return flow_acc - - -def delineate_catchment( - grid: pysheds.sgrid.sGrid, - flow_acc: pysheds.sview.Raster, - flow_dir: pysheds.sview.Raster, - dirmap: tuple, - G: nx.Graph, -) -> gpd.GeoDataFrame: - """Delineate catchments. - - Args: - grid (pysheds.sgrid.Grid): The grid object. - flow_acc (pysheds.sview.Raster): Flow accumulations. - flow_dir (pysheds.sview.Raster): Flow directions. - dirmap (tuple): Direction mapping. - G (nx.Graph): The input graph with nodes containing 'x' and 'y'. - - Returns: - gpd.GeoDataFrame: A GeoDataFrame containing polygons with columns: - 'geometry', 'area', and 'id'. Sorted by area in descending order. - """ - polys = [] - # Iterate over the nodes in the graph - for id, data in tqdm(G.nodes(data=True), total=len(G.nodes), disable=not verbose()): - # Snap the node to the nearest grid cell - x, y = data["x"], data["y"] - grid_ = deepcopy(grid) - x_snap, y_snap = grid_.snap_to_mask(flow_acc >= 0, (x, y)) - - # Delineate the catchment - catch = grid_.catchment( - x=x_snap, - y=y_snap, - fdir=flow_dir, - dirmap=dirmap, - xytype="coordinate", - algorithm="recursive", - ) - # n.b. recursive algorithm is not recommended, but crashes with a seg - # fault occasionally otherwise. - - grid_.clip_to(catch) - - # Polygonize the catchment - shapes = grid_.polygonize() - catchment_polygon = sops.unary_union( - [sgeom.shape(shape) for shape, value in shapes] - ) - - # Add the catchment to the list - polys.append( - {"id": id, "geometry": catchment_polygon, "area": catchment_polygon.area} - ) - polys.sort(key=operator.itemgetter("area"), reverse=True) - return gpd.GeoDataFrame(polys, crs=grid.crs) - - def remove_intersections(polys: gpd.GeoDataFrame) -> gpd.GeoDataFrame: """Remove intersections from a GeoDataFrame of polygons. @@ -573,14 +460,14 @@ def attach_unconnected_subareas( def calculate_slope( - polys_gdf: gpd.GeoDataFrame, grid: pysheds.sgrid.sGrid, cell_slopes: np.ndarray + polys_gdf: gpd.GeoDataFrame, grid: Grid, cell_slopes: np.ndarray ) -> gpd.GeoDataFrame: """Calculate the average slope of each polygon. Args: polys_gdf (gpd.GeoDataFrame): A GeoDataFrame containing polygons with columns: 'geometry', 'area', and 'id'. - grid (pysheds.sgrid.sGrid): The grid object. + grid (Grid): Information of the raster (affine, shape, crs, bbox) cell_slopes (np.ndarray): The slopes of each cell in the grid. Returns: @@ -633,7 +520,7 @@ def vectorize( def delineate_catchment_pyflwdir( - grid: pysheds.sgrid.sGrid, flow_dir: pysheds.sview.Raster, G: nx.Graph + grid: Grid, flow_dir: np.array, G: nx.Graph ) -> gpd.GeoDataFrame: """Derive subcatchments from the nodes on a graph and a DEM. @@ -641,8 +528,8 @@ def delineate_catchment_pyflwdir( faster than delineate_catchment. Args: - grid (pysheds.sgrid.Grid): The grid object. - flow_dir (pysheds.sview.Raster): Flow directions. + grid (Grid): Information of the raster (affine, shape, crs, bbox). + flow_dir (np.array): Flow directions. G (nx.Graph): The input graph with nodes containing 'x' and 'y'. Returns: @@ -691,7 +578,7 @@ def derive_subbasins_streamorder( gpd.GeoDataFrame: A GeoDataFrame containing polygons. """ # Load and process the DEM - grid, flow_dir, _, _ = load_and_process_dem(fid) + grid, flow_dir, _ = load_and_process_dem(fid) flw = pyflwdir.from_array( flow_dir, @@ -732,67 +619,133 @@ def derive_subbasins_streamorder( return gdf_bas +def flwdir_whitebox(fid: Path) -> np.array: + """Calculate flow direction using WhiteboxTools. + + Args: + fid (Path): Filepath to the DEM. + + Returns: + np.array: Flow directions. + """ + # Initialize WhiteboxTools + with tempfile.TemporaryDirectory(dir=str(fid.parent)) as temp_dir: + temp_path = Path(temp_dir) + + # Copy raster to working directory + dem = temp_path / "dem.tif" + shutil.copy(fid, dem) + + # Condition + wbt_args = { + "BreachDepressions": ["-i=dem.tif", "--fillpits", "-o=dem_corr.tif"], + "D8Pointer": ["-i=dem_corr.tif", "-o=fdir.tif"], + } + whitebox_tools( + wbt_args, + work_dir=temp_path, + verbose=verbose(), + wbt_root=temp_path / "WBT", + max_procs=1, + ) + + fdir = temp_path / "fdir.tif" + if not Path(fdir).exists(): + raise ValueError("Flow direction raster not created.") + + with rst.open(fdir) as src: + flow_dir = src.read(1) + + # Adjust mapping from WhiteboxTools to pyflwdir + mapping = {1: 128, 2: 1, 4: 2, 8: 4, 16: 8, 32: 16, 64: 32, 128: 64} + get_flow_dir = np.vectorize(mapping.get, excluded=["default"]) + flow_dir = get_flow_dir(flow_dir, 0) + return flow_dir + + +class Grid: + """A class to represent a grid.""" + + def __init__(self, affine: rst.Affine, shape: tuple, crs: int, bbox: tuple): + """Initialize the Grid class. + + Args: + affine (rst.Affine): The affine transformation. + shape (tuple): The shape of the grid. + crs (int): The CRS of the grid. + bbox (tuple): The bounding box of the grid. + """ + self.affine = affine + self.shape = shape + self.crs = crs + self.bbox = bbox + + def load_and_process_dem( fid: Path, -) -> tuple[pysheds.sgrid.sGrid, pysheds.sview.Raster, tuple, pysheds.sview.Raster]: + method: str = "whitebox", +) -> tuple[Grid, np.array, np.array]: """Load and condition a DEM. Args: fid (Path): Filepath to the DEM. + method (str, optional): The method to use for conditioning. Defaults to + "whitebox". Returns: - tuple: A tuple containing the grid, flow directions, direction mapping, - and cell slopes. + tuple: A tuple containing the grid, flow directions, and cell slopes. """ - # Initialise pysheds grids - grid = pgrid.Grid.from_raster(str(fid)) - dem = grid.read_raster(str(fid)) + with rst.open(fid, "r") as src: + elevtn = src.read(1).astype(float) + nodata = float(src.nodata) + transform = src.transform + crs = src.crs + + if method not in ("whitebox", "pyflwdir"): + raise ValueError("Method must be 'whitebox' or 'pyflwdir'.") - # Condition the DEM - inflated_dem = condition_dem(grid, dem) + if method == "whitebox": + flow_dir = flwdir_whitebox(fid) + elif method == "pyflwdir": + flw = pyflwdir.from_dem( + data=elevtn, + nodata=nodata, + transform=transform, + latlon=crs.is_geographic, + ) + flow_dir = flw.to_array(ftype="d8").astype(int) - # Compute flow directions - flow_dir, dirmap = compute_flow_directions(grid, inflated_dem) + cell_slopes = pyflwdir.dem.slope( + elevtn, + nodata=nodata, + transform=transform, + latlon=crs.is_geographic, + ) - # Calculate slopes - cell_slopes = grid.cell_slopes(dem, flow_dir) + grid = Grid(transform, elevtn.shape, crs, src.bounds) - return grid, flow_dir, dirmap, cell_slopes + return grid, flow_dir, cell_slopes -def derive_subcatchments(G: nx.Graph, fid: Path, method="pyflwdir") -> gpd.GeoDataFrame: +def derive_subcatchments( + G: nx.Graph, fid: Path, method: str = "whitebox" +) -> gpd.GeoDataFrame: """Derive subcatchments from the nodes on a graph and a DEM. Args: G (nx.Graph): The input graph with nodes containing 'x' and 'y'. fid (Path): Filepath to the DEM. - method (str, optional): The method to use for delineating catchments. - Defaults to 'pyflwdir'. Can also be `pysheds` to use the old - method. + method (str, optional): The method to use for conditioning. Returns: gpd.GeoDataFrame: A GeoDataFrame containing polygons with columns: 'geometry', 'area', 'id', 'width', and 'slope'. """ - if method not in ["pyflwdir", "pysheds"]: - raise ValueError("Invalid method. Must be 'pyflwdir' or 'pysheds'.") - # Load and process the DEM - grid, flow_dir, dirmap, cell_slopes = load_and_process_dem(fid) + grid, flow_dir, cell_slopes = load_and_process_dem(fid, method) - if method == "pysheds": - # Calculate flow accumulations - flow_acc = calculate_flow_accumulation(grid, flow_dir, dirmap) - - # Delineate catchments - polys = delineate_catchment(grid, flow_acc, flow_dir, dirmap, G) - - # Remove intersections - result_polygons = remove_intersections(polys) - - elif method == "pyflwdir": - # Delineate catchments - result_polygons = delineate_catchment_pyflwdir(grid, flow_dir, G) + # Delineate catchments + result_polygons = delineate_catchment_pyflwdir(grid, flow_dir, G) # Convert to GeoDataFrame polys_gdf = result_polygons.dropna(subset=["geometry"]) diff --git a/swmmanywhere/logging.py b/swmmanywhere/logging.py index 29af226b..cae70f24 100644 --- a/swmmanywhere/logging.py +++ b/swmmanywhere/logging.py @@ -25,6 +25,11 @@ def verbose() -> bool: return os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true" +def set_verbose(verbose: bool): + """Set the verbosity.""" + os.environ["SWMMANYWHERE_VERBOSE"] = str(verbose).lower() + + def dynamic_filter(record): """A dynamic filter.""" return verbose() diff --git a/swmmanywhere/preprocessing.py b/swmmanywhere/preprocessing.py index e38b79df..47867c1b 100644 --- a/swmmanywhere/preprocessing.py +++ b/swmmanywhere/preprocessing.py @@ -32,7 +32,7 @@ def write_df(df: pd.DataFrame | gpd.GeoDataFrame, fid: Path): """ if fid.suffix in (".geoparquet", ".parquet"): df.to_parquet(fid) - elif fid.suffix == ".json": + elif fid.suffix in (".geojson", ".json"): if isinstance(df, gpd.GeoDataFrame): df.to_file(fid, driver="GeoJSON") else: diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index 5d48a756..d36fa942 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -78,6 +78,7 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: config["bbox"], config.get("bbox_number", None), config.get("model_number", None), + config.get("extension", "parquet"), **config.get("address_overrides", {}), ) @@ -326,7 +327,7 @@ def check_and_register_custom_graphfcns(config: dict): spec.loader.exec_module(custom_graphfcn_module) # Validate the import - validate_graphfcn_list(config["graphfcn_list"]) + validate_graphfcn_list(config.get("graphfcn_list", [])) return config @@ -357,7 +358,7 @@ def check_and_register_custom_metrics(config: dict): spec.loader.exec_module(custom_metric_module) # Validate metric list - validate_metric_list(config["metric_list"]) + validate_metric_list(config.get("metric_list", [])) return config diff --git a/tests/test_geospatial_utilities.py b/tests/test_geospatial_utilities.py index d70daf81..1ffdbe33 100644 --- a/tests/test_geospatial_utilities.py +++ b/tests/test_geospatial_utilities.py @@ -15,6 +15,7 @@ from swmmanywhere import geospatial_utilities as go from swmmanywhere import graph_utilities as ge +from swmmanywhere.logging import set_verbose from swmmanywhere.misc.debug_derive_rc import derive_rc_alt @@ -228,27 +229,33 @@ def test_burn_shape_in_raster(): new_raster_fid.unlink(missing_ok=True) -def test_derive_subcatchments(street_network): +@pytest.mark.parametrize( + "method,area,slope,width", + [("pyflwdir", 2498, 0.1187, 28.202), ("whitebox", 2998, 0.1102, 30.894)], +) +@pytest.mark.parametrize("verbose", [True, False]) +def test_derive_subcatchments(street_network, method, area, slope, width, verbose): """Test the derive_subcatchments function.""" + set_verbose(verbose) + elev_fid = Path(__file__).parent / "test_data" / "elevation.tif" - for method in ["pysheds", "pyflwdir"]: - polys = go.derive_subcatchments(street_network, elev_fid, method=method) - assert "slope" in polys.columns - assert "area" in polys.columns - assert "geometry" in polys.columns - assert "id" in polys.columns - assert polys.shape[0] > 0 - assert polys.dropna().shape == polys.shape - assert polys.crs == street_network.graph["crs"] - - # Pyflwdir and pysheds catchment derivation aren't absolutely identical - assert almost_equal(polys.set_index("id").loc[2623975694, "area"], 1499, tol=1) - assert almost_equal( - polys.set_index("id").loc[2623975694, "slope"], 0.06145, tol=0.001 - ) - assert almost_equal( - polys.set_index("id").loc[2623975694, "width"], 21.845, tol=0.001 - ) + + polys = go.derive_subcatchments(street_network, elev_fid, method) + assert "slope" in polys.columns + assert "area" in polys.columns + assert "geometry" in polys.columns + assert "id" in polys.columns + assert polys.shape[0] > 0 + assert polys.dropna().shape == polys.shape + assert polys.crs == street_network.graph["crs"] + + assert almost_equal(polys.set_index("id").loc[2623975694, "area"], area, tol=1) + assert almost_equal( + polys.set_index("id").loc[2623975694, "slope"], slope, tol=0.001 + ) + assert almost_equal( + polys.set_index("id").loc[2623975694, "width"], width, tol=0.001 + ) def test_derive_rc(street_network): diff --git a/tests/test_graph_utilities.py b/tests/test_graph_utilities.py index 2822059a..0c0abbd9 100644 --- a/tests/test_graph_utilities.py +++ b/tests/test_graph_utilities.py @@ -690,7 +690,7 @@ def test_clip_to_catchments(street_network): G_ = gu.clip_to_catchments( G, addresses=addresses, subcatchment_derivation=subcatchment_derivation ) - assert len(G_.edges) == 9 + assert len(G_.edges) == 7 # Test default clipping streamorder subcatchment_derivation = parameters.SubcatchmentDerivation() @@ -698,7 +698,7 @@ def test_clip_to_catchments(street_network): G_ = gu.clip_to_catchments( G, addresses=addresses, subcatchment_derivation=subcatchment_derivation ) - assert len(G_.edges) == 4 + assert len(G_.edges) == 2 # Test clipping subcatchment_derivation = parameters.SubcatchmentDerivation( @@ -709,7 +709,7 @@ def test_clip_to_catchments(street_network): G_ = gu.clip_to_catchments( G, addresses=addresses, subcatchment_derivation=subcatchment_derivation ) - assert len(G_.edges) == 30 + assert len(G_.edges) == 31 # Test clipping with different params subcatchment_derivation = parameters.SubcatchmentDerivation(