Skip to content

Commit

Permalink
fix: retrieve center coordinates of failing chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
vschaffn committed Dec 5, 2024
1 parent 260bbdd commit 936e20b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 33 deletions.
13 changes: 5 additions & 8 deletions tests/test_coreg/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,11 +925,8 @@ def test_blockwise_coreg_large_gaps(self) -> None:

stats = blockwise.stats()

# We expect holes in the blockwise coregistration, so there should not be 64 "successful" blocks.
assert stats.shape[0] < 64

# Statistics are only calculated on finite values, so all of these should be finite as well.
assert np.all(np.isfinite(stats) | np.isnan(stats))
# We expect holes in the blockwise coregistration, but not in stats due to nan padding for failing chunks
assert stats.shape[0] == 64

# Copy the TBA DEM and set a square portion to nodata
tba = self.tba.copy()
Expand All @@ -939,7 +936,7 @@ def test_blockwise_coreg_large_gaps(self) -> None:

blockwise = xdem.coreg.BlockwiseCoreg(xdem.coreg.NuthKaab(), 8, warn_failures=False)

# Align the DEM and apply the blockwise to a zero-array (to get the zshift)
# Align the DEM and apply blockwise to a zero-array (to get the z_shift)
aligned = blockwise.fit(self.ref, tba).apply(tba)
zshift, _ = blockwise.apply(np.zeros_like(tba.data), transform=tba.transform, crs=tba.crs)

