Skip to content

Commit

Permalink
Dev/balancing (#70)
Browse files Browse the repository at this point in the history
Correct balancing of class weights across datasets. Add clipmin and
clipmax weight configurations for Affinity and Distance task
configurations.
  • Loading branch information
rhoadesScholar authored Feb 14, 2024
2 parents 06bf991 + 3d3a7ba commit fbb18fe
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 40 deletions.
7 changes: 6 additions & 1 deletion dacapo/experiments/tasks/affinities_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ def __init__(self, task_config):
"""Create a `DummyTask` from a `DummyTaskConfig`."""

self.predictor = AffinitiesPredictor(
neighborhood=task_config.neighborhood, lsds=task_config.lsds
neighborhood=task_config.neighborhood,
lsds=task_config.lsds,
affs_weight_clipmin=task_config.affs_weight_clipmin,
affs_weight_clipmax=task_config.affs_weight_clipmax,
lsd_weight_clipmin=task_config.lsd_weight_clipmin,
lsd_weight_clipmax=task_config.lsd_weight_clipmax,
)
self.loss = AffinitiesLoss(
len(task_config.neighborhood), task_config.lsds_to_affs_weight_ratio
Expand Down
16 changes: 16 additions & 0 deletions dacapo/experiments/tasks/affinities_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,19 @@ class AffinitiesTaskConfig(TaskConfig):
"help_text": "If training with lsds, set how much they should be weighted compared to affs."
},
)
affs_weight_clipmin: float = attr.ib(
default=0.05,
metadata={"help_text": "The minimum value for affinities weights."},
)
affs_weight_clipmax: float = attr.ib(
default=0.95,
metadata={"help_text": "The maximum value for affinities weights."},
)
lsd_weight_clipmin: float = attr.ib(
default=0.05,
metadata={"help_text": "The minimum value for lsds weights."},
)
lsd_weight_clipmax: float = attr.ib(
default=0.95,
metadata={"help_text": "The maximum value for lsds weights."},
)
2 changes: 2 additions & 0 deletions dacapo/experiments/tasks/distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def __init__(self, task_config):
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
clipmin=task_config.clipmin,
clipmax=task_config.clipmax,
)
self.loss = MSELoss()
self.post_processor = ThresholdPostProcessor()
Expand Down
8 changes: 8 additions & 0 deletions dacapo/experiments/tasks/distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,11 @@ class DistanceTaskConfig(TaskConfig):
"is less than the distance to object boundary."
},
)
clipmin: float = attr.ib(
default=0.05,
metadata={"help_text": "The minimum value for distance weights."},
)
clipmax: float = attr.ib(
default=0.95,
metadata={"help_text": "The maximum value for distance weights."},
)
12 changes: 12 additions & 0 deletions dacapo/experiments/tasks/predictors/affinities_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def __init__(
num_voxels: int = 20,
downsample_lsds: int = 1,
grow_boundary_iterations: int = 0,
affs_weight_clipmin: float = 0.05,
affs_weight_clipmax: float = 0.95,
lsd_weight_clipmin: float = 0.05,
lsd_weight_clipmax: float = 0.95,
):
self.neighborhood = neighborhood
self.lsds = lsds
Expand All @@ -42,6 +46,10 @@ def __init__(
else:
self.num_lsds = 0
self.grow_boundary_iterations = grow_boundary_iterations
self.affs_weight_clipmin = affs_weight_clipmin
self.affs_weight_clipmax = affs_weight_clipmax
self.lsd_weight_clipmin = lsd_weight_clipmin
self.lsd_weight_clipmax = lsd_weight_clipmax

def extractor(self, voxel_size):
if self._extractor is None:
Expand Down Expand Up @@ -155,6 +163,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
slab=tuple(1 if c == "c" else -1 for c in target.axes),
masks=[mask_data],
moving_counts=moving_class_counts,
clipmin=self.affs_weight_clipmin,
clipmax=self.affs_weight_clipmax,
)
if self.lsds:
lsd_weights, moving_lsd_class_counts = balance_weights(
Expand All @@ -163,6 +173,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
slab=(-1,) * len(gt.axes),
masks=[mask_data],
moving_counts=moving_lsd_class_counts,
clipmin=self.lsd_weight_clipmin,
clipmax=self.lsd_weight_clipmax,
)
lsd_weights = np.ones(
(self.num_lsds,) + aff_weights.shape[1:], dtype=aff_weights.dtype
Expand Down
13 changes: 12 additions & 1 deletion dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ class DistancePredictor(Predictor):
in the channels argument.
"""

def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool):
def __init__(
self,
channels: List[str],
scale_factor: float,
mask_distances: bool,
clipmin: float = 0.05,
clipmax: float = 0.95,
):
self.channels = channels
self.norm = "tanh"
self.dt_scale_factor = scale_factor
Expand All @@ -36,6 +43,8 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo
self.max_distance = 1 * scale_factor
self.epsilon = 5e-2
self.threshold = 0.8
self.clipmin = clipmin
self.clipmax = clipmax

@property
def embedding_dims(self):
Expand Down Expand Up @@ -83,6 +92,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
slab=tuple(1 if c == "c" else -1 for c in gt.axes),
masks=[mask[target.roi], distance_mask],
moving_counts=moving_class_counts,
clipmin=self.clipmin,
clipmax=self.clipmax,
)
return (
NumpyArray.from_np_array(
Expand Down
43 changes: 12 additions & 31 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ def __init__(self, trainer_config):
self.mask_integral_downsample_factor = 4
self.clip_raw = trainer_config.clip_raw

# Testing out if calculating multiple times and multiplying is necessary
self.add_predictor_nodes_to_dataset = (
trainer_config.add_predictor_nodes_to_dataset
)

self.scheduler = None

def create_optimizer(self, model):
Expand Down Expand Up @@ -85,8 +80,6 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER")

target_key = gp.ArrayKey("TARGET")
dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT")
datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT")
weight_key = gp.ArrayKey("WEIGHT")
sample_points_key = gp.GraphKey("SAMPLE_POINTS")

Expand Down Expand Up @@ -137,12 +130,12 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
+ gp.Pad(gt_key, None)
+ gp.Pad(mask_key, None)
+ gp.RandomLocation(
ensure_nonempty=sample_points_key
if points_source is not None
else None,
ensure_centered=sample_points_key
if points_source is not None
else None,
ensure_nonempty=(
sample_points_key if points_source is not None else None
),
ensure_centered=(
sample_points_key if points_source is not None else None
),
)
)

Expand All @@ -151,15 +144,6 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
for augment in self.augments:
dataset_source += augment.node(raw_key, gt_key, mask_key)

if self.add_predictor_nodes_to_dataset:
# Add predictor nodes to dataset_source
dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=dataset_weight_key,
mask_key=mask_key,
)

dataset_sources.append(dataset_source)
pipeline = tuple(dataset_sources) + gp.RandomProvider(weights)

Expand All @@ -168,15 +152,10 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
task.predictor,
gt_key=gt_key,
target_key=target_key,
weights_key=datasets_weight_key
if self.add_predictor_nodes_to_dataset
else weight_key,
weights_key=weight_key,
mask_key=mask_key,
)

if self.add_predictor_nodes_to_dataset:
pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key)

# Trainer attributes:
if self.num_data_fetchers > 1:
pipeline += gp.PreCache(num_workers=self.num_data_fetchers)
Expand Down Expand Up @@ -332,9 +311,11 @@ def next(self):
NumpyArray.from_gp_array(batch[self._gt_key]),
NumpyArray.from_gp_array(batch[self._target_key]),
NumpyArray.from_gp_array(batch[self._weight_key]),
NumpyArray.from_gp_array(batch[self._mask_key])
if self._mask_key is not None
else None,
(
NumpyArray.from_gp_array(batch[self._mask_key])
if self._mask_key is not None
else None
),
)

def __enter__(self):
Expand Down
7 changes: 0 additions & 7 deletions dacapo/experiments/trainers/gunpowder_trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,3 @@ class GunpowderTrainerConfig(TrainerConfig):
)
min_masked: Optional[float] = attr.ib(default=0.15)
clip_raw: bool = attr.ib(default=True)

add_predictor_nodes_to_dataset: Optional[bool] = attr.ib(
default=True,
metadata={
"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"
},
)

0 comments on commit fbb18fe

Please sign in to comment.