Skip to content

Commit

Permalink
Improve tests for subsampling in NuthKaab
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet committed Oct 9, 2023
1 parent f8faf97 commit 22aad4a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
5 changes: 3 additions & 2 deletions tests/test_coreg/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,9 @@ def test_get_subsample_on_valid_mask(self, subsample: float | int) -> None:
subsample_val = subsample
assert np.count_nonzero(subsample_mask) == min(subsample_val, np.count_nonzero(valid_mask))

# TODO: Activate NuthKaab once subsampling there is made consistent
all_coregs = [
coreg.VerticalShift,
# coreg.NuthKaab,
coreg.NuthKaab,
coreg.ICP,
coreg.Deramp,
coreg.TerrainBias,
Expand All @@ -156,6 +155,8 @@ def test_subsample(self, coreg: Callable) -> None: # type: ignore
# But can be overridden during fit
coreg_full.fit(**self.fit_params, subsample=10000, random_state=42)
assert coreg_full._meta["subsample"] == 10000
# Check that the random state is properly set when subsampling explicitly or implicitly
assert coreg_full._meta["random_state"] == 42

# Test subsampled vertical shift correction
coreg_sub = coreg(subsample=0.1)
Expand Down
2 changes: 1 addition & 1 deletion xdem/coreg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def fit(
# In any case, override!
self._meta["subsample"] = subsample

# Save random_state is a subsample is used
# Save random_state if a subsample is used
if self._meta["subsample"] != 1:
self._meta["random_state"] = random_state

Expand Down
4 changes: 4 additions & 0 deletions xdem/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def process_coregistered_examples(name: str, overwrite: bool = False) -> None:

nuth_kaab = xdem.coreg.NuthKaab()
nuth_kaab.fit(reference_raster, to_be_aligned_raster, inlier_mask=inlier_mask, random_state=42)

# Check that random state is respected
assert nuth_kaab._meta["random_state"] == 42

aligned_raster = nuth_kaab.apply(to_be_aligned_raster, resample=True)

diff = reference_raster - aligned_raster
Expand Down

0 comments on commit 22aad4a

Please sign in to comment.