From 999cf08311ca422905a32d32b50c8a40557b57da Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 13 Feb 2024 09:45:54 -0500 Subject: [PATCH 1/2] =?UTF-8?q?fix:=20=E2=9A=A1=EF=B8=8F=20Correct=20class?= =?UTF-8?q?=20weight=20balancing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../experiments/trainers/gunpowder_trainer.py | 43 ++++++------------- .../trainers/gunpowder_trainer_config.py | 7 --- 2 files changed, 12 insertions(+), 38 deletions(-) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index f5d8fcd52..6e06504c0 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -42,11 +42,6 @@ def __init__(self, trainer_config): self.mask_integral_downsample_factor = 4 self.clip_raw = trainer_config.clip_raw - # Testing out if calculating multiple times and multiplying is necessary - self.add_predictor_nodes_to_dataset = ( - trainer_config.add_predictor_nodes_to_dataset - ) - self.scheduler = None def create_optimizer(self, model): @@ -85,8 +80,6 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER") target_key = gp.ArrayKey("TARGET") - dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") - datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT") weight_key = gp.ArrayKey("WEIGHT") sample_points_key = gp.GraphKey("SAMPLE_POINTS") @@ -137,12 +130,12 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): + gp.Pad(gt_key, None, 0) + gp.Pad(mask_key, None, 0) + gp.RandomLocation( - ensure_nonempty=sample_points_key - if points_source is not None - else None, - ensure_centered=sample_points_key - if points_source is not None - else None, + ensure_nonempty=( + sample_points_key if points_source is not None else None + ), + ensure_centered=( + sample_points_key if points_source is not None else None + ), ) ) @@ -151,15 +144,6 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): for augment in self.augments: dataset_source += augment.node(raw_key, gt_key, mask_key) - if self.add_predictor_nodes_to_dataset: - # Add predictor nodes to dataset_source - dataset_source += DaCapoTargetFilter( - task.predictor, - gt_key=gt_key, - weights_key=dataset_weight_key, - mask_key=mask_key, - ) - dataset_sources.append(dataset_source) pipeline = tuple(dataset_sources) + gp.RandomProvider(weights) @@ -168,15 +152,10 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): task.predictor, gt_key=gt_key, target_key=target_key, - weights_key=datasets_weight_key - if self.add_predictor_nodes_to_dataset - else weight_key, + weights_key=weight_key, mask_key=mask_key, ) - if self.add_predictor_nodes_to_dataset: - pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key) - # Trainer attributes: if self.num_data_fetchers > 1: pipeline += gp.PreCache(num_workers=self.num_data_fetchers) @@ -332,9 +311,11 @@ def next(self): NumpyArray.from_gp_array(batch[self._gt_key]), NumpyArray.from_gp_array(batch[self._target_key]), NumpyArray.from_gp_array(batch[self._weight_key]), - NumpyArray.from_gp_array(batch[self._mask_key]) - if self._mask_key is not None - else None, + ( + NumpyArray.from_gp_array(batch[self._mask_key]) + if self._mask_key is not None + else None + ), ) def __enter__(self): diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index 539e3c5e1..ae4243059 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -29,10 +29,3 @@ class GunpowderTrainerConfig(TrainerConfig): ) min_masked: Optional[float] = attr.ib(default=0.15) clip_raw: bool = attr.ib(default=True) - - add_predictor_nodes_to_dataset: Optional[bool] = attr.ib( - default=True, - metadata={ - "help_text": "Whether to add a predictor node to dataset_source and apply product of weights" - }, - ) From 71bb73e72fceb7b6d8386c554ff0ce42056ebfbf Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 13 Feb 2024 10:05:59 -0500 Subject: [PATCH 2/2] =?UTF-8?q?feat:=20=E2=9C=A8=20Allow=20configuration?= =?UTF-8?q?=20of=20weight=20clipping=20from=20task=20config.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/experiments/tasks/affinities_task.py | 7 ++++++- .../experiments/tasks/affinities_task_config.py | 16 ++++++++++++++++ dacapo/experiments/tasks/distance_task.py | 2 ++ dacapo/experiments/tasks/distance_task_config.py | 8 ++++++++ .../tasks/predictors/affinities_predictor.py | 12 ++++++++++++ .../tasks/predictors/distance_predictor.py | 13 ++++++++++++- 6 files changed, 56 insertions(+), 2 deletions(-) diff --git a/dacapo/experiments/tasks/affinities_task.py b/dacapo/experiments/tasks/affinities_task.py index 859494e7e..5341da8c6 100644 --- a/dacapo/experiments/tasks/affinities_task.py +++ b/dacapo/experiments/tasks/affinities_task.py @@ -12,7 +12,12 @@ def __init__(self, task_config): """Create a `DummyTask` from a `DummyTaskConfig`.""" self.predictor = AffinitiesPredictor( - neighborhood=task_config.neighborhood, lsds=task_config.lsds + neighborhood=task_config.neighborhood, + lsds=task_config.lsds, + affs_weight_clipmin=task_config.affs_weight_clipmin, + affs_weight_clipmax=task_config.affs_weight_clipmax, + lsd_weight_clipmin=task_config.lsd_weight_clipmin, + lsd_weight_clipmax=task_config.lsd_weight_clipmax, ) self.loss = AffinitiesLoss( len(task_config.neighborhood), task_config.lsds_to_affs_weight_ratio diff --git a/dacapo/experiments/tasks/affinities_task_config.py b/dacapo/experiments/tasks/affinities_task_config.py index a50c2141e..913a28187 100644 --- a/dacapo/experiments/tasks/affinities_task_config.py +++ b/dacapo/experiments/tasks/affinities_task_config.py @@ -36,3 +36,19 @@ class AffinitiesTaskConfig(TaskConfig): "help_text": "If training with lsds, set how much they should be weighted compared to affs." }, ) + affs_weight_clipmin: float = attr.ib( + default=0.05, + metadata={"help_text": "The minimum value for affinities weights."}, + ) + affs_weight_clipmax: float = attr.ib( + default=0.95, + metadata={"help_text": "The maximum value for affinities weights."}, + ) + lsd_weight_clipmin: float = attr.ib( + default=0.05, + metadata={"help_text": "The minimum value for lsds weights."}, + ) + lsd_weight_clipmax: float = attr.ib( + default=0.95, + metadata={"help_text": "The maximum value for lsds weights."}, + ) diff --git a/dacapo/experiments/tasks/distance_task.py b/dacapo/experiments/tasks/distance_task.py index cdb82e95c..10a4e8178 100644 --- a/dacapo/experiments/tasks/distance_task.py +++ b/dacapo/experiments/tasks/distance_task.py @@ -15,6 +15,8 @@ def __init__(self, task_config): channels=task_config.channels, scale_factor=task_config.scale_factor, mask_distances=task_config.mask_distances, + clipmin=task_config.clipmin, + clipmax=task_config.clipmax, ) self.loss = MSELoss() self.post_processor = ThresholdPostProcessor() diff --git a/dacapo/experiments/tasks/distance_task_config.py b/dacapo/experiments/tasks/distance_task_config.py index 130cf1c20..a26263375 100644 --- a/dacapo/experiments/tasks/distance_task_config.py +++ b/dacapo/experiments/tasks/distance_task_config.py @@ -46,3 +46,11 @@ class DistanceTaskConfig(TaskConfig): "is less than the distance to object boundary." }, ) + clipmin: float = attr.ib( + default=0.05, + metadata={"help_text": "The minimum value for distance weights."}, + ) + clipmax: float = attr.ib( + default=0.95, + metadata={"help_text": "The maximum value for distance weights."}, + ) diff --git a/dacapo/experiments/tasks/predictors/affinities_predictor.py b/dacapo/experiments/tasks/predictors/affinities_predictor.py index 40d81f5da..92915384a 100644 --- a/dacapo/experiments/tasks/predictors/affinities_predictor.py +++ b/dacapo/experiments/tasks/predictors/affinities_predictor.py @@ -24,6 +24,10 @@ def __init__( num_voxels: int = 20, downsample_lsds: int = 1, grow_boundary_iterations: int = 0, + affs_weight_clipmin: float = 0.05, + affs_weight_clipmax: float = 0.95, + lsd_weight_clipmin: float = 0.05, + lsd_weight_clipmax: float = 0.95, ): self.neighborhood = neighborhood self.lsds = lsds @@ -42,6 +46,10 @@ def __init__( else: self.num_lsds = 0 self.grow_boundary_iterations = grow_boundary_iterations + self.affs_weight_clipmin = affs_weight_clipmin + self.affs_weight_clipmax = affs_weight_clipmax + self.lsd_weight_clipmin = lsd_weight_clipmin + self.lsd_weight_clipmax = lsd_weight_clipmax def extractor(self, voxel_size): if self._extractor is None: @@ -155,6 +163,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): slab=tuple(1 if c == "c" else -1 for c in target.axes), masks=[mask_data], moving_counts=moving_class_counts, + clipmin=self.affs_weight_clipmin, + clipmax=self.affs_weight_clipmax, ) if self.lsds: lsd_weights, moving_lsd_class_counts = balance_weights( @@ -163,6 +173,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): slab=(-1,) * len(gt.axes), masks=[mask_data], moving_counts=moving_lsd_class_counts, + clipmin=self.lsd_weight_clipmin, + clipmax=self.lsd_weight_clipmax, ) lsd_weights = np.ones( (self.num_lsds,) + aff_weights.shape[1:], dtype=aff_weights.dtype diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index 70c2bde4a..8ddab6131 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -27,7 +27,14 @@ class DistancePredictor(Predictor): in the channels argument. """ - def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): + def __init__( + self, + channels: List[str], + scale_factor: float, + mask_distances: bool, + clipmin: float = 0.05, + clipmax: float = 0.95, + ): self.channels = channels self.norm = "tanh" self.dt_scale_factor = scale_factor @@ -36,6 +43,8 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo self.max_distance = 1 * scale_factor self.epsilon = 5e-2 self.threshold = 0.8 + self.clipmin = clipmin + self.clipmax = clipmax @property def embedding_dims(self): @@ -83,6 +92,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): slab=tuple(1 if c == "c" else -1 for c in gt.axes), masks=[mask[target.roi], distance_mask], moving_counts=moving_class_counts, + clipmin=self.clipmin, + clipmax=self.clipmax, ) return ( NumpyArray.from_np_array(