Expand All @@ -965,8 +962,8 @@ def test_failed_chunks_return_nan(self) -> None:
assert np.isnan(result_df.loc[1, "inlier_count"])
assert np.isnan(result_df.loc[1, "nmad"])
assert np.isnan(result_df.loc[1, "median"])
assert np.isnan(result_df.loc[1, "center_x"])
assert np.isnan(result_df.loc[1, "center_y"])
assert isinstance(result_df.loc[1, "center_x"], float)
assert isinstance(result_df.loc[1, "center_y"], float)
assert np.isnan(result_df.loc[1, "center_z"])
assert np.isnan(result_df.loc[1, "x_off"])
assert np.isnan(result_df.loc[1, "y_off"])
Expand Down
80 changes: 55 additions & 25 deletions xdem/coreg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3036,6 +3036,7 @@ def __init__(
super().__init__()

self._meta: CoregDict = {"step_meta": []}
self._groups: NDArrayf = np.array([])

def fit(
self: CoregType,
Expand Down Expand Up @@ -3091,9 +3092,9 @@ def fit(
else:
mask = inlier_mask

groups = self.subdivide_array(tba_dem.shape if isinstance(tba_dem, np.ndarray) else ref_dem.shape)
self._groups = self.subdivide_array(tba_dem.shape if isinstance(tba_dem, np.ndarray) else ref_dem.shape)

indices = np.unique(groups)
indices = np.unique(self._groups)

progress_bar = tqdm(
total=indices.size, desc="Processing chunks", disable=logging.getLogger().getEffectiveLevel() > logging.INFO
Expand All @@ -3108,7 +3109,7 @@ def process(i: int) -> dict[str, Any] | BaseException | None:
* If it fails: The associated exception.
* If the block is empty: None
"""
group_mask = groups == i
group_mask = self._groups == i

# Find the corresponding slice of the inlier_mask to subset the data
rows, cols = np.where(group_mask)
Expand Down Expand Up @@ -3272,24 +3273,44 @@ def to_points(self) -> NDArrayf:
if len(self._meta["step_meta"]) == 0:
raise AssertionError("No coreg results exist. Has '.fit()' been called?")
points = np.empty(shape=(0, 3, 2))
for meta in self._meta["step_meta"]:
self._restore_metadata(meta)

# x_coord, y_coord = rio.transform.xy(meta["transform"], meta["representative_row"],
# meta["representative_col"])
x_coord, y_coord = meta["representative_x"], meta["representative_y"]
for i in range(self.subdivision):
# Try to restore the metadata for this chunk (if it succeeded)
chunk_meta = next((meta for meta in self._meta["step_meta"] if meta["i"] == i), None)

old_pos_arr = np.reshape([x_coord, y_coord, meta["representative_val"]], (1, 3))
if chunk_meta is not None:
# Successful chunk: Retrieve the representative X, Y, Z coordinates
self._restore_metadata(chunk_meta)
x_coord, y_coord = chunk_meta["representative_x"], chunk_meta["representative_y"]
repr_val = chunk_meta["representative_val"]
else:
# Failed chunk: Calculate the approximate center using the group's bounds
rows, cols = np.where(self._groups == i)
center_row = (rows.min() + rows.max()) // 2
center_col = (cols.min() + cols.max()) // 2

transform = self._meta["step_meta"][0]["transform"] # Assuming all chunks share a transform
x_coord, y_coord = rio.transform.xy(transform, center_row, center_col)
repr_val = np.nan # No valid Z value for failed chunks

# Old position based on the calculated or retrieved coordinates
old_pos_arr = np.reshape([x_coord, y_coord, repr_val], (1, 3))
old_position = gpd.GeoDataFrame(
geometry=gpd.points_from_xy(x=old_pos_arr[:, 0], y=old_pos_arr[:, 1], crs=None),
data={"z": old_pos_arr[:, 2]},
)

new_position = self.procstep.apply(old_position)
new_pos_arr = np.reshape(
[new_position.geometry.x.values, new_position.geometry.y.values, new_position["z"].values], (1, 3)
)
if chunk_meta is not None:
# Successful chunk: Apply the transformation
new_position = self.procstep.apply(old_position)
new_pos_arr = np.reshape(
[new_position.geometry.x.values, new_position.geometry.y.values, new_position["z"].values], (1, 3)
)
else:
# Failed chunk: Keep the new position the same as the old position (no transformation)
new_pos_arr = old_pos_arr.copy()

# Append the result
points = np.append(points, np.dstack((old_pos_arr, new_pos_arr)), axis=0)

return points
Expand All @@ -3314,15 +3335,14 @@ def stats(self) -> pd.DataFrame:
chunk_meta = {meta["i"]: meta for meta in self.meta["step_meta"]}

statistics: list[dict[str, Any]] = []
cpt_in_chunk = 0
for i in range(points.shape[0]):
if i not in chunk_meta:
# For missing chunks, return NaN for all stats
statistics.append(
{
"center_x": np.nan,
"center_y": np.nan,
"center_z": np.nan,
"center_x": points[i, 0, 0],
"center_y": points[i, 1, 0],
"center_z": points[i, 2, 0],
"x_off": np.nan,
"y_off": np.nan,
"z_off": np.nan,
Expand All @@ -3334,18 +3354,17 @@ def stats(self) -> pd.DataFrame:
else:
statistics.append(
{
"center_x": points[cpt_in_chunk, 0, 0],
"center_y": points[cpt_in_chunk, 1, 0],
"center_z": points[cpt_in_chunk, 2, 0],
"x_off": points[cpt_in_chunk, 0, 1] - points[cpt_in_chunk, 0, 0],
"y_off": points[cpt_in_chunk, 1, 1] - points[cpt_in_chunk, 1, 0],
"z_off": points[cpt_in_chunk, 2, 1] - points[cpt_in_chunk, 2, 0],
"center_x": points[i, 0, 0],
"center_y": points[i, 1, 0],
"center_z": points[i, 2, 0],
"x_off": points[i, 0, 1] - points[i, 0, 0],
"y_off": points[i, 1, 1] - points[i, 1, 0],
"z_off": points[i, 2, 1] - points[i, 2, 0],
"inlier_count": chunk_meta[i]["inlier_count"],
"nmad": chunk_meta[i]["nmad"],
"median": chunk_meta[i]["median"],
}
)
cpt_in_chunk += 1

stats_df = pd.DataFrame(statistics)
stats_df.index.name = "chunk"
Expand Down Expand Up @@ -3381,6 +3400,11 @@ def _apply_rst(
raise NotImplementedError("Option `resample=False` not supported for coreg method BlockwiseCoreg.")

points = self.to_points()
# Check for NaN values across both the old and new positions for each point
mask = ~np.isnan(points).any(axis=(1, 2))

# Filter out points where there are no NaN values
points = points[mask]

bounds = _bounds(transform=transform, shape=elev.shape)
resolution = _res(transform)
Expand Down Expand Up @@ -3423,6 +3447,12 @@ def _apply_pts(
"""Apply the scaling model to a set of points."""
points = self.to_points()

# Check for NaN values across both the old and new positions for each point
mask = ~np.isnan(points).any(axis=(1, 2))

# Filter out points where there are no NaN values
points = points[mask]

new_coords = np.array([elev.geometry.x.values, elev.geometry.y.values, elev["z"].values]).T

for dim in range(0, 3):
Expand Down Expand Up @@ -3529,7 +3559,7 @@ def warp_dem(
order = {"nearest": 0, "linear": 1, "cubic": 3}

with warnings.catch_warnings():
# An skimage warning that will hopefully be fixed soon. (2021-06-08)
# A skimage warning that will hopefully be fixed soon. (2021-06-08)
warnings.filterwarnings("ignore", message="Passing `np.nan` to mean no clipping in np.clip")
warped = skimage.transform.warp(
image=np.where(dem_mask, np.nan, dem_arr),
Expand Down

0 comments on commit 936e20b

Please sign in to comment.