Skip to content

Commit

Permalink
Fix explicit augmentations passing in the fit method
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Nov 22, 2024
1 parent f6c7bba commit 86af2a3
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,9 @@ def fit(
transforms.Crop if not self.config.is_3d else transforms3d.Crop3D
)
assert any(
isinstance(p, _crop_cls) for p in augment_train.transforms
isinstance(p, _crop_cls) for p in augment_train.augmentations
), "Custom augmenter must contain a cropping transform!"
tr_augmenter = self.build_image_augmenter(
crop_size, point_priority=point_priority
)
tr_augmenter = augment_train
elif augment_train:
tr_augmenter = self.build_image_augmenter(
crop_size, point_priority=point_priority
Expand Down

0 comments on commit 86af2a3

Please sign in to comment.