-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature: new Target InnerDistanceTarget (#90)
- Loading branch information
Showing
5 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from .evaluators import BinarySegmentationEvaluator | ||
from .losses import MSELoss | ||
from .post_processors import ThresholdPostProcessor | ||
from .predictors import InnerDistancePredictor | ||
from .task import Task | ||
|
||
|
||
# Goal is have a distance task but with distance inside the forground only | ||
class InnerDistanceTask(Task): | ||
"""This is just a dummy task for testing.""" | ||
|
||
def __init__(self, task_config): | ||
"""Create a `DummyTask` from a `DummyTaskConfig`.""" | ||
|
||
self.predictor = InnerDistancePredictor( | ||
channels=task_config.channels, | ||
scale_factor=task_config.scale_factor, | ||
) | ||
self.loss = MSELoss() | ||
self.post_processor = ThresholdPostProcessor() | ||
self.evaluator = BinarySegmentationEvaluator( | ||
clip_distance=task_config.clip_distance, | ||
tol_distance=task_config.tol_distance, | ||
channels=task_config.channels, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import attr | ||
|
||
from .inner_distance_task import InnerDistanceTask | ||
from .task_config import TaskConfig | ||
|
||
from typing import List | ||
|
||
|
||
@attr.s | ||
class InnerDistanceTaskConfig(TaskConfig): | ||
"""This is a Distance task config used for generating and | ||
evaluating signed distance transforms as a way of generating | ||
segmentations. | ||
The advantage of generating distance transforms over regular | ||
affinities is you can get a denser signal, i.e. 1 misclassified | ||
pixel in an affinity prediction could merge 2 otherwise very | ||
distinct objects, this cannot happen with distances. | ||
""" | ||
|
||
task_type = InnerDistanceTask | ||
|
||
channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."}) | ||
clip_distance: float = attr.ib( | ||
metadata={ | ||
"help_text": "Maximum distance to consider for false positive/negatives." | ||
}, | ||
) | ||
tol_distance: float = attr.ib( | ||
metadata={ | ||
"help_text": "Tolerance distance for counting false positives/negatives" | ||
}, | ||
) | ||
scale_factor: float = attr.ib( | ||
default=1, | ||
metadata={ | ||
"help_text": "The amount by which to scale distances before applying " | ||
"a tanh normalization." | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
191 changes: 191 additions & 0 deletions
191
dacapo/experiments/tasks/predictors/inner_distance_predictor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
from .predictor import Predictor | ||
from dacapo.experiments import Model | ||
from dacapo.experiments.arraytypes import DistanceArray | ||
from dacapo.experiments.datasplits.datasets.arrays import NumpyArray | ||
from dacapo.utils.balance_weights import balance_weights | ||
|
||
from funlib.geometry import Coordinate | ||
|
||
from scipy.ndimage.morphology import distance_transform_edt | ||
import numpy as np | ||
import torch | ||
|
||
import logging | ||
from typing import List | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class InnerDistancePredictor(Predictor): | ||
""" | ||
Predict signed distances for a binary segmentation task. | ||
Distances deep within background are pushed to -inf, distances deep within | ||
the foreground object are pushed to inf. After distances have been | ||
calculated they are passed through a tanh so that distances saturate at +-1. | ||
Multiple classes can be predicted via multiple distance channels. The names | ||
of each class that is being segmented can be passed in as a list of strings | ||
in the channels argument. | ||
""" | ||
|
||
def __init__(self, channels: List[str], scale_factor: float): | ||
self.channels = channels | ||
self.norm = "tanh" | ||
self.dt_scale_factor = scale_factor | ||
|
||
self.max_distance = 1 * scale_factor | ||
self.epsilon = 5e-2 | ||
self.threshold = 0.8 | ||
|
||
@property | ||
def embedding_dims(self): | ||
return len(self.channels) | ||
|
||
def create_model(self, architecture): | ||
if architecture.dims == 2: | ||
head = torch.nn.Conv2d( | ||
architecture.num_out_channels, self.embedding_dims, kernel_size=1 | ||
) | ||
elif architecture.dims == 3: | ||
head = torch.nn.Conv3d( | ||
architecture.num_out_channels, self.embedding_dims, kernel_size=1 | ||
) | ||
|
||
return Model(architecture, head) | ||
|
||
def create_target(self, gt): | ||
distances = self.process( | ||
gt.data, gt.voxel_size, self.norm, self.dt_scale_factor | ||
) | ||
return NumpyArray.from_np_array( | ||
distances, | ||
gt.roi, | ||
gt.voxel_size, | ||
gt.axes, | ||
) | ||
|
||
def create_weight(self, gt, target, mask, moving_class_counts=None): | ||
# balance weights independently for each channel | ||
|
||
weights, moving_class_counts = balance_weights( | ||
gt[target.roi], | ||
2, | ||
slab=tuple(1 if c == "c" else -1 for c in gt.axes), | ||
masks=[mask[target.roi]], | ||
moving_counts=moving_class_counts, | ||
) | ||
return ( | ||
NumpyArray.from_np_array( | ||
weights, | ||
gt.roi, | ||
gt.voxel_size, | ||
gt.axes, | ||
), | ||
moving_class_counts, | ||
) | ||
|
||
@property | ||
def output_array_type(self): | ||
return DistanceArray(self.embedding_dims) | ||
|
||
def process( | ||
self, | ||
labels: np.ndarray, | ||
voxel_size: Coordinate, | ||
normalize=None, | ||
normalize_args=None, | ||
): | ||
all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 | ||
for ii, channel in enumerate(labels): | ||
boundaries = self.__find_boundaries(channel) | ||
|
||
# mark boundaries with 0 (not 1) | ||
boundaries = 1.0 - boundaries | ||
|
||
if np.sum(boundaries == 0) == 0: | ||
max_distance = min( | ||
dim * vs / 2 for dim, vs in zip(channel.shape, voxel_size) | ||
) | ||
if np.sum(channel) == 0: | ||
distances = -np.ones(channel.shape, dtype=np.float32) * max_distance | ||
else: | ||
distances = np.ones(channel.shape, dtype=np.float32) * max_distance | ||
else: | ||
# get distances (voxel_size/2 because image is doubled) | ||
distances = distance_transform_edt( | ||
boundaries, sampling=tuple(float(v) / 2 for v in voxel_size) | ||
) | ||
distances = distances.astype(np.float32) | ||
|
||
# restore original shape | ||
downsample = (slice(None, None, 2),) * len(voxel_size) | ||
distances = distances[downsample] | ||
|
||
# todo: inverted distance | ||
distances[channel == 0] = -distances[channel == 0] | ||
|
||
if normalize is not None: | ||
distances = self.__normalize(distances, normalize, normalize_args) | ||
|
||
all_distances[ii] = distances | ||
|
||
return all_distances * labels | ||
|
||
def __find_boundaries(self, labels): | ||
# labels: 1 1 1 1 0 0 2 2 2 2 3 3 n | ||
# shift : 1 1 1 1 0 0 2 2 2 2 3 n - 1 | ||
# diff : 0 0 0 1 0 1 0 0 0 1 0 n - 1 | ||
# bound.: 00000001000100000001000 2n - 1 | ||
|
||
logger.debug("computing boundaries for %s", labels.shape) | ||
|
||
dims = len(labels.shape) | ||
in_shape = labels.shape | ||
out_shape = tuple(2 * s - 1 for s in in_shape) | ||
|
||
boundaries = np.zeros(out_shape, dtype=bool) | ||
|
||
logger.debug("boundaries shape is %s", boundaries.shape) | ||
|
||
for d in range(dims): | ||
logger.debug("processing dimension %d", d) | ||
|
||
shift_p = [slice(None)] * dims | ||
shift_p[d] = slice(1, in_shape[d]) | ||
|
||
shift_n = [slice(None)] * dims | ||
shift_n[d] = slice(0, in_shape[d] - 1) | ||
|
||
diff = (labels[tuple(shift_p)] - labels[tuple(shift_n)]) != 0 | ||
|
||
logger.debug("diff shape is %s", diff.shape) | ||
|
||
target = [slice(None, None, 2)] * dims | ||
target[d] = slice(1, out_shape[d], 2) | ||
|
||
logger.debug("target slices are %s", target) | ||
|
||
boundaries[tuple(target)] = diff | ||
|
||
return boundaries | ||
|
||
def __normalize(self, distances, norm, normalize_args): | ||
if norm == "tanh": | ||
scale = normalize_args | ||
return np.tanh(distances / scale) | ||
else: | ||
raise ValueError("Only tanh is supported for normalization") | ||
|
||
def gt_region_for_roi(self, target_spec): | ||
if self.mask_distances: | ||
gt_spec = target_spec.copy() | ||
gt_spec.roi = gt_spec.roi.grow( | ||
Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), | ||
Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), | ||
).snap_to_grid(gt_spec.voxel_size, mode="shrink") | ||
else: | ||
gt_spec = target_spec.copy() | ||
return gt_spec | ||
|
||
def padding(self, gt_voxel_size: Coordinate) -> Coordinate: | ||
return Coordinate((self.max_distance,) * gt_voxel_size.dims) |