diff --git a/satflow/data/datasets.py b/satflow/data/datasets.py index a96e6e4b..40f30d44 100644 --- a/satflow/data/datasets.py +++ b/satflow/data/datasets.py @@ -630,7 +630,9 @@ def __iter__(self) -> Iterator[T_co]: if self.add_pixel_coords: # Add channels for pixel_coords, once per channel, or once per stack, dependent if self.time_as_channels: - mask = np.concatenate([mask, self.pixel_coords], axis=0) + mask = np.concatenate( + [mask, np.moveaxis(self.pixel_coords, [2], [0])], axis=0 + ) else: mask = np.concatenate([mask, self.pixel_coords], axis=1) logger.debug(f"Mask: {mask.shape} Target: {target_mask.shape}") @@ -707,9 +709,16 @@ def __iter__(self): # Move channel to Time x Channel x W x H image = np.moveaxis(image, [2], [0]) if self.add_pixel_coords: - # Add channels for pixel_coords - image = np.concatenate([image, self.pixel_coords], axis=0) - yield np.nan_to_num(image), np.nan_to_num(target_mask) + # Add channels for pixel_coords, once per channel, or once per stack, dependent + if self.time_as_channels: + image = np.concatenate( + [image, np.moveaxis(self.pixel_coords, [2], [0])], axis=0 + ) + else: + image = np.concatenate([image, self.pixel_coords], axis=1) + image = np.nan_to_num(image, posinf=0.0, neginf=0.0) + target_mask = np.nan_to_num(target_mask, posinf=0, neginf=0) + yield image, target_mask def crop_center(img, cropx, cropy): diff --git a/satflow/models/losses.py b/satflow/models/losses.py index 8259cc63..4ff023bb 100644 --- a/satflow/models/losses.py +++ b/satflow/models/losses.py @@ -98,3 +98,63 @@ def forward(self, logit, target): else: loss = loss.sum() return loss + + +def _unbind_images(x, dim=1): + "only unstack images" + if isinstance(x, torch.Tensor): + if len(x.shape) >= 4: + return x.unbind(dim=dim) + return x + + +class StackUnstack(nn.Module): + "Stack together inputs, apply module, unstack output" + + def __init__(self, module, dim=1): + super().__init__() + self.dim = dim + self.module = module + + @staticmethod + def unbind_images(x, dim=1): + return _unbind_images(x, dim) + + def forward(self, *args): + inputs = [torch.stack(x, dim=self.dim) for x in args] + outputs = self.module(*inputs) + if isinstance(outputs, (tuple, list)): + return [self.unbind_images(output, dim=self.dim) for output in outputs] + else: + return outputs.unbind(dim=self.dim) + + +def StackLoss(loss_func=F.mse_loss, axis=-1): + def _inner_loss(x, y): + x = torch.cat(x, axis) + y = torch.cat(y, axis) + return loss_func(x, y) + + return _inner_loss + + +class MultiImageDice: + "Dice coefficient metric for binary target in segmentation" + + def __init__(self, axis=1): + self.axis = axis + + def reset(self): + self.inter, self.union = 0, 0 + + def accumulate(self, pred, y): + x = torch.cat(pred, -1) + y = torch.cat(y, -1) + pred = x.argmax(dim=self.axis).flatten() + targ = np.flatten(y) + self.inter += (pred * targ).float().sum().item() + self.union += (pred + targ).float().sum().item() + + @property + def value(self): + return 2.0 * self.inter / self.union if self.union > 0 else None