diff --git a/tests/test_coreg/test_workflows.py b/tests/test_coreg/test_workflows.py index d710df4b..88d0cbf8 100644 --- a/tests/test_coreg/test_workflows.py +++ b/tests/test_coreg/test_workflows.py @@ -320,3 +320,19 @@ def test_dem_coregistration(self) -> None: estimated_initial_shift=test_shift_tuple, random_state=42, ) + + # Check if the appropriate exception is raised with a wrong type initial shift + with pytest.raises(ValueError, match=r".*two numerical values*"): + dem_coregistration( + tba_dem, + ref_dem, + estimated_initial_shift=["2", 2], + random_state=42, + ) + with pytest.raises(ValueError, match=r".*two numerical values*"): + dem_coregistration( + tba_dem, + ref_dem, + estimated_initial_shift=[2, 3, 5], + random_state=42, + ) diff --git a/xdem/coreg/workflows.py b/xdem/coreg/workflows.py index 9f10ba3e..4ec4b390 100644 --- a/xdem/coreg/workflows.py +++ b/xdem/coreg/workflows.py @@ -227,6 +227,14 @@ def dem_coregistration( # Ensure that if an initial shift is provided, at least one coregistration method is affine. if estimated_initial_shift: + if not ( + isinstance(estimated_initial_shift, (list, tuple)) + and len(estimated_initial_shift) == 2 + and all(isinstance(val, (float, int)) for val in estimated_initial_shift) + ): + raise ValueError( + "Argument `estimated_initial_shift` must be a list or tuple of exactly two numerical values." + ) if isinstance(coreg_method, CoregPipeline): if not any(isinstance(step, AffineCoreg) for step in coreg_method.pipeline): raise TypeError(