Skip to content

Commit

Permalink
fix: Fix data module wrapper behavior and errors in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jan 10, 2025
1 parent facf03f commit 8bbee5e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 47 deletions.
65 changes: 22 additions & 43 deletions src/careamics/lightning/train_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,9 @@
from numpy.typing import NDArray
from torch.utils.data import DataLoader, IterableDataset

from careamics.config import DataFactory
from careamics.config.data import GeneralDataConfig
from careamics.config.support import (
SupportedData,
SupportedPixelManipulation,
SupportedTransform,
)
from careamics.config.transformations import N2VManipulateModel, TransformModel
from careamics.config.data import DataConfig, GeneralDataConfig, N2VDataConfig
from careamics.config.support import SupportedData
from careamics.config.transformations import TransformModel
from careamics.dataset.dataset_utils import (
get_files_size,
list_files,
Expand Down Expand Up @@ -507,14 +502,22 @@ def create_train_datamodule(
"""Create a TrainDataModule.
This function is used to explicitly pass the parameters usually contained in a
`data_model` configuration to a TrainDataModule.
`GenericDataConfig` to a TrainDataModule.
Since the lightning datamodule has no access to the model, make sure that the
parameters passed to the datamodule are consistent with the model's requirements and
are coherent.
By default, the train DataModule will be set for Noise2Void if no target data is
provided.
provided. That means that it will add a `N2VManipulateModel` transformation to the
list of augmentations. The default augmentations are XY flip, XY rotation, and N2V
pixel manipulation. If you pass a training target data, the default behaviour is to
train a supervised model. It will use the default XY flip and rotation
augmentations.
To use a different set of transformations, you can pass a list of transforms to
`transforms`. Note that if you intend to use Noise2Void, you should add
`N2VManipulateModel` as the last transform in the list of transformations.
The data module can be used with Path, str or numpy arrays. In the case of
numpy arrays, it loads and computes all the patches in memory. For Path and str
Expand All @@ -526,11 +529,6 @@ def create_train_datamodule(
To use array data, set `data_type` to `array` and pass a numpy array to
`train_data`.
In particular, N2V requires a specific transformation (N2V manipulates), which is
not compatible with supervised training. The default transformations applied to the
training patches are defined in `careamics.config.data_model`. To use different
transformations, pass a list of transforms. See examples for more details.
By default, CAREamics only supports types defined in
`careamics.config.support.SupportedData`. To read custom data types, you can set
`data_type` to `custom` and provide a function that returns a numpy array from a
Expand Down Expand Up @@ -635,11 +633,12 @@ def create_train_datamodule(
transforms:
>>> import numpy as np
>>> from careamics.lightning import create_train_datamodule
>>> from careamics.config.transformations import XYFlipModel
>>> from careamics.config.transformations import XYFlipModel, N2VManipulateModel
>>> from careamics.config.support import SupportedTransform
>>> my_array = np.arange(256).reshape(16, 16)
>>> my_transforms = [
... XYFlipModel(flip_y=False)
... XYFlipModel(flip_y=False),
... N2VManipulateModel()
... ]
>>> data_module = create_train_datamodule(
... train_data=my_array,
Expand Down Expand Up @@ -667,34 +666,14 @@ def create_train_datamodule(
data_dict["transforms"] = transforms

# TODO not compatible with HDN, consider adding an argument for n2v/hdn
# if there are no target, then we enforce n2v via the transforms
if train_target_data is None:
n2v_manipulate = N2VManipulateModel(
strategy=(
SupportedPixelManipulation.MEDIAN
if use_n2v2
else SupportedPixelManipulation.UNIFORM
),
struct_mask_axis=struct_n2v_axis,
struct_mask_span=struct_n2v_span,
)

if "transforms" in data_dict.keys():
# check if there is a N2V transformation
found_n2v = False
for t in data_dict["transforms"]:
if t.name == SupportedTransform.N2V_MANIPULATE.value:
found_n2v = True

if not found_n2v:
# add it
data_dict["transforms"].append(n2v_manipulate)
else:
# create the transforms
data_dict["transforms"] = [n2v_manipulate]
data_config: GeneralDataConfig = N2VDataConfig(**data_dict)
assert isinstance(data_config, N2VDataConfig)

# validate configuration
data_config = DataFactory(data=data_dict).data
data_config.set_n2v2(use_n2v2)
data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
else:
data_config = DataConfig(**data_dict)

# sanity check on the dataloader parameters
if "batch_size" in dataloader_params:
Expand Down
60 changes: 56 additions & 4 deletions tests/lightning/test_train_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
SupportedData,
SupportedPixelManipulation,
SupportedStructAxis,
SupportedTransform,
)
from careamics.config.transformations import N2VManipulateModel, XYFlipModel
from careamics.dataset import InMemoryDataset, PathIterableDataset
from careamics.lightning import TrainDataModule, create_train_datamodule

Expand Down Expand Up @@ -40,7 +42,7 @@ def test_wrapper_unknown_type(simple_array):
create_train_datamodule(
train_data=simple_array,
data_type="wrong_type",
patch_size=(10, 10),
patch_size=(8, 8),
axes="YX",
batch_size=2,
)
Expand All @@ -63,21 +65,71 @@ def test_wrapper_train_array(simple_array):
assert len(list(data_module.train_dataloader())) > 0


def test_wrapper_supervised(simple_array):
"""Test that a supervised data config is created."""
data_module = create_train_datamodule(
train_data=simple_array,
data_type="array",
patch_size=(8, 8),
axes="YX",
batch_size=2,
train_target_data=simple_array,
val_minimum_patches=2,
)
for transform in data_module.data_config.transforms:
assert transform.name != SupportedTransform.N2V_MANIPULATE.value


def test_wrapper_supervised_n2v_throws_error(simple_array):
"""Test that an error is raised if target data is passed but the transformations
(default ones) contain N2V manipulate."""
with pytest.raises(ValueError):
create_train_datamodule(
train_data=simple_array,
data_type="array",
patch_size=(10, 10),
patch_size=(8, 8),
axes="YX",
batch_size=2,
train_target_data=simple_array,
val_minimum_patches=2,
transforms=[XYFlipModel(), N2VManipulateModel()],
)


def test_wrapper_n2v_wthout_pm_error(simple_array):
"""Test that an error is raised if target data is passed but the transformations
(default ones) contain N2V manipulate."""
with pytest.raises(ValueError):
create_train_datamodule(
train_data=simple_array,
data_type="array",
patch_size=(8, 8),
axes="YX",
batch_size=2,
val_minimum_patches=2,
transforms=[
XYFlipModel(),
],
)


def test_wrapper_default_n2v():
"""Test that instantiating a TrainDataModule with N2V works."""
data_module = create_train_datamodule(
train_data=np.zeros((10, 10)),
data_type="array",
patch_size=(8, 8),
axes="YX",
batch_size=2,
)

# N2VManipulate as last transform
assert (
data_module.data_config.transforms[-1].name
== SupportedTransform.N2V_MANIPULATE.value
)


@pytest.mark.parametrize(
"use_n2v2, strategy",
[
Expand All @@ -90,7 +142,7 @@ def test_wrapper_n2v2(simple_array, use_n2v2, strategy):
data_module = create_train_datamodule(
train_data=simple_array,
data_type="array",
patch_size=(16, 16),
patch_size=(8, 8),
axes="YX",
batch_size=2,
use_n2v2=use_n2v2,
Expand All @@ -106,7 +158,7 @@ def test_wrapper_structn2v(simple_array):
data_module = create_train_datamodule(
train_data=simple_array,
data_type="array",
patch_size=(16, 16),
patch_size=(8, 8),
axes="YX",
batch_size=2,
struct_n2v_axis=struct_axis,
Expand Down

0 comments on commit 8bbee5e

Please sign in to comment.