diff --git a/tests/test_coreg/test_base.py b/tests/test_coreg/test_base.py index a6d3d976..b8b423f0 100644 --- a/tests/test_coreg/test_base.py +++ b/tests/test_coreg/test_base.py @@ -1204,6 +1204,37 @@ def test_warp_dem() -> None: # Due to the randomness, the threshold is quite high, but would be something like 10+ if it was incorrect. assert spatialstats.nmad(dem - untransformed_dem) < 0.5 + # Test with Z-correction disabled + transformed_dem_no_z = coreg.base.warp_dem( + dem=dem, + transform=transform, + source_coords=source_coords, + destination_coords=dest_coords, + resampling="linear", + apply_z_correction=False, + ) + + # Try to undo the warp by reversing the source-destination coordinates with Z-correction disabled + untransformed_dem_no_z = coreg.base.warp_dem( + dem=transformed_dem_no_z, + transform=transform, + source_coords=dest_coords, + destination_coords=source_coords, + resampling="linear", + apply_z_correction=False, + ) + + # Validate that the DEM is now more or less the same as the original, with Z-correction disabled. + # The result should be similar to the original, but with no Z-shift applied. + assert spatialstats.nmad(dem - untransformed_dem_no_z) < 0.5 + + # The difference between the two DEMs should be the vertical shift. + # We expect the difference to be approximately equal to the average vertical shift. + expected_vshift = np.mean(dest_coords[:, 2] - source_coords[:, 2]) + + # Check that the mean difference between the DEMs matches the expected vertical shift. + assert np.nanmean(transformed_dem_no_z - transformed_dem) == pytest.approx(expected_vshift, rel=0.3) + if False: import matplotlib.pyplot as plt diff --git a/xdem/coreg/base.py b/xdem/coreg/base.py index 5222ddc3..ae68aa94 100644 --- a/xdem/coreg/base.py +++ b/xdem/coreg/base.py @@ -3013,6 +3013,7 @@ def __init__( success_threshold: float = 0.8, n_threads: int | None = None, warn_failures: bool = False, + apply_z_correction: bool = True, ) -> None: """ Instantiate a blockwise processing object. @@ -3022,6 +3023,7 @@ def __init__( :param success_threshold: Raise an error if fewer chunks than the fraction failed for any reason. :param n_threads: The maximum amount of threads to use. Default=auto :param warn_failures: Trigger or ignore warnings for each exception/warning in each block. + :param apply_z_correction: Boolean to toggle whether the Z-offset correction is applied or not (default True). """ if isinstance(step, type): raise ValueError( @@ -3032,6 +3034,7 @@ def __init__( self.success_threshold = success_threshold self.n_threads = n_threads self.warn_failures = warn_failures + self.apply_z_correction = apply_z_correction super().__init__() @@ -3396,6 +3399,7 @@ def _apply_rst( source_coords=all_points[:, :, 1], destination_coords=all_points[:, :, 0], resampling="linear", + apply_z_correction=self.apply_z_correction, ) return warped_dem, transform @@ -3436,6 +3440,7 @@ def warp_dem( resampling: str = "cubic", trim_border: bool = True, dilate_mask: bool = True, + apply_z_correction: bool = True, ) -> NDArrayf: """ (22/08/24: Method currently used only for blockwise coregistration) @@ -3448,6 +3453,7 @@ def warp_dem( :param resampling: The resampling order to use. Choices: ['nearest', 'linear', 'cubic']. :param trim_border: Remove values outside of the interpolation regime (True) or leave them unmodified (False). :param dilate_mask: Dilate the nan mask to exclude edge pixels that could be wrong. + :param apply_z_correction: Boolean to toggle whether the Z-offset correction is applied or not (default True). :raises ValueError: If the inputs are poorly formatted. :raises AssertionError: For unexpected outputs. @@ -3534,8 +3540,8 @@ def warp_dem( warped[new_mask] = np.nan - # If the coordinates are 3D (N, 3), apply a Z correction as well. - if not no_vertical: + # Apply the Z-correction if apply_z_correction is True and if the coordinates are 3D (N, 3) + if not no_vertical and apply_z_correction: grid_offsets = scipy.interpolate.griddata( points=destination_coords_scaled[:, :2], values=source_coords_scaled[:, 2] - destination_coords_scaled[:, 2],