Skip to content

Commit

Permalink
Feature: new Target InnerDistanceTarget (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Feb 14, 2024
2 parents c956257 + 58bbaf4 commit 39915ac
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 0 deletions.
4 changes: 4 additions & 0 deletions dacapo/experiments/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@
from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa
from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa
from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa
from .inner_distance_task_config import (
InnerDistanceTaskConfig,
InnerDistanceTask,
) # noqa
from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa
25 changes: 25 additions & 0 deletions dacapo/experiments/tasks/inner_distance_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from .evaluators import BinarySegmentationEvaluator
from .losses import MSELoss
from .post_processors import ThresholdPostProcessor
from .predictors import InnerDistancePredictor
from .task import Task


# Goal is have a distance task but with distance inside the forground only
class InnerDistanceTask(Task):
"""This is just a dummy task for testing."""

def __init__(self, task_config):
"""Create a `DummyTask` from a `DummyTaskConfig`."""

self.predictor = InnerDistancePredictor(
channels=task_config.channels,
scale_factor=task_config.scale_factor,
)
self.loss = MSELoss()
self.post_processor = ThresholdPostProcessor()
self.evaluator = BinarySegmentationEvaluator(
clip_distance=task_config.clip_distance,
tol_distance=task_config.tol_distance,
channels=task_config.channels,
)
40 changes: 40 additions & 0 deletions dacapo/experiments/tasks/inner_distance_task_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import attr

from .inner_distance_task import InnerDistanceTask
from .task_config import TaskConfig

from typing import List


@attr.s
class InnerDistanceTaskConfig(TaskConfig):
"""This is a Distance task config used for generating and
evaluating signed distance transforms as a way of generating
segmentations.
The advantage of generating distance transforms over regular
affinities is you can get a denser signal, i.e. 1 misclassified
pixel in an affinity prediction could merge 2 otherwise very
distinct objects, this cannot happen with distances.
"""

task_type = InnerDistanceTask

channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."})
clip_distance: float = attr.ib(
metadata={
"help_text": "Maximum distance to consider for false positive/negatives."
},
)
tol_distance: float = attr.ib(
metadata={
"help_text": "Tolerance distance for counting false positives/negatives"
},
)
scale_factor: float = attr.ib(
default=1,
metadata={
"help_text": "The amount by which to scale distances before applying "
"a tanh normalization."
},
)
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .one_hot_predictor import OneHotPredictor # noqa
from .predictor import Predictor # noqa
from .affinities_predictor import AffinitiesPredictor # noqa
from .inner_distance_predictor import InnerDistancePredictor # noqa
from .hot_distance_predictor import HotDistancePredictor # noqa
191 changes: 191 additions & 0 deletions dacapo/experiments/tasks/predictors/inner_distance_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from .predictor import Predictor
from dacapo.experiments import Model
from dacapo.experiments.arraytypes import DistanceArray
from dacapo.experiments.datasplits.datasets.arrays import NumpyArray
from dacapo.utils.balance_weights import balance_weights

from funlib.geometry import Coordinate

from scipy.ndimage.morphology import distance_transform_edt
import numpy as np
import torch

import logging
from typing import List

logger = logging.getLogger(__name__)


class InnerDistancePredictor(Predictor):
"""
Predict signed distances for a binary segmentation task.
Distances deep within background are pushed to -inf, distances deep within
the foreground object are pushed to inf. After distances have been
calculated they are passed through a tanh so that distances saturate at +-1.
Multiple classes can be predicted via multiple distance channels. The names
of each class that is being segmented can be passed in as a list of strings
in the channels argument.
"""

def __init__(self, channels: List[str], scale_factor: float):
self.channels = channels
self.norm = "tanh"
self.dt_scale_factor = scale_factor

self.max_distance = 1 * scale_factor
self.epsilon = 5e-2
self.threshold = 0.8

@property
def embedding_dims(self):
return len(self.channels)

def create_model(self, architecture):
if architecture.dims == 2:
head = torch.nn.Conv2d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)
elif architecture.dims == 3:
head = torch.nn.Conv3d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)

return Model(architecture, head)

