diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index 8576e41b..6749f4c7 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- """Created 2024-01-20. +A module containing functions to perform a variety of geospatial operations, +such as reprojecting coordinates and handling raster data. + @author: Barnaby Dobson """ from functools import lru_cache @@ -21,7 +24,7 @@ TransformerFromCRS = lru_cache(pyproj.transformer.Transformer.from_crs) -def get_utm_crs(x: float, +def get_utm_epsg(x: float, y: float, crs: str | int | pyproj.CRS = 'EPSG:4326', datum_name: str = "WGS 84"): @@ -141,7 +144,7 @@ def interpolate_points_on_raster(x: list[float], bounds_error=False, fill_value=None) # Interpolate for x,y - return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] + return [interp_with_nans((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] def reproject_raster(target_crs: str, fid: str, diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index cf320c62..fd70f724 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -16,8 +16,8 @@ from swmmanywhere import geospatial_operations as go -def test_interp_wrap(): - """Test the interp_wrap function.""" +def test_interp_with_nans(): + """Test the interp_interp_with_nans function.""" # Define a simple grid and values x = np.linspace(0, 1, 5) y = np.linspace(0, 1, 5) @@ -32,13 +32,13 @@ def test_interp_wrap(): # Test the function at a point inside the grid yx = (0.875, 0.875) - result = go.interp_wrap(yx, interp, grid, values) + result = go.interp_with_nans(yx, interp, grid, values) assert result == 0.875 # Test the function on a nan point values_grid[1][1] = np.nan yx = (0.251, 0.25) - result = go.interp_wrap(yx, interp, grid, values) + result = go.interp_with_nans(yx, interp, grid, values) assert result == values_grid[1][2] @patch('rasterio.open') @@ -70,11 +70,11 @@ def test_interpolate_points_on_raster(mock_rst_open): def test_get_utm(): """Test the get_utm_epsg function.""" # Test a northern hemisphere point - crs = go.get_utm_epsg(-1, 51) + crs = go.get_utm_epsg(-1.0, 51.0) assert crs == 'EPSG:32630' # Test a southern hemisphere point - crs = go.get_utm_epsg(-1, -51) + crs = go.get_utm_epsg(-1.0, -51.0) assert crs == 'EPSG:32730' def create_raster(fid):