Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Commit

Permalink
Add new loss
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Jul 9, 2021
1 parent 11d9297 commit 35dd655
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 4 deletions.
17 changes: 13 additions & 4 deletions satflow/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,9 @@ def __iter__(self) -> Iterator[T_co]:
if self.add_pixel_coords:
# Add channels for pixel_coords, once per channel, or once per stack, dependent
if self.time_as_channels:
mask = np.concatenate([mask, self.pixel_coords], axis=0)
mask = np.concatenate(
[mask, np.moveaxis(self.pixel_coords, [2], [0])], axis=0
)
else:
mask = np.concatenate([mask, self.pixel_coords], axis=1)
logger.debug(f"Mask: {mask.shape} Target: {target_mask.shape}")
Expand Down Expand Up @@ -707,9 +709,16 @@ def __iter__(self):
# Move channel to Time x Channel x W x H
image = np.moveaxis(image, [2], [0])
if self.add_pixel_coords:
# Add channels for pixel_coords
image = np.concatenate([image, self.pixel_coords], axis=0)
yield np.nan_to_num(image), np.nan_to_num(target_mask)
# Add channels for pixel_coords, once per channel, or once per stack, dependent
if self.time_as_channels:
image = np.concatenate(
[image, np.moveaxis(self.pixel_coords, [2], [0])], axis=0
)
else:
image = np.concatenate([image, self.pixel_coords], axis=1)
image = np.nan_to_num(image, posinf=0.0, neginf=0.0)
target_mask = np.nan_to_num(target_mask, posinf=0, neginf=0)
yield image, target_mask


def crop_center(img, cropx, cropy):
Expand Down
60 changes: 60 additions & 0 deletions satflow/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,63 @@ def forward(self, logit, target):
else:
loss = loss.sum()
return loss


def _unbind_images(x, dim=1):
"only unstack images"
if isinstance(x, torch.Tensor):
if len(x.shape) >= 4:
return x.unbind(dim=dim)
return x


class StackUnstack(nn.Module):
"Stack together inputs, apply module, unstack output"

def __init__(self, module, dim=1):
super().__init__()
self.dim = dim
self.module = module

@staticmethod
def unbind_images(x, dim=1):
return _unbind_images(x, dim)

def forward(self, *args):
inputs = [torch.stack(x, dim=self.dim) for x in args]
outputs = self.module(*inputs)
if isinstance(outputs, (tuple, list)):
return [self.unbind_images(output, dim=self.dim) for output in outputs]
else:
return outputs.unbind(dim=self.dim)


def StackLoss(loss_func=F.mse_loss, axis=-1):
def _inner_loss(x, y):
x = torch.cat(x, axis)
y = torch.cat(y, axis)
return loss_func(x, y)

return _inner_loss


class MultiImageDice:
"Dice coefficient metric for binary target in segmentation"

def __init__(self, axis=1):
self.axis = axis

def reset(self):
self.inter, self.union = 0, 0

def accumulate(self, pred, y):
x = torch.cat(pred, -1)
y = torch.cat(y, -1)
pred = x.argmax(dim=self.axis).flatten()
targ = np.flatten(y)
self.inter += (pred * targ).float().sum().item()
self.union += (pred + targ).float().sum().item()

@property
def value(self):
return 2.0 * self.inter / self.union if self.union > 0 else None

0 comments on commit 35dd655

Please sign in to comment.