Skip to content

Commit

Permalink
Add Initial shift Capability to DEM Coregistration (#650)
Browse files Browse the repository at this point in the history
  • Loading branch information
vschaffn authored Dec 3, 2024
1 parent e0afc9d commit f324899
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 18 deletions.
8 changes: 8 additions & 0 deletions doc/source/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ To build and pass your coregistration pipeline to {func}`~xdem.DEM.coregister_3d
coreg.Coreg.meta
```

#### Quick coregistration
```{eval-rst}
.. autosummary::
:toctree: gen_modules/
coreg.workflows.dem_coregistration
```

### Affine coregistration

#### Parent object (to define custom methods)
Expand Down
9 changes: 8 additions & 1 deletion doc/source/coregistration.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ my_coreg_pipeline = xdem.coreg.ICP() + xdem.coreg.NuthKaab()
my_coreg_pipeline = xdem.coreg.NuthKaab()
```

Then, coregistering a pair of elevation data can be done by calling {func}`xdem.DEM.coregister_3d` from the DEM that should be aligned.
Then, coregistering a pair of elevation data can be done by calling {func}`xdem.coreg.workflows.dem_coregistration`, or
{func}`xdem.DEM.coregister_3d` from the DEM that should be aligned.

```{code-cell} ipython3
:tags: [hide-cell]
Expand All @@ -66,12 +67,18 @@ Then, coregistering a pair of elevation data can be done by calling {func}`xdem.
import geoutils as gu
import numpy as np
import matplotlib.pyplot as plt
from xdem.coreg.workflows import dem_coregistration
# Open a reference and to-be-aligned DEM
ref_dem = xdem.DEM(xdem.examples.get_path("longyearbyen_ref_dem"))
tba_dem = xdem.DEM(xdem.examples.get_path("longyearbyen_tba_dem"))
```

```{code-cell} ipython3
# Coregister by calling the dem_coregistration function
aligned_dem = dem_coregistration(tba_dem, ref_dem, coreg_method=my_coreg_pipeline)
```

```{code-cell} ipython3
# Coregister by calling the DEM method
aligned_dem = tba_dem.coregister_3d(ref_dem, my_coreg_pipeline)
Expand Down
82 changes: 78 additions & 4 deletions tests/test_coreg/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,83 @@ def test_dem_coregistration(self) -> None:
out_fig.close()

# Testing different coreg method
dem_coreg, coreg_method, coreg_stats, inlier_mask = dem_coregistration(
dem_coreg2, coreg_method2, coreg_stats2, inlier_mask2 = dem_coregistration(
tba_dem, ref_dem, coreg_method=xdem.coreg.Deramp()
)
assert isinstance(coreg_method, xdem.coreg.Deramp)
assert abs(coreg_stats["med_orig"].values) > abs(coreg_stats["med_coreg"].values)
assert coreg_stats["nmad_orig"].values > coreg_stats["nmad_coreg"].values
assert isinstance(coreg_method2, xdem.coreg.Deramp)
assert abs(coreg_stats2["med_orig"].values) > abs(coreg_stats2["med_coreg"].values)
assert coreg_stats2["nmad_orig"].values > coreg_stats2["nmad_coreg"].values

# Testing with initial shift
test_shift_list = [10, 5]
tba_dem_origin = tba_dem.copy()
coreg_pipeline = xdem.coreg.affine.NuthKaab() + xdem.coreg.affine.VerticalShift()

dem_coreg2, coreg_method2, coreg_stats2, inlier_mask2 = dem_coregistration(
tba_dem, ref_dem, coreg_method=coreg_pipeline, estimated_initial_shift=test_shift_list, random_state=42
)
dem_coreg3, coreg_method3, coreg_stats3, inlier_mask3 = dem_coregistration(
tba_dem, ref_dem, coreg_method=coreg_pipeline, random_state=42
)
assert tba_dem.raster_equal(tba_dem_origin)
assert isinstance(coreg_method2, xdem.coreg.CoregPipeline)
assert isinstance(coreg_method3, xdem.coreg.CoregPipeline)
assert isinstance(coreg_method2.pipeline[0], xdem.coreg.AffineCoreg)
assert isinstance(coreg_method3.pipeline[0], xdem.coreg.AffineCoreg)
assert (
coreg_method2.pipeline[0].meta["outputs"]["affine"]["shift_x"]
== coreg_method3.pipeline[0].meta["outputs"]["affine"]["shift_x"]
)
assert (
coreg_method2.pipeline[0].meta["outputs"]["affine"]["shift_y"]
== coreg_method3.pipeline[0].meta["outputs"]["affine"]["shift_y"]
)

# Testing without coreg pipeline
test_shift_tuple = (-5, 2) # tuple
coreg_simple = xdem.coreg.affine.DhMinimize()

dem_coreg2, coreg_method2, coreg_stats2, inlier_mask2 = dem_coregistration(
tba_dem, ref_dem, coreg_method=coreg_simple, estimated_initial_shift=test_shift_tuple, random_state=42
)
dem_coreg3, coreg_method3, coreg_stats3, inlier_mask3 = dem_coregistration(
tba_dem, ref_dem, coreg_method=coreg_simple, random_state=42
)
assert isinstance(coreg_method2, xdem.coreg.AffineCoreg)
assert isinstance(coreg_method3, xdem.coreg.AffineCoreg)
assert coreg_method2.meta["outputs"]["affine"]["shift_x"] == coreg_method3.meta["outputs"]["affine"]["shift_x"]
assert coreg_method2.meta["outputs"]["affine"]["shift_y"] == coreg_method3.meta["outputs"]["affine"]["shift_y"]

# Check if the appropriate exception is raised with an initial shift and without affine coreg
with pytest.raises(TypeError, match=r".*affine.*"):
dem_coregistration(
tba_dem,
ref_dem,
coreg_method=xdem.coreg.Deramp(),
estimated_initial_shift=test_shift_tuple,
random_state=42,
)
with pytest.raises(TypeError, match=r".*affine.*"):
dem_coregistration(
tba_dem,
ref_dem,
coreg_method=xdem.coreg.Deramp() + xdem.coreg.TerrainBias(),
estimated_initial_shift=test_shift_tuple,
random_state=42,
)

# Check if the appropriate exception is raised with a wrong type initial shift
with pytest.raises(ValueError, match=r".*two numerical values.*"):
dem_coregistration(
tba_dem,
ref_dem,
estimated_initial_shift=["2", 2],
random_state=42,
)
with pytest.raises(ValueError, match=r".*two numerical values.*"):
dem_coregistration(
tba_dem,
ref_dem,
estimated_initial_shift=[2, 3, 5],
random_state=42,
)
82 changes: 69 additions & 13 deletions xdem/coreg/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from geoutils.raster import RasterType

from xdem._typing import NDArrayf
from xdem.coreg import AffineCoreg, CoregPipeline
from xdem.coreg.affine import NuthKaab, VerticalShift
from xdem.coreg.base import Coreg
from xdem.dem import DEM
Expand Down Expand Up @@ -148,7 +149,7 @@ def dem_coregistration(
src_dem_path: str | RasterType,
ref_dem_path: str | RasterType,
out_dem_path: str | None = None,
coreg_method: Coreg | None = None,
coreg_method: Coreg | CoregPipeline | None = None,
grid: str = "ref",
resample: bool = False,
resampling: rio.warp.Resampling | None = rio.warp.Resampling.bilinear,
Expand All @@ -161,7 +162,8 @@ def dem_coregistration(
random_state: int | np.random.Generator | None = None,
plot: bool = False,
out_fig: str = None,
) -> tuple[DEM, Coreg, pd.DataFrame, NDArrayf]:
estimated_initial_shift: list[Number] | tuple[Number, Number] | None = None,
) -> tuple[DEM, Coreg | CoregPipeline, pd.DataFrame, NDArrayf]:
"""
A one-line function to coregister a selected DEM to a reference DEM.
Expand All @@ -173,7 +175,7 @@ def dem_coregistration(
:param src_dem_path: Path to the input DEM to be coregistered
:param ref_dem_path: Path to the reference DEM
:param out_dem_path: Path where to save the coregistered DEM. If set to None (default), will not save to file.
:param coreg_method: Coregistration method or pipeline. Defaults to NuthKaab + VerticalShift.
:param coreg_method: Coregistration method, or pipeline.
:param grid: The grid to be used during coregistration, set either to "ref" or "src".
:param resample: If set to True, will reproject output Raster on the same grid as input. Otherwise, only \
the array/transform will be updated (if possible) and no resampling is done. Useful to avoid spreading data gaps.
Expand All @@ -189,6 +191,8 @@ def dem_coregistration(
:param random_state: Random state or seed number to use for subsampling and optimizer.
:param plot: Set to True to plot a figure of elevation diff before/after coregistration.
:param out_fig: Path to the output figure. If None will display to screen.
:param estimated_initial_shift: List containing x and y shifts (in pixels). These shifts are applied before \
the coregistration process begins.
:returns: A tuple containing 1) coregistered DEM as an xdem.DEM instance 2) the coregistration method \
3) DataFrame of coregistration statistics (count of obs, median and NMAD over stable terrain) before and after \
Expand Down Expand Up @@ -221,21 +225,52 @@ def dem_coregistration(
if grid not in ["ref", "src"]:
raise ValueError(f"Argument `grid` must be either 'ref' or 'src' - currently set to {grid}.")

# Ensure that if an initial shift is provided, at least one coregistration method is affine.
if estimated_initial_shift:
if not (
isinstance(estimated_initial_shift, (list, tuple))
and len(estimated_initial_shift) == 2
and all(isinstance(val, (float, int)) for val in estimated_initial_shift)
):
raise ValueError(
"Argument `estimated_initial_shift` must be a list or tuple of exactly two numerical values."
)
if isinstance(coreg_method, CoregPipeline):
if not any(isinstance(step, AffineCoreg) for step in coreg_method.pipeline):
raise TypeError(
"An initial shift has been provided, but none of the coregistration methods in the pipeline "
"are affine. At least one affine coregistration method (e.g., AffineCoreg) is required."
)
elif not isinstance(coreg_method, AffineCoreg):
raise TypeError(
"An initial shift has been provided, but the coregistration method is not affine. "
"An affine coregistration method (e.g., AffineCoreg) is required."
)

# Load both DEMs
logging.info("Loading and reprojecting input data")

if isinstance(ref_dem_path, str):
if grid == "ref":
ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=0)
elif grid == "src":
ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=1)
else:
ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path])

elif isinstance(src_dem_path, gu.Raster):
ref_dem = ref_dem_path
src_dem = src_dem_path
if grid == "ref":
src_dem = src_dem.reproject(ref_dem, silent=True)
elif grid == "src":
ref_dem = ref_dem.reproject(src_dem, silent=True)
src_dem = src_dem_path.copy()

# If an initial shift is provided, apply it before coregistration
if estimated_initial_shift:

# convert shift
shift_x = estimated_initial_shift[0] * src_dem.res[0]
shift_y = estimated_initial_shift[1] * src_dem.res[1]

# Apply the shift to the source dem
src_dem.translate(shift_x, shift_y, inplace=True)

if grid == "ref":
src_dem = src_dem.reproject(ref_dem, silent=True)
elif grid == "src":
ref_dem = ref_dem.reproject(src_dem, silent=True)

# Convert to DEM instance with Float32 dtype
# TODO: Could only convert types int into float, but any other float dtype should yield very similar results
Expand Down Expand Up @@ -268,6 +303,27 @@ def dem_coregistration(
coreg_method.fit(ref_dem, src_dem, inlier_mask, random_state=random_state)
dem_coreg = coreg_method.apply(src_dem, resample=resample, resampling=resampling)

# Add the initial shift to the calculated shift
if estimated_initial_shift:

def update_shift(
coreg_method: Coreg | CoregPipeline, shift_x: float = shift_x, shift_y: float = shift_y
) -> None:
if isinstance(coreg_method, CoregPipeline):
for step in coreg_method.pipeline:
update_shift(step)
else:
# check if the keys exists
if "outputs" in coreg_method.meta and "affine" in coreg_method.meta["outputs"]:
if "shift_x" in coreg_method.meta["outputs"]["affine"]:
coreg_method.meta["outputs"]["affine"]["shift_x"] += shift_x
logging.debug(f"Updated shift_x by {shift_x} in {coreg_method}")
if "shift_y" in coreg_method.meta["outputs"]["affine"]:
coreg_method.meta["outputs"]["affine"]["shift_y"] += shift_y
logging.debug(f"Updated shift_y by {shift_y} in {coreg_method}")

update_shift(coreg_method)

# Calculate coregistered ddem (might need resampling if resample set to False), needed for stats and plot only
ddem_coreg = dem_coreg.reproject(ref_dem, silent=True) - ref_dem

Expand Down

0 comments on commit f324899

Please sign in to comment.