Skip to content

Commit

Permalink
refac: Refactor config factories and N2V
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jan 21, 2025
1 parent 3c4b919 commit 70a81e3
Show file tree
Hide file tree
Showing 33 changed files with 398 additions and 793 deletions.
2 changes: 0 additions & 2 deletions src/careamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"Configuration",
"algorithm_factory",
"configuration_factory",
"data_factory",
"load_configuration",
"save_configuration",
]
Expand All @@ -22,7 +21,6 @@
Configuration,
algorithm_factory,
configuration_factory,
data_factory,
load_configuration,
save_configuration,
)
7 changes: 2 additions & 5 deletions src/careamics/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
"CheckpointModel",
"Configuration",
"DataConfig",
"DataConfig",
"GaussianMixtureNMConfig",
"GeneralDataConfig",
"InferenceConfig",
"LVAELossConfig",
"MultiChannelNMConfig",
"N2NAlgorithm",
"N2NConfiguration",
"N2VAlgorithm",
"N2VConfiguration",
"N2VDataConfig",
"TrainingConfig",
"UNetBasedAlgorithm",
"VAEBasedAlgorithm",
Expand All @@ -30,7 +29,6 @@
"create_care_configuration",
"create_n2n_configuration",
"create_n2v_configuration",
"data_factory",
"load_configuration",
"save_configuration",
]
Expand All @@ -51,10 +49,9 @@
create_care_configuration,
create_n2n_configuration,
create_n2v_configuration,
data_factory,
)
from .configuration_io import load_configuration, save_configuration
from .data import DataConfig, GeneralDataConfig, N2VDataConfig
from .data import DataConfig
from .inference_model import InferenceConfig
from .loss_model import LVAELossConfig
from .n2n_configuration import N2NConfiguration
Expand Down
80 changes: 40 additions & 40 deletions src/careamics/config/algorithms/n2v_algorithm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from typing import Annotated, Literal

from pydantic import AfterValidator, ConfigDict
from pydantic import AfterValidator, ConfigDict, model_validator
from typing_extensions import Self

from careamics.config.architectures import UNetModel
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
from careamics.config.transformations import N2VManipulateModel
from careamics.config.validators import (
model_matching_in_out_channels,
Expand Down Expand Up @@ -33,63 +35,61 @@ class N2VAlgorithm(UNetBasedAlgorithm):
AfterValidator(model_without_final_activation),
]

def get_masking_strategy(self) -> str:
"""Get the masking strategy for N2V."""
return self.n2v_masking.strategy
@model_validator(mode="after")
def validate_n2v2(self) -> Self:
"""Validate that the N2V2 strategy and models are set correctly.
def set_masking_strategy(self, strategy: Literal["uniform", "median"]) -> None:
"""
Set masking strategy.
Parameters
----------
strategy : "uniform" or "median"
Strategy to use for N2V2.
Returns
-------
Self
The validateed configuration.
Raises
------
ValueError
If the N2V pixel manipulate transform is not found in the transforms.
If N2V2 is used with the wrong pixel manipulation strategy.
"""
self.model.n2v_masking.strategy = strategy
if self.model.n2v2:
if self.n2v_masking.strategy != SupportedPixelManipulation.MEDIAN.value:
raise ValueError(
f"N2V2 can only be used with the "
f"{SupportedPixelManipulation.MEDIAN} pixel manipulation strategy"
f". Change the N2VManipulate transform strategy."
)
else:
if self.n2v_masking.strategy != SupportedPixelManipulation.UNIFORM.value:
raise ValueError(
f"N2V can only be used with the "
f"{SupportedPixelManipulation.UNIFORM} pixel manipulation strategy"
f". Change the N2VManipulate transform strategy."
)
return self

def set_n2v2(self, use_n2v2: bool) -> None:
"""
Set the configuration to use N2V2 or the vanilla Noise2Void.
This method ensures that N2V2 is set correctly and remain coherent, as opposed
to setting the different parameters individually.
Parameters
----------
use_n2v2 : bool
Whether to use N2V2.
"""
if use_n2v2:
self.set_masking_strategy("median")
self.n2v_masking.strategy = SupportedPixelManipulation.MEDIAN.value
self.model.n2v2 = True
else:
self.set_masking_strategy("uniform")
self.n2v_masking.strategy = SupportedPixelManipulation.UNIFORM.value
self.model.n2v2 = False

def is_using_struct_n2v(self) -> bool:
"""Check if the configuration is using structN2V."""
return self.n2v_masking.struct_mask_axis != "none" # TODO change!

def set_structN2V_mask(
self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
) -> None:
"""
Set structN2V mask parameters.
def is_struct_n2v(self) -> bool:
"""Check if the configuration is using structN2V.
Setting `mask_axis` to `none` will disable structN2V.
Parameters
----------
mask_axis : Literal["horizontal", "vertical", "none"]
Axis along which to apply the mask. `none` will disable structN2V.
mask_span : int
Total span of the mask in pixels.
Raises
------
ValueError
If the N2V pixel manipulate transform is not found in the transforms.
Returns
-------
bool
Whether the configuration is using structN2V.
"""
self.n2v_masking.struct_mask_axis = mask_axis
self.n2v_masking.struct_mask_span = mask_span
return self.n2v_masking.struct_mask_axis != SupportedStructAxis.NONE.value
4 changes: 0 additions & 4 deletions src/careamics/config/care_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from careamics.config.algorithms.care_algorithm_model import CAREAlgorithm
from careamics.config.configuration import Configuration
from careamics.config.data import DataConfig

CARE = "CARE"

Expand All @@ -30,9 +29,6 @@ class CAREConfiguration(Configuration):
algorithm_config: CAREAlgorithm
"""Algorithm configuration."""

data_config: DataConfig
"""Data configuration."""

def get_algorithm_friendly_name(self) -> str:
"""
Get the algorithm friendly name.
Expand Down
4 changes: 2 additions & 2 deletions src/careamics/config/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Self

from careamics.config.algorithms import UNetBasedAlgorithm, VAEBasedAlgorithm
from careamics.config.data import GeneralDataConfig
from careamics.config.data import DataConfig
from careamics.config.training_model import TrainingConfig


Expand Down Expand Up @@ -129,7 +129,7 @@ class Configuration(BaseModel):
"""Algorithm configuration, holding all parameters required to configure the
model."""

data_config: GeneralDataConfig
data_config: DataConfig
"""Data configuration, holding all parameters required to configure the training
data loader."""

Expand Down
Loading

0 comments on commit 70a81e3

Please sign in to comment.