diff --git a/README.md b/README.md index 0bc89d95..1e707092 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # SWMManywhere - + [![Test and build](https://github.com/ImperialCollegeLondon/SWMManywhere/actions/workflows/ci.yml/badge.svg)](https://github.com/ImperialCollegeLondon/SWMManywhere/actions/workflows/ci.yml) + ## High level workflow overview diff --git a/dev-requirements.txt b/dev-requirements.txt index 7556ccb0..218d771f 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -143,9 +143,9 @@ numpy==1.26.3 # osmnx # pandas # pyarrow - # pygeos # pysheds # rasterio + # rioxarray # salib # scikit-image # scipy @@ -163,6 +163,7 @@ packaging==23.2 # geopandas # matplotlib # pytest + # rioxarray # scikit-image # xarray pandas==2.1.4 @@ -189,8 +190,6 @@ pre-commit==3.6.0 # via swmmanywhere (pyproject.toml) pyarrow==14.0.2 # via swmmanywhere (pyproject.toml) -pygeos==0.14 - # via swmmanywhere (pyproject.toml) pyparsing==3.1.1 # via # matplotlib @@ -199,6 +198,7 @@ pyproj==3.6.1 # via # geopandas # pysheds + # rioxarray pyproject-hooks==1.0.0 # via build pysheds==0.3.5 @@ -228,11 +228,14 @@ pyyaml==6.0.1 rasterio==1.3.9 # via # pysheds + # rioxarray # swmmanywhere (pyproject.toml) requests==2.31.0 # via # cdsapi # osmnx +rioxarray==0.15.1 + # via swmmanywhere (pyproject.toml) ruff==0.1.11 # via swmmanywhere (pyproject.toml) salib==1.4.7 @@ -278,7 +281,9 @@ virtualenv==20.24.5 wheel==0.41.3 # via pip-tools xarray==2023.12.0 - # via swmmanywhere (pyproject.toml) + # via + # rioxarray + # swmmanywhere (pyproject.toml) # The following packages are considered to be unsafe in a requirements file: # pip diff --git a/pyproject.toml b/pyproject.toml index 19ce1bdc..e317ac9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,10 +27,10 @@ dependencies = [ # TODO definitely don't need all of these "osmnx", "pandas", "pyarrow", - "pygeos", "pysheds", "PyYAML", "rasterio", + "rioxarray", "SALib", "SciPy", "shapely", diff --git a/requirements.txt b/requirements.txt index 8d88cf6d..120ed9d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -113,9 +113,9 @@ numpy==1.26.3 # osmnx # pandas # pyarrow - # pygeos # pysheds # rasterio + # rioxarray # salib # scikit-image # scipy @@ -131,6 +131,7 @@ packaging==23.2 # fastparquet # geopandas # matplotlib + # rioxarray # scikit-image # xarray pandas==2.1.4 @@ -149,8 +150,6 @@ pillow==10.2.0 # scikit-image pyarrow==14.0.2 # via swmmanywhere (pyproject.toml) -pygeos==0.14 - # via swmmanywhere (pyproject.toml) pyparsing==3.1.1 # via # matplotlib @@ -159,6 +158,7 @@ pyproj==3.6.1 # via # geopandas # pysheds + # rioxarray pysheds==0.3.5 # via swmmanywhere (pyproject.toml) python-dateutil==2.8.2 @@ -172,11 +172,14 @@ pyyaml==6.0.1 rasterio==1.3.9 # via # pysheds + # rioxarray # swmmanywhere (pyproject.toml) requests==2.31.0 # via # cdsapi # osmnx +rioxarray==0.15.1 + # via swmmanywhere (pyproject.toml) salib==1.4.7 # via swmmanywhere (pyproject.toml) scikit-image==0.22.0 @@ -214,7 +217,9 @@ tzdata==2023.4 urllib3==2.1.0 # via requests xarray==2023.12.0 - # via swmmanywhere (pyproject.toml) + # via + # rioxarray + # swmmanywhere (pyproject.toml) # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py new file mode 100644 index 00000000..6749f4c7 --- /dev/null +++ b/swmmanywhere/geospatial_operations.py @@ -0,0 +1,335 @@ +# -*- coding: utf-8 -*- +"""Created 2024-01-20. + +A module containing functions to perform a variety of geospatial operations, +such as reprojecting coordinates and handling raster data. + +@author: Barnaby Dobson +""" +from functools import lru_cache +from typing import Optional + +import geopandas as gpd +import networkx as nx +import numpy as np +import pandas as pd +import pyproj +import rasterio as rst +import rioxarray +from rasterio import features +from scipy.interpolate import RegularGridInterpolator +from shapely import geometry as sgeom +from shapely.strtree import STRtree + +TransformerFromCRS = lru_cache(pyproj.transformer.Transformer.from_crs) + + +def get_utm_epsg(x: float, + y: float, + crs: str | int | pyproj.CRS = 'EPSG:4326', + datum_name: str = "WGS 84"): + """Get the UTM CRS code for a given coordinate. + + Note, this function is taken from GeoPandas and modified to use + for getting the UTM CRS code for a given coordinate. + + Args: + x (float): Longitude in crs + y (float): Latitude in crs + crs (str | int | pyproj.CRS, optional): The CRS of the input + coordinates. Defaults to 'EPSG:4326'. + datum_name (str, optional): The datum name to use for the UTM CRS + + Returns: + str: Formatted EPSG code for the UTM zone. + + Example: + >>> get_utm_epsg(-0.1276, 51.5074) + 'EPSG:32630' + """ + if not isinstance(x, float) or not isinstance(y, float): + raise TypeError("x and y must be floats") + + try: + crs = pyproj.CRS(crs) + except pyproj.exceptions.CRSError: + raise ValueError("Invalid CRS") + + # ensure using geographic coordinates + if pyproj.CRS(crs).is_geographic: + lon = x + lat = y + else: + transformer = TransformerFromCRS(crs, "EPSG:4326", always_xy=True) + lon, lat = transformer.transform(x, y) + utm_crs_list = pyproj.database.query_utm_crs_info( + datum_name=datum_name, + area_of_interest=pyproj.aoi.AreaOfInterest( + west_lon_degree=lon, + south_lat_degree=lat, + east_lon_degree=lon, + north_lat_degree=lat, + ), + ) + return f"{utm_crs_list[0].auth_name}:{utm_crs_list[0].code}" + + +def interp_with_nans(xy: tuple[float,float], + interp: RegularGridInterpolator, + grid: np.ndarray, + values: list[float]) -> float: + """Wrap the interpolation function to handle NaNs. + + Picks the nearest non NaN grid point if the interpolated value is NaN, + otherwise returns the interpolated value. + + Args: + xy (tuple): Coordinate of interest + interp (RegularGridInterpolator): The interpolator object. + grid (np.ndarray): List of xy coordinates of the grid points. + values (list): The list of values at each point in the grid. + + Returns: + float: The interpolated value. + """ + # Call the interpolator + val = float(interp(xy)) + # If the value is NaN, we need to pick nearest non nan grid point + if np.isnan(val): + # Get the distances to all grid points + distances = np.linalg.norm(grid - xy, axis=1) + # Get the indices of the grid points sorted by distance + indices = np.argsort(distances) + # Iterate over the grid points in order of increasing distance + for index in indices: + # If the value at this grid point is not NaN, return it + if not np.isnan(values[index]): + return values[index] + else: + return val + + raise ValueError("No non NaN values found in grid.") + +def interpolate_points_on_raster(x: list[float], + y: list[float], + elevation_fid: str) -> list[float ]: + """Interpolate points on a raster. + + Args: + x (list): X coordinates. + y (list): Y coordinates. + elevation_fid (str): Filepath to elevation raster. + + Returns: + elevation (float): Elevation at point. + """ + with rst.open(elevation_fid) as src: + # Read the raster data + data = src.read(1).astype(float) # Assuming it's a single-band raster + data[data == src.nodata] = None + + # Get the raster's coordinates + x = np.linspace(src.bounds.left, src.bounds.right, src.width) + y = np.linspace(src.bounds.bottom, src.bounds.top, src.height) + + # Define grid + xx, yy = np.meshgrid(x, y) + grid = np.vstack([xx.ravel(), yy.ravel()]).T + values = data.ravel() + + # Define interpolator + interp = RegularGridInterpolator((y,x), + np.flipud(data), + method='linear', + bounds_error=False, + fill_value=None) + # Interpolate for x,y + return [interp_with_nans((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] + +def reproject_raster(target_crs: str, + fid: str, + new_fid: Optional[str] = None): + """Reproject a raster to a new CRS. + + Args: + target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630). + fid (str): Filepath to the raster to reproject. + new_fid (str, optional): Filepath to save the reprojected raster. + Defaults to None, which will just use fid with '_reprojected'. + """ + # Open the raster + with rioxarray.open_rasterio(fid) as raster: + + # Reproject the raster + reprojected = raster.rio.reproject(target_crs) + + # Define the output filepath + if new_fid is None: + new_fid = fid.replace('.tif','_reprojected.tif') + + # Save the reprojected raster + reprojected.rio.to_raster(new_fid) + +def get_transformer(source_crs: str, + target_crs: str) -> pyproj.Transformer: + """Get a transformer object for reprojection. + + Args: + source_crs (str): Source CRS in EPSG format (e.g., EPSG:32630). + target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630). + + Returns: + pyproj.Transformer: Transformer object for reprojection. + + Example: + >>> transformer = get_transformer('EPSG:4326', 'EPSG:32630') + >>> transformer.transform(-0.1276, 51.5074) + (699330.1106898375, 5710164.30300683) + """ + return pyproj.Transformer.from_crs(source_crs, + target_crs, + always_xy=True) + +def reproject_df(df: pd.DataFrame, + source_crs: str, + target_crs: str) -> pd.DataFrame: + """Reproject the coordinates in a DataFrame. + + Args: + df (pd.DataFrame): DataFrame with columns 'longitude' and 'latitude'. + source_crs (str): Source CRS in EPSG format (e.g., EPSG:4326). + target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630). + """ + # Function to transform coordinates + pts = gpd.points_from_xy(df["longitude"], + df["latitude"], + crs=source_crs).to_crs(target_crs) + df = df.copy() + df['x'] = pts.x + df['y'] = pts.y + return df + +def reproject_graph(G: nx.Graph, + source_crs: str, + target_crs: str) -> nx.Graph: + """Reproject the coordinates in a graph. + + osmnx.projection.project_graph might be suitable if some other behaviour + needs to be captured, but it currently fails the tests so I will ignore for + now. + + Args: + G (nx.Graph): Graph with nodes containing 'x' and 'y' properties. + source_crs (str): Source CRS in EPSG format (e.g., EPSG:4326). + target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630). + + Returns: + nx.Graph: Graph with nodes containing 'x' and 'y' properties. + """ + # Create a PyProj transformer for CRS conversion + transformer = get_transformer(source_crs, target_crs) + + # Create a new graph with the converted nodes and edges + G_new = G.copy() + + # Convert and add nodes with 'x', 'y' properties + for node, data in G_new.nodes(data=True): + x, y = transformer.transform(data['x'], data['y']) + data['x'] = x + data['y'] = y + + # Convert and add edges with 'geometry' property + for u, v, data in G_new.edges(data=True): + if 'geometry' in data.keys(): + data['geometry'] = sgeom.LineString(transformer.transform(x, y) + for x, y in data['geometry'].coords) + else: + data['geometry'] = sgeom.LineString([[G_new.nodes[u]['x'], + G_new.nodes[u]['y']], + [G_new.nodes[v]['x'], + G_new.nodes[v]['y']]]) + + return G_new + + +def nearest_node_buffer(points1: dict[str, sgeom.Point], + points2: dict[str, sgeom.Point], + threshold: float) -> dict: + """Find the nearest node within a given buffer threshold. + + Args: + points1 (dict): A dictionary where keys are labels and values are + Shapely points geometries. + points2 (dict): A dictionary where keys are labels and values are + Shapely points geometries. + threshold (float): The maximum distance for a node to be considered + 'nearest'. If no nodes are within this distance, the node is not + included in the output. + + Returns: + dict: A dictionary where keys are labels from points1 and values are + labels from points2 of the nearest nodes within the threshold. + """ + # Convert the keys of points2 to a list + labels2 = list(points2.keys()) + + # Create a spatial index + tree = STRtree(list(points2.values())) + + # Initialize an empty dictionary to store the matching nodes + matching = {} + + # Iterate over points1 + for key, geom in points1.items(): + # Find the nearest node in the spatial index to the current geometry + nearest = tree.nearest(geom) + nearest_geom = points2[labels2[nearest]] + + # If the nearest node is within the threshold, add it to the + # matching dictionary + if geom.buffer(threshold).intersects(nearest_geom): + matching[key] = labels2[nearest] + + # Return the matching dictionary + return matching + +def burn_shape_in_raster(geoms: list[sgeom.LineString], + depth: float, + raster_fid: str, + new_raster_fid: str): + """Burn a depth into a raster along a list of shapely geometries. + + Args: + geoms (list): List of Shapely geometries. + depth (float): Depth to carve. + raster_fid (str): Filepath to input raster. + new_raster_fid (str): Filepath to save the carved raster. + """ + with rst.open(raster_fid) as src: + # read data + data = src.read(1) + data = data.astype(float) + data_mask = data != src.nodata + bool_mask = np.zeros(data.shape, dtype=bool) + for geom in geoms: + # Create a mask for the line + mask = features.geometry_mask([sgeom.mapping(geom)], + out_shape=src.shape, + transform=src.transform, + invert=True) + # modify masked data + bool_mask[mask] = True # Adjust this multiplier as needed + #modify data + data[bool_mask & data_mask] -= depth + # Create a new GeoTIFF with modified values + with rst.open(new_raster_fid, + 'w', + driver='GTiff', + height=src.height, + width=src.width, + count=1, + dtype=data.dtype, + crs=src.crs, + transform=src.transform, + nodata = src.nodata) as dest: + dest.write(data, 1) \ No newline at end of file diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py new file mode 100644 index 00000000..fd70f724 --- /dev/null +++ b/tests/test_geospatial.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +"""Created on Tue Oct 18 10:35:51 2022. + +@author: Barney +""" + +import os +from unittest.mock import MagicMock, patch + +import networkx as nx +import numpy as np +import rasterio as rst +from scipy.interpolate import RegularGridInterpolator +from shapely import geometry as sgeom + +from swmmanywhere import geospatial_operations as go + + +def test_interp_with_nans(): + """Test the interp_interp_with_nans function.""" + # Define a simple grid and values + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + xx, yy = np.meshgrid(x, y) + grid = np.vstack([xx.ravel(), yy.ravel()]).T + values = np.linspace(0, 1, 25) + values_grid = values.reshape(5, 5) + + # Define an interpolator + interp = RegularGridInterpolator((x,y), + values_grid) + + # Test the function at a point inside the grid + yx = (0.875, 0.875) + result = go.interp_with_nans(yx, interp, grid, values) + assert result == 0.875 + + # Test the function on a nan point + values_grid[1][1] = np.nan + yx = (0.251, 0.25) + result = go.interp_with_nans(yx, interp, grid, values) + assert result == values_grid[1][2] + +@patch('rasterio.open') +def test_interpolate_points_on_raster(mock_rst_open): + """Test the interpolate_points_on_raster function.""" + # Mock the raster file + mock_src = MagicMock() + mock_src.read.return_value = np.array([[1, 2], [3, 4]]) + mock_src.bounds = MagicMock() + mock_src.bounds.left = 0 + mock_src.bounds.right = 1 + mock_src.bounds.bottom = 0 + mock_src.bounds.top = 1 + mock_src.width = 2 + mock_src.height = 2 + mock_src.nodata = None + mock_rst_open.return_value.__enter__.return_value = mock_src + + # Define the x and y coordinates + x = [0.25, 0.75] + y = [0.25, 0.75] + + # Call the function + result = go.interpolate_points_on_raster(x, y, 'fake_path') + + # [3,2] feels unintuitive but it's because rasters measure from the top + assert result == [3.0, 2.0] + +def test_get_utm(): + """Test the get_utm_epsg function.""" + # Test a northern hemisphere point + crs = go.get_utm_epsg(-1.0, 51.0) + assert crs == 'EPSG:32630' + + # Test a southern hemisphere point + crs = go.get_utm_epsg(-1.0, -51.0) + assert crs == 'EPSG:32730' + +def create_raster(fid): + """Define a function to create a mock raster file.""" + data = np.ones((100, 100)) + transform = rst.transform.from_origin(0, 0, 0.1, 0.1) + with rst.open(fid, + 'w', + driver='GTiff', + height=100, + width=100, + count=1, + dtype='uint8', + crs='EPSG:4326', + transform=transform) as src: + src.write(data, 1) +def test_reproject_raster(): + """Test the reproject_raster function.""" + # Create a mock raster file + fid = 'test.tif' + try: + create_raster(fid) + + # Define the input parameters + target_crs = 'EPSG:32630' + new_fid = 'test_reprojected.tif' + + # Call the function + go.reproject_raster(target_crs, fid) + + # Check if the reprojected file exists + assert os.path.exists(new_fid) + + # Check if the reprojected file has the correct CRS + with rst.open(new_fid) as src: + assert src.crs.to_string() == target_crs + finally: + # Regardless of test outcome, delete the temp file + if os.path.exists(fid): + os.remove(fid) + if os.path.exists(new_fid): + os.remove(new_fid) + + +def almost_equal(a, b, tol=1e-6): + """Check if two numbers are almost equal.""" + return abs(a-b) < tol + +def test_get_transformer(): + """Test the get_transformer function.""" + # Test a northern hemisphere point + transformer = go.get_transformer('EPSG:4326', 'EPSG:32630') + + initial_point = (-0.1276, 51.5074) + expected_point = (699330.1106898375, 5710164.30300683) + new_point = transformer.transform(*initial_point) + assert almost_equal(new_point[0], + expected_point[0]) + assert almost_equal(new_point[1], + expected_point[1]) + +def test_reproject_graph(): + """Test the reproject_graph function.""" + # Create a mock graph + G = nx.Graph() + G.add_node(1, x=0, y=0) + G.add_node(2, x=1, y=1) + G.add_edge(1, 2) + G.add_node(3, x=1, y=2) + G.add_edge(2, 3, geometry=sgeom.LineString([(1, 1), (1, 2)])) + + # Define the input parameters + source_crs = 'EPSG:4326' + target_crs = 'EPSG:32630' + + # Call the function + G_new = go.reproject_graph(G, source_crs, target_crs) + + # Test node coordinates + assert almost_equal(G_new.nodes[1]['x'], 833978.5569194595) + assert almost_equal(G_new.nodes[1]['y'], 0) + assert almost_equal(G_new.nodes[2]['x'], 945396.6839773951) + assert almost_equal(G_new.nodes[2]['y'], 110801.83254625657) + assert almost_equal(G_new.nodes[3]['x'], 945193.8596723974) + assert almost_equal(G_new.nodes[3]['y'], 221604.0105092727) + + # Test edge geometry + assert almost_equal(list(G_new[1][2]['geometry'].coords)[0][0], + 833978.5569194595) + assert almost_equal(list(G_new[2][3]['geometry'].coords)[0][0], + 945396.6839773951) + +def test_nearest_node_buffer(): + """Test the nearest_node_buffer function.""" + # Create mock dictionaries of points + points1 = {'a': sgeom.Point(0, 0), 'b': sgeom.Point(1, 1)} + points2 = {'c': sgeom.Point(0.5, 0.5), 'd': sgeom.Point(2, 2)} + + # Define the input threshold + threshold = 1.0 + + # Call the function + matching = go.nearest_node_buffer(points1, points2, threshold) + + # Check if the function returns the correct matching nodes + assert matching == {'a': 'c', 'b': 'c'} + +def test_burn_shape_in_raster(): + """Test the burn_shape_in_raster function.""" + # Create a mock geometry + geoms = [sgeom.LineString([(0, 0), (1, 1)]), + sgeom.Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])] + + # Define the input parameters + depth = 1.0 + raster_fid = 'input.tif' + new_raster_fid = 'output.tif' + try: + create_raster(raster_fid) + + # Call the function + go.burn_shape_in_raster(geoms, depth, raster_fid, new_raster_fid) + + with rst.open(raster_fid) as src: + data_ = src.read(1) + + # Open the new GeoTIFF file and check if it has been correctly modified + with rst.open(new_raster_fid) as src: + data = src.read(1) + assert (data != data_).any() + finally: + # Regardless of test outcome, delete the temp file + if os.path.exists(raster_fid): + os.remove(raster_fid) + if os.path.exists(new_raster_fid): + os.remove(new_raster_fid) \ No newline at end of file