From 2e0da0d1b2292983dcfb13d61890229cbff8d752 Mon Sep 17 00:00:00 2001 From: vschaffn Date: Thu, 5 Dec 2024 10:01:56 +0100 Subject: [PATCH] fix: retrieve center coordinates of failing chunks --- tests/test_coreg/test_base.py | 13 +++--- xdem/coreg/base.py | 80 ++++++++++++++++++++++++----------- 2 files changed, 60 insertions(+), 33 deletions(-) diff --git a/tests/test_coreg/test_base.py b/tests/test_coreg/test_base.py index 233aca00..0c1a7414 100644 --- a/tests/test_coreg/test_base.py +++ b/tests/test_coreg/test_base.py @@ -925,11 +925,8 @@ def test_blockwise_coreg_large_gaps(self) -> None: stats = blockwise.stats() - # We expect holes in the blockwise coregistration, so there should not be 64 "successful" blocks. - assert stats.shape[0] < 64 - - # Statistics are only calculated on finite values, so all of these should be finite as well. - assert np.all(np.isfinite(stats) | np.isnan(stats)) + # We expect holes in the blockwise coregistration, but not in stats due to nan padding for failing chunks + assert stats.shape[0] == 64 # Copy the TBA DEM and set a square portion to nodata tba = self.tba.copy() @@ -939,7 +936,7 @@ def test_blockwise_coreg_large_gaps(self) -> None: blockwise = xdem.coreg.BlockwiseCoreg(xdem.coreg.NuthKaab(), 8, warn_failures=False) - # Align the DEM and apply the blockwise to a zero-array (to get the zshift) + # Align the DEM and apply blockwise to a zero-array (to get the z_shift) aligned = blockwise.fit(self.ref, tba).apply(tba) zshift, _ = blockwise.apply(np.zeros_like(tba.data), transform=tba.transform, crs=tba.crs) @@ -965,8 +962,8 @@ def test_failed_chunks_return_nan(self) -> None: assert np.isnan(result_df.loc[1, "inlier_count"]) assert np.isnan(result_df.loc[1, "nmad"]) assert np.isnan(result_df.loc[1, "median"]) - assert np.isnan(result_df.loc[1, "center_x"]) - assert np.isnan(result_df.loc[1, "center_y"]) + assert isinstance(result_df.loc[1, "center_x"], float) + assert isinstance(result_df.loc[1, "center_y"], float) assert np.isnan(result_df.loc[1, "center_z"]) assert np.isnan(result_df.loc[1, "x_off"]) assert np.isnan(result_df.loc[1, "y_off"]) diff --git a/xdem/coreg/base.py b/xdem/coreg/base.py index e26f131d..a37d81f3 100644 --- a/xdem/coreg/base.py +++ b/xdem/coreg/base.py @@ -3036,6 +3036,7 @@ def __init__( super().__init__() self._meta: CoregDict = {"step_meta": []} + self._groups: NDArrayf = np.array([]) def fit( self: CoregType, @@ -3091,9 +3092,9 @@ def fit( else: mask = inlier_mask - groups = self.subdivide_array(tba_dem.shape if isinstance(tba_dem, np.ndarray) else ref_dem.shape) + self._groups = self.subdivide_array(tba_dem.shape if isinstance(tba_dem, np.ndarray) else ref_dem.shape) - indices = np.unique(groups) + indices = np.unique(self._groups) progress_bar = tqdm( total=indices.size, desc="Processing chunks", disable=logging.getLogger().getEffectiveLevel() > logging.INFO @@ -3108,7 +3109,7 @@ def process(i: int) -> dict[str, Any] | BaseException | None: * If it fails: The associated exception. * If the block is empty: None """ - group_mask = groups == i + group_mask = self._groups == i # Find the corresponding slice of the inlier_mask to subset the data rows, cols = np.where(group_mask) @@ -3272,24 +3273,44 @@ def to_points(self) -> NDArrayf: if len(self._meta["step_meta"]) == 0: raise AssertionError("No coreg results exist. Has '.fit()' been called?") points = np.empty(shape=(0, 3, 2)) - for meta in self._meta["step_meta"]: - self._restore_metadata(meta) - # x_coord, y_coord = rio.transform.xy(meta["transform"], meta["representative_row"], - # meta["representative_col"]) - x_coord, y_coord = meta["representative_x"], meta["representative_y"] + for i in range(self.subdivision): + # Try to restore the metadata for this chunk (if it succeeded) + chunk_meta = next((meta for meta in self._meta["step_meta"] if meta["i"] == i), None) - old_pos_arr = np.reshape([x_coord, y_coord, meta["representative_val"]], (1, 3)) + if chunk_meta is not None: + # Successful chunk: Retrieve the representative X, Y, Z coordinates + self._restore_metadata(chunk_meta) + x_coord, y_coord = chunk_meta["representative_x"], chunk_meta["representative_y"] + repr_val = chunk_meta["representative_val"] + else: + # Failed chunk: Calculate the approximate center using the group's bounds + rows, cols = np.where(self._groups == i) + center_row = (rows.min() + rows.max()) // 2 + center_col = (cols.min() + cols.max()) // 2 + + transform = self._meta["step_meta"][0]["transform"] # Assuming all chunks share a transform + x_coord, y_coord = rio.transform.xy(transform, center_row, center_col) + repr_val = np.nan # No valid Z value for failed chunks + + # Old position based on the calculated or retrieved coordinates + old_pos_arr = np.reshape([x_coord, y_coord, repr_val], (1, 3)) old_position = gpd.GeoDataFrame( geometry=gpd.points_from_xy(x=old_pos_arr[:, 0], y=old_pos_arr[:, 1], crs=None), data={"z": old_pos_arr[:, 2]}, ) - new_position = self.procstep.apply(old_position) - new_pos_arr = np.reshape( - [new_position.geometry.x.values, new_position.geometry.y.values, new_position["z"].values], (1, 3) - ) + if chunk_meta is not None: + # Successful chunk: Apply the transformation + new_position = self.procstep.apply(old_position) + new_pos_arr = np.reshape( + [new_position.geometry.x.values, new_position.geometry.y.values, new_position["z"].values], (1, 3) + ) + else: + # Failed chunk: Keep the new position the same as the old position (no transformation) + new_pos_arr = old_pos_arr.copy() + # Append the result points = np.append(points, np.dstack((old_pos_arr, new_pos_arr)), axis=0) return points @@ -3314,15 +3335,14 @@ def stats(self) -> pd.DataFrame: chunk_meta = {meta["i"]: meta for meta in self.meta["step_meta"]} statistics: list[dict[str, Any]] = [] - cpt_in_chunk = 0 for i in range(points.shape[0]): if i not in chunk_meta: # For missing chunks, return NaN for all stats statistics.append( { - "center_x": np.nan, - "center_y": np.nan, - "center_z": np.nan, + "center_x": points[i, 0, 0], + "center_y": points[i, 1, 0], + "center_z": points[i, 2, 0], "x_off": np.nan, "y_off": np.nan, "z_off": np.nan, @@ -3334,18 +3354,17 @@ def stats(self) -> pd.DataFrame: else: statistics.append( { - "center_x": points[cpt_in_chunk, 0, 0], - "center_y": points[cpt_in_chunk, 1, 0], - "center_z": points[cpt_in_chunk, 2, 0], - "x_off": points[cpt_in_chunk, 0, 1] - points[cpt_in_chunk, 0, 0], - "y_off": points[cpt_in_chunk, 1, 1] - points[cpt_in_chunk, 1, 0], - "z_off": points[cpt_in_chunk, 2, 1] - points[cpt_in_chunk, 2, 0], + "center_x": points[i, 0, 0], + "center_y": points[i, 1, 0], + "center_z": points[i, 2, 0], + "x_off": points[i, 0, 1] - points[i, 0, 0], + "y_off": points[i, 1, 1] - points[i, 1, 0], + "z_off": points[i, 2, 1] - points[i, 2, 0], "inlier_count": chunk_meta[i]["inlier_count"], "nmad": chunk_meta[i]["nmad"], "median": chunk_meta[i]["median"], } ) - cpt_in_chunk += 1 stats_df = pd.DataFrame(statistics) stats_df.index.name = "chunk" @@ -3381,6 +3400,11 @@ def _apply_rst( raise NotImplementedError("Option `resample=False` not supported for coreg method BlockwiseCoreg.") points = self.to_points() + # Check for NaN values across both the old and new positions for each point + mask = ~np.isnan(points).any(axis=(1, 2)) + + # Filter out points where there are no NaN values + points = points[mask] bounds = _bounds(transform=transform, shape=elev.shape) resolution = _res(transform) @@ -3423,6 +3447,12 @@ def _apply_pts( """Apply the scaling model to a set of points.""" points = self.to_points() + # Check for NaN values across both the old and new positions for each point + mask = ~np.isnan(points).any(axis=(1, 2)) + + # Filter out points where there are no NaN values + points = points[mask] + new_coords = np.array([elev.geometry.x.values, elev.geometry.y.values, elev["z"].values]).T for dim in range(0, 3): @@ -3529,7 +3559,7 @@ def warp_dem( order = {"nearest": 0, "linear": 1, "cubic": 3} with warnings.catch_warnings(): - # An skimage warning that will hopefully be fixed soon. (2021-06-08) + # A skimage warning that will hopefully be fixed soon. (2021-06-08) warnings.filterwarnings("ignore", message="Passing `np.nan` to mean no clipping in np.clip") warped = skimage.transform.warp( image=np.where(dem_mask, np.nan, dem_arr),