Skip to content

Commit

Permalink
fix: typing
Browse files Browse the repository at this point in the history
  • Loading branch information
vschaffn committed Nov 27, 2024
1 parent 5c72f02 commit 8a356e8
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions xdem/coreg/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from geoutils.raster import RasterType

from xdem._typing import NDArrayf
from xdem.coreg import CoregPipeline
from xdem.coreg import AffineCoreg, CoregPipeline
from xdem.coreg.affine import NuthKaab, VerticalShift
from xdem.coreg.base import Coreg
from xdem.dem import DEM
Expand Down Expand Up @@ -162,7 +162,7 @@ def dem_coregistration(
random_state: int | np.random.Generator | None = None,
plot: bool = False,
out_fig: str = None,
estimated_initial_shift: list[float] | tuple[float, float] = None,
estimated_initial_shift: list[Number] | tuple[Number, Number] | None = None,
) -> tuple[DEM, Coreg | CoregPipeline, pd.DataFrame, NDArrayf]:
"""
A one-line function to coregister a selected DEM to a reference DEM.
Expand All @@ -175,7 +175,7 @@ def dem_coregistration(
:param src_dem_path: Path to the input DEM to be coregistered
:param ref_dem_path: Path to the reference DEM
:param out_dem_path: Path where to save the coregistered DEM. If set to None (default), will not save to file.
:param coreg_method: Coregistration method or pipeline. Defaults to NuthKaab + VerticalShift.
:param coreg_method: Coregistration method, or pipeline.
:param grid: The grid to be used during coregistration, set either to "ref" or "src".
:param resample: If set to True, will reproject output Raster on the same grid as input. Otherwise, only \
the array/transform will be updated (if possible) and no resampling is done. Useful to avoid spreading data gaps.
Expand Down Expand Up @@ -225,21 +225,32 @@ def dem_coregistration(
if grid not in ["ref", "src"]:
raise ValueError(f"Argument `grid` must be either 'ref' or 'src' - currently set to {grid}.")

# Ensure that if an initial shift is provided, at least one coregistration method is affine.
if estimated_initial_shift:
if isinstance(coreg_method, CoregPipeline):
if not any(isinstance(step, AffineCoreg) for step in coreg_method.pipeline):
raise TypeError(
"An initial shift has been provided, but none of the coregistration methods in the pipeline "
"are affine. At least one affine coregistration method (e.g., AffineCoreg) is required."
)
elif not isinstance(coreg_method, AffineCoreg):
raise TypeError(
"An initial shift has been provided, but the coregistration method is not affine. "
"An affine coregistration method (e.g., AffineCoreg) is required."
)

# Load both DEMs
logging.info("Loading and reprojecting input data")

if isinstance(ref_dem_path, str):
if grid == "ref":
ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path])
elif grid == "src":
ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path])
else:
ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path])

elif isinstance(src_dem_path, gu.Raster):
ref_dem = ref_dem_path
src_dem = src_dem_path
src_dem = src_dem_path.copy()

# If an initial shift is provided, apply it before coregistration
if estimated_initial_shift:
logging.warning("Initial shift in affine mode only")

# convert shift
shift_x = estimated_initial_shift[0] * src_dem.res[0]
Expand Down

0 comments on commit 8a356e8

Please sign in to comment.