From f32489972b01e76ed97bfc154af7c2649605e690 Mon Sep 17 00:00:00 2001 From: vschaffn <152897324+vschaffn@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:24:46 +0100 Subject: [PATCH] Add Initial shift Capability to DEM Coregistration (#650) --- doc/source/api.md | 8 +++ doc/source/coregistration.md | 9 +++- tests/test_coreg/test_workflows.py | 82 ++++++++++++++++++++++++++++-- xdem/coreg/workflows.py | 82 +++++++++++++++++++++++++----- 4 files changed, 163 insertions(+), 18 deletions(-) diff --git a/doc/source/api.md b/doc/source/api.md index e46532cc..085378ba 100644 --- a/doc/source/api.md +++ b/doc/source/api.md @@ -231,6 +231,14 @@ To build and pass your coregistration pipeline to {func}`~xdem.DEM.coregister_3d coreg.Coreg.meta ``` +#### Quick coregistration +```{eval-rst} +.. autosummary:: + :toctree: gen_modules/ + + coreg.workflows.dem_coregistration +``` + ### Affine coregistration #### Parent object (to define custom methods) diff --git a/doc/source/coregistration.md b/doc/source/coregistration.md index efc4e930..d0b6119c 100644 --- a/doc/source/coregistration.md +++ b/doc/source/coregistration.md @@ -55,7 +55,8 @@ my_coreg_pipeline = xdem.coreg.ICP() + xdem.coreg.NuthKaab() my_coreg_pipeline = xdem.coreg.NuthKaab() ``` -Then, coregistering a pair of elevation data can be done by calling {func}`xdem.DEM.coregister_3d` from the DEM that should be aligned. +Then, coregistering a pair of elevation data can be done by calling {func}`xdem.coreg.workflows.dem_coregistration`, or +{func}`xdem.DEM.coregister_3d` from the DEM that should be aligned. ```{code-cell} ipython3 :tags: [hide-cell] @@ -66,12 +67,18 @@ Then, coregistering a pair of elevation data can be done by calling {func}`xdem. import geoutils as gu import numpy as np import matplotlib.pyplot as plt +from xdem.coreg.workflows import dem_coregistration # Open a reference and to-be-aligned DEM ref_dem = xdem.DEM(xdem.examples.get_path("longyearbyen_ref_dem")) tba_dem = xdem.DEM(xdem.examples.get_path("longyearbyen_tba_dem")) ``` +```{code-cell} ipython3 +# Coregister by calling the dem_coregistration function +aligned_dem = dem_coregistration(tba_dem, ref_dem, coreg_method=my_coreg_pipeline) +``` + ```{code-cell} ipython3 # Coregister by calling the DEM method aligned_dem = tba_dem.coregister_3d(ref_dem, my_coreg_pipeline) diff --git a/tests/test_coreg/test_workflows.py b/tests/test_coreg/test_workflows.py index 8e178d79..3e7f0157 100644 --- a/tests/test_coreg/test_workflows.py +++ b/tests/test_coreg/test_workflows.py @@ -256,9 +256,83 @@ def test_dem_coregistration(self) -> None: out_fig.close() # Testing different coreg method - dem_coreg, coreg_method, coreg_stats, inlier_mask = dem_coregistration( + dem_coreg2, coreg_method2, coreg_stats2, inlier_mask2 = dem_coregistration( tba_dem, ref_dem, coreg_method=xdem.coreg.Deramp() ) - assert isinstance(coreg_method, xdem.coreg.Deramp) - assert abs(coreg_stats["med_orig"].values) > abs(coreg_stats["med_coreg"].values) - assert coreg_stats["nmad_orig"].values > coreg_stats["nmad_coreg"].values + assert isinstance(coreg_method2, xdem.coreg.Deramp) + assert abs(coreg_stats2["med_orig"].values) > abs(coreg_stats2["med_coreg"].values) + assert coreg_stats2["nmad_orig"].values > coreg_stats2["nmad_coreg"].values + + # Testing with initial shift + test_shift_list = [10, 5] + tba_dem_origin = tba_dem.copy() + coreg_pipeline = xdem.coreg.affine.NuthKaab() + xdem.coreg.affine.VerticalShift() + + dem_coreg2, coreg_method2, coreg_stats2, inlier_mask2 = dem_coregistration( + tba_dem, ref_dem, coreg_method=coreg_pipeline, estimated_initial_shift=test_shift_list, random_state=42 + ) + dem_coreg3, coreg_method3, coreg_stats3, inlier_mask3 = dem_coregistration( + tba_dem, ref_dem, coreg_method=coreg_pipeline, random_state=42 + ) + assert tba_dem.raster_equal(tba_dem_origin) + assert isinstance(coreg_method2, xdem.coreg.CoregPipeline) + assert isinstance(coreg_method3, xdem.coreg.CoregPipeline) + assert isinstance(coreg_method2.pipeline[0], xdem.coreg.AffineCoreg) + assert isinstance(coreg_method3.pipeline[0], xdem.coreg.AffineCoreg) + assert ( + coreg_method2.pipeline[0].meta["outputs"]["affine"]["shift_x"] + == coreg_method3.pipeline[0].meta["outputs"]["affine"]["shift_x"] + ) + assert ( + coreg_method2.pipeline[0].meta["outputs"]["affine"]["shift_y"] + == coreg_method3.pipeline[0].meta["outputs"]["affine"]["shift_y"] + ) + + # Testing without coreg pipeline + test_shift_tuple = (-5, 2) # tuple + coreg_simple = xdem.coreg.affine.DhMinimize() + + dem_coreg2, coreg_method2, coreg_stats2, inlier_mask2 = dem_coregistration( + tba_dem, ref_dem, coreg_method=coreg_simple, estimated_initial_shift=test_shift_tuple, random_state=42 + ) + dem_coreg3, coreg_method3, coreg_stats3, inlier_mask3 = dem_coregistration( + tba_dem, ref_dem, coreg_method=coreg_simple, random_state=42 + ) + assert isinstance(coreg_method2, xdem.coreg.AffineCoreg) + assert isinstance(coreg_method3, xdem.coreg.AffineCoreg) + assert coreg_method2.meta["outputs"]["affine"]["shift_x"] == coreg_method3.meta["outputs"]["affine"]["shift_x"] + assert coreg_method2.meta["outputs"]["affine"]["shift_y"] == coreg_method3.meta["outputs"]["affine"]["shift_y"] + + # Check if the appropriate exception is raised with an initial shift and without affine coreg + with pytest.raises(TypeError, match=r".*affine.*"): + dem_coregistration( + tba_dem, + ref_dem, + coreg_method=xdem.coreg.Deramp(), + estimated_initial_shift=test_shift_tuple, + random_state=42, + ) + with pytest.raises(TypeError, match=r".*affine.*"): + dem_coregistration( + tba_dem, + ref_dem, + coreg_method=xdem.coreg.Deramp() + xdem.coreg.TerrainBias(), + estimated_initial_shift=test_shift_tuple, + random_state=42, + ) + + # Check if the appropriate exception is raised with a wrong type initial shift + with pytest.raises(ValueError, match=r".*two numerical values.*"): + dem_coregistration( + tba_dem, + ref_dem, + estimated_initial_shift=["2", 2], + random_state=42, + ) + with pytest.raises(ValueError, match=r".*two numerical values.*"): + dem_coregistration( + tba_dem, + ref_dem, + estimated_initial_shift=[2, 3, 5], + random_state=42, + ) diff --git a/xdem/coreg/workflows.py b/xdem/coreg/workflows.py index 04ec1ee4..4ec4b390 100644 --- a/xdem/coreg/workflows.py +++ b/xdem/coreg/workflows.py @@ -31,6 +31,7 @@ from geoutils.raster import RasterType from xdem._typing import NDArrayf +from xdem.coreg import AffineCoreg, CoregPipeline from xdem.coreg.affine import NuthKaab, VerticalShift from xdem.coreg.base import Coreg from xdem.dem import DEM @@ -148,7 +149,7 @@ def dem_coregistration( src_dem_path: str | RasterType, ref_dem_path: str | RasterType, out_dem_path: str | None = None, - coreg_method: Coreg | None = None, + coreg_method: Coreg | CoregPipeline | None = None, grid: str = "ref", resample: bool = False, resampling: rio.warp.Resampling | None = rio.warp.Resampling.bilinear, @@ -161,7 +162,8 @@ def dem_coregistration( random_state: int | np.random.Generator | None = None, plot: bool = False, out_fig: str = None, -) -> tuple[DEM, Coreg, pd.DataFrame, NDArrayf]: + estimated_initial_shift: list[Number] | tuple[Number, Number] | None = None, +) -> tuple[DEM, Coreg | CoregPipeline, pd.DataFrame, NDArrayf]: """ A one-line function to coregister a selected DEM to a reference DEM. @@ -173,7 +175,7 @@ def dem_coregistration( :param src_dem_path: Path to the input DEM to be coregistered :param ref_dem_path: Path to the reference DEM :param out_dem_path: Path where to save the coregistered DEM. If set to None (default), will not save to file. - :param coreg_method: Coregistration method or pipeline. Defaults to NuthKaab + VerticalShift. + :param coreg_method: Coregistration method, or pipeline. :param grid: The grid to be used during coregistration, set either to "ref" or "src". :param resample: If set to True, will reproject output Raster on the same grid as input. Otherwise, only \ the array/transform will be updated (if possible) and no resampling is done. Useful to avoid spreading data gaps. @@ -189,6 +191,8 @@ def dem_coregistration( :param random_state: Random state or seed number to use for subsampling and optimizer. :param plot: Set to True to plot a figure of elevation diff before/after coregistration. :param out_fig: Path to the output figure. If None will display to screen. + :param estimated_initial_shift: List containing x and y shifts (in pixels). These shifts are applied before \ +the coregistration process begins. :returns: A tuple containing 1) coregistered DEM as an xdem.DEM instance 2) the coregistration method \ 3) DataFrame of coregistration statistics (count of obs, median and NMAD over stable terrain) before and after \ @@ -221,21 +225,52 @@ def dem_coregistration( if grid not in ["ref", "src"]: raise ValueError(f"Argument `grid` must be either 'ref' or 'src' - currently set to {grid}.") + # Ensure that if an initial shift is provided, at least one coregistration method is affine. + if estimated_initial_shift: + if not ( + isinstance(estimated_initial_shift, (list, tuple)) + and len(estimated_initial_shift) == 2 + and all(isinstance(val, (float, int)) for val in estimated_initial_shift) + ): + raise ValueError( + "Argument `estimated_initial_shift` must be a list or tuple of exactly two numerical values." + ) + if isinstance(coreg_method, CoregPipeline): + if not any(isinstance(step, AffineCoreg) for step in coreg_method.pipeline): + raise TypeError( + "An initial shift has been provided, but none of the coregistration methods in the pipeline " + "are affine. At least one affine coregistration method (e.g., AffineCoreg) is required." + ) + elif not isinstance(coreg_method, AffineCoreg): + raise TypeError( + "An initial shift has been provided, but the coregistration method is not affine. " + "An affine coregistration method (e.g., AffineCoreg) is required." + ) + # Load both DEMs logging.info("Loading and reprojecting input data") if isinstance(ref_dem_path, str): - if grid == "ref": - ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=0) - elif grid == "src": - ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=1) - else: + ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path]) + + elif isinstance(src_dem_path, gu.Raster): ref_dem = ref_dem_path - src_dem = src_dem_path - if grid == "ref": - src_dem = src_dem.reproject(ref_dem, silent=True) - elif grid == "src": - ref_dem = ref_dem.reproject(src_dem, silent=True) + src_dem = src_dem_path.copy() + + # If an initial shift is provided, apply it before coregistration + if estimated_initial_shift: + + # convert shift + shift_x = estimated_initial_shift[0] * src_dem.res[0] + shift_y = estimated_initial_shift[1] * src_dem.res[1] + + # Apply the shift to the source dem + src_dem.translate(shift_x, shift_y, inplace=True) + + if grid == "ref": + src_dem = src_dem.reproject(ref_dem, silent=True) + elif grid == "src": + ref_dem = ref_dem.reproject(src_dem, silent=True) # Convert to DEM instance with Float32 dtype # TODO: Could only convert types int into float, but any other float dtype should yield very similar results @@ -268,6 +303,27 @@ def dem_coregistration( coreg_method.fit(ref_dem, src_dem, inlier_mask, random_state=random_state) dem_coreg = coreg_method.apply(src_dem, resample=resample, resampling=resampling) + # Add the initial shift to the calculated shift + if estimated_initial_shift: + + def update_shift( + coreg_method: Coreg | CoregPipeline, shift_x: float = shift_x, shift_y: float = shift_y + ) -> None: + if isinstance(coreg_method, CoregPipeline): + for step in coreg_method.pipeline: + update_shift(step) + else: + # check if the keys exists + if "outputs" in coreg_method.meta and "affine" in coreg_method.meta["outputs"]: + if "shift_x" in coreg_method.meta["outputs"]["affine"]: + coreg_method.meta["outputs"]["affine"]["shift_x"] += shift_x + logging.debug(f"Updated shift_x by {shift_x} in {coreg_method}") + if "shift_y" in coreg_method.meta["outputs"]["affine"]: + coreg_method.meta["outputs"]["affine"]["shift_y"] += shift_y + logging.debug(f"Updated shift_y by {shift_y} in {coreg_method}") + + update_shift(coreg_method) + # Calculate coregistered ddem (might need resampling if resample set to False), needed for stats and plot only ddem_coreg = dem_coreg.reproject(ref_dem, silent=True) - ref_dem