def create_target(self, gt):
distances = self.process(
gt.data, gt.voxel_size, self.norm, self.dt_scale_factor
)
return NumpyArray.from_np_array(
distances,
gt.roi,
gt.voxel_size,
gt.axes,
)

def create_weight(self, gt, target, mask, moving_class_counts=None):
# balance weights independently for each channel

weights, moving_class_counts = balance_weights(
gt[target.roi],
2,
slab=tuple(1 if c == "c" else -1 for c in gt.axes),
masks=[mask[target.roi]],
moving_counts=moving_class_counts,
)
return (
NumpyArray.from_np_array(
weights,
gt.roi,
gt.voxel_size,
gt.axes,
),
moving_class_counts,
)

@property
def output_array_type(self):
return DistanceArray(self.embedding_dims)

def process(
self,
labels: np.ndarray,
voxel_size: Coordinate,
normalize=None,
normalize_args=None,
):
all_distances = np.zeros(labels.shape, dtype=np.float32) - 1
for ii, channel in enumerate(labels):
boundaries = self.__find_boundaries(channel)

# mark boundaries with 0 (not 1)
boundaries = 1.0 - boundaries

if np.sum(boundaries == 0) == 0:
max_distance = min(
dim * vs / 2 for dim, vs in zip(channel.shape, voxel_size)
)
if np.sum(channel) == 0:
distances = -np.ones(channel.shape, dtype=np.float32) * max_distance
else:
distances = np.ones(channel.shape, dtype=np.float32) * max_distance
else:
# get distances (voxel_size/2 because image is doubled)
distances = distance_transform_edt(
boundaries, sampling=tuple(float(v) / 2 for v in voxel_size)
)
distances = distances.astype(np.float32)

# restore original shape
downsample = (slice(None, None, 2),) * len(voxel_size)
distances = distances[downsample]

# todo: inverted distance
distances[channel == 0] = -distances[channel == 0]

if normalize is not None:
distances = self.__normalize(distances, normalize, normalize_args)

all_distances[ii] = distances

return all_distances * labels

def __find_boundaries(self, labels):
# labels: 1 1 1 1 0 0 2 2 2 2 3 3 n
# shift : 1 1 1 1 0 0 2 2 2 2 3 n - 1
# diff : 0 0 0 1 0 1 0 0 0 1 0 n - 1
# bound.: 00000001000100000001000 2n - 1

logger.debug("computing boundaries for %s", labels.shape)

dims = len(labels.shape)
in_shape = labels.shape
out_shape = tuple(2 * s - 1 for s in in_shape)

boundaries = np.zeros(out_shape, dtype=bool)

logger.debug("boundaries shape is %s", boundaries.shape)

for d in range(dims):
logger.debug("processing dimension %d", d)

shift_p = [slice(None)] * dims
shift_p[d] = slice(1, in_shape[d])

shift_n = [slice(None)] * dims
shift_n[d] = slice(0, in_shape[d] - 1)

diff = (labels[tuple(shift_p)] - labels[tuple(shift_n)]) != 0

logger.debug("diff shape is %s", diff.shape)

target = [slice(None, None, 2)] * dims
target[d] = slice(1, out_shape[d], 2)

logger.debug("target slices are %s", target)

boundaries[tuple(target)] = diff

return boundaries

def __normalize(self, distances, norm, normalize_args):
if norm == "tanh":
scale = normalize_args
return np.tanh(distances / scale)
else:
raise ValueError("Only tanh is supported for normalization")

def gt_region_for_roi(self, target_spec):
if self.mask_distances:
gt_spec = target_spec.copy()
gt_spec.roi = gt_spec.roi.grow(
Coordinate((self.max_distance,) * gt_spec.voxel_size.dims),
Coordinate((self.max_distance,) * gt_spec.voxel_size.dims),
).snap_to_grid(gt_spec.voxel_size, mode="shrink")
else:
gt_spec = target_spec.copy()
return gt_spec

def padding(self, gt_voxel_size: Coordinate) -> Coordinate:
return Coordinate((self.max_distance,) * gt_voxel_size.dims)

0 comments on commit 39915ac

Please sign in to comment.