diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index aca3e083..71d52fcd 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -25,6 +25,7 @@ If applicable, add screenshots to help explain your problem. **Desktop (please complete the following information):** - OS: [e.g. iOS] - Version [e.g. 22] +- microscopy-portfolio version [e.g. 0.0.5] **Environment:** Please add here the content of your conda environment with versions. diff --git a/.github/pull_request_template.md b/.github/PR_TEMPLATE/pull_request.md similarity index 100% rename from .github/pull_request_template.md rename to .github/PR_TEMPLATE/pull_request.md diff --git a/.github/TEST_FAIL_TEMPLATE.md b/.github/TEST_FAIL_TEMPLATE.md deleted file mode 100644 index 2f95b732..00000000 --- a/.github/TEST_FAIL_TEMPLATE.md +++ /dev/null @@ -1,12 +0,0 @@ ---- -title: "{{ env.TITLE }}" -labels: [bug] ---- -The {{ workflow }} workflow failed on {{ date | date("YYYY-MM-DD HH:mm") }} UTC. - -The most recent failing test was on {{ env.PLATFORM }} py{{ env.PYTHON }} {{ env.BACKEND }} -with commit: {{ sha }}. - -Full run: https://github.com/CAREamics/careamics-portfolio/actions/runs/{{ env.RUN_ID }} - -(This post will be updated if another test fails, as long as this issue remains open.) diff --git a/.github/dependabot.yml b/.github/dependabot.yml deleted file mode 100644 index 2b6e52b8..00000000 --- a/.github/dependabot.yml +++ /dev/null @@ -1,11 +0,0 @@ ---- -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates - -version: 2 -updates: - - package-ecosystem: github-actions - directory: / - schedule: - interval: weekly - commit-message: - prefix: 'ci(dependabot):' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3d18a5c2..d6637c26 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,109 +1,103 @@ ---- name: CI on: - push: - branches: - - main - tags: - - v*.*.* - pull_request: - workflow_dispatch: - schedule: + push: + branches: + - main + tags: + - "v*" + pull_request: + workflow_dispatch: + schedule: # run every week (for --pre release tests) - - cron: 0 0 * * 0 + - cron: "0 0 * * 0" jobs: - check-manifest: + check-manifest: # check-manifest is a tool that checks that all files in version control are # included in the sdist (unless explicitly excluded) - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - run: pipx run check-manifest - - test: - name: ${{ matrix.platform }} (${{ matrix.python-version }}) - runs-on: ${{ matrix.platform }} - strategy: - fail-fast: false - matrix: - python-version: ['3.8', '3.9', '3.10', '3.11'] - platform: [ubuntu-latest, macos-latest, windows-latest] - - steps: - - name: ๐Ÿ›‘ Cancel Previous Runs - uses: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - - - uses: actions/checkout@v4 - - - name: ๐Ÿ Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache-dependency-path: pyproject.toml - cache: pip - - - name: Install Dependencies - run: | - python -m pip install -U pip - # if running a cron job, we add the --pre flag to test against pre-releases - python -m pip install .[test] ${{ github.event_name == 'schedule' && '--pre' || '' }} - - - name: ๐Ÿงช Run Tests - run: pytest --color=yes --cov --cov-report=xml --cov-report=term-missing -m "not gpu" - - # If something goes wrong with --pre tests, we can open an issue in the repo - - name: ๐Ÿ“ Report --pre Failures - if: failure() && github.event_name == 'schedule' - uses: JasonEtco/create-an-issue@v2 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - PLATFORM: ${{ matrix.platform }} - PYTHON: ${{ matrix.python-version }} - RUN_ID: ${{ github.run_id }} - TITLE: '[test-bot] pip install --pre is failing' - with: - filename: .github/TEST_FAIL_TEMPLATE.md - update_existing: true - - - name: Coverage - uses: codecov/codecov-action@v4 - - deploy: - name: Deploy - needs: test - if: success() && startsWith(github.ref, 'refs/tags/') && github.event_name != 'schedule' - runs-on: ubuntu-latest - - permissions: - # IMPORTANT: this permission is mandatory for trusted publishing on PyPi - # see https://docs.pypi.org/trusted-publishers/ - id-token: write - # This permission allows writing releases - contents: write - - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: ๐Ÿ Set up Python - uses: actions/setup-python@v5 - with: - python-version: 3.x - - - name: ๐Ÿ‘ท Build - run: | - python -m pip install build - python -m build - - - name: ๐Ÿšข Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - - - uses: softprops/action-gh-release@v1 - with: - generate_release_notes: true - files: ./dist/* + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - run: pipx run check-manifest + + test: + name: ${{ matrix.platform }} (${{ matrix.python-version }}) + runs-on: ${{ matrix.platform }} + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11"] + # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories + platform: [ubuntu-latest, macos-13, windows-latest] + + steps: + - name: ๐Ÿ›‘ Cancel Previous Runs + uses: styfle/cancel-workflow-action@0.11.0 + with: + access_token: ${{ github.token }} + + - uses: actions/checkout@v3 + + - name: ๐Ÿ Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache-dependency-path: "pyproject.toml" + cache: "pip" + + - name: Install Dependencies + run: | + python -m pip install -U pip + # if running a cron job, we add the --pre flag to test against pre-releases + python -m pip install .[dev] ${{ github.event_name == 'schedule' && '--pre' || '' }} + + - name: ๐Ÿงช Run Tests + run: pytest --color=yes --cov --cov-report=xml --cov-report=term-missing -m "not gpu" + + # If something goes wrong with --pre tests, we can open an issue in the repo + - name: ๐Ÿ“ Report --pre Failures + if: failure() && github.event_name == 'schedule' + uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PLATFORM: ${{ matrix.platform }} + PYTHON: ${{ matrix.python-version }} + RUN_ID: ${{ github.run_id }} + TITLE: "[test-bot] pip install --pre is failing" + with: + filename: .github/TEST_FAIL_TEMPLATE.md + update_existing: true + + - name: Coverage + uses: codecov/codecov-action@v3 + + deploy: + name: Deploy + needs: test + if: success() && startsWith(github.ref, 'refs/tags/') && github.event_name != 'schedule' + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: ๐Ÿ Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.x" + + - name: ๐Ÿ‘ท Build + run: | + python -m pip install build + python -m build + + - name: ๐Ÿšข Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.TWINE_API_KEY }} + + - uses: softprops/action-gh-release@v1 + with: + generate_release_notes: true diff --git a/.gitignore b/.gitignore index 1be3365b..eb3fbdf9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,6 @@ .pytest* *.DS_Store* **/.ipynb_checkpoints +.coverage +coverage.xml +lightning_logs diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3df6b985..220afd90 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,62 +1,56 @@ ---- # enable pre-commit.ci at https://pre-commit.ci/ # it adds: # 1. auto fixing pull requests # 2. auto updating the pre-commit configuration ci: - autoupdate_schedule: monthly - autofix_commit_msg: 'style(pre-commit.ci): auto fixes [...]' - autoupdate_commit_msg: 'ci(pre-commit.ci): autoupdate' + autoupdate_schedule: monthly + autofix_commit_msg: "style(pre-commit.ci): auto fixes [...]" + autoupdate_commit_msg: "ci(pre-commit.ci): autoupdate" repos: - - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.15 - hooks: - - id: validate-pyproject - - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.9 - hooks: - - id: ruff - args: [--fix, --target-version, py38] - - - repo: https://github.com/psf/black - rev: 23.12.1 - hooks: - - id: black - - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 - hooks: - - id: mypy - files: ^src/ - additional_dependencies: - - numpy - - types-PyYAML + - repo: https://github.com/abravalheri/validate-pyproject + rev: v0.14 + hooks: + - id: validate-pyproject + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.292 + hooks: + - id: ruff + args: [--fix, --target-version, py38] + + - repo: https://github.com/psf/black + rev: 23.9.1 + hooks: + - id: black + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.5.1 + hooks: + - id: mypy + files: "^src/" + additional_dependencies: + - numpy + - types-PyYAML + - types-setuptools # check docstrings - - repo: https://github.com/numpy/numpydoc - rev: v1.6.0 - hooks: - - id: numpydoc-validation + - repo: https://github.com/numpy/numpydoc + rev: v1.6.0 + hooks: + - id: numpydoc-validation # jupyter linting and formatting - - repo: https://github.com/nbQA-dev/nbQA - rev: 1.7.1 - hooks: - - id: nbqa-ruff - args: [--fix] - - id: nbqa-black - #- id: nbqa-mypy + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.7.0 + hooks: + - id: nbqa-ruff + args: [--fix] + - id: nbqa-black + #- id: nbqa-mypy # strip out jupyter notebooks - - repo: https://github.com/kynan/nbstripout - rev: 0.6.1 - hooks: - - id: nbstripout - - # yaml formatter - - repo: https://github.com/jumanjihouse/pre-commit-hook-yamlfmt - rev: 0.2.3 - hooks: - - id: yamlfmt + - repo: https://github.com/kynan/nbstripout + rev: 0.6.1 + hooks: + - id: nbstripout diff --git a/README.md b/README.md index c3c7ceb9..764bf01f 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@

-# CAREamics +# CAREamics Restoration [![License](https://img.shields.io/pypi/l/careamics.svg?color=green)](https://github.com/CAREamics/careamics/blob/main/LICENSE) [![PyPI](https://img.shields.io/pypi/v/careamics.svg?color=green)](https://pypi.org/project/careamics) @@ -12,12 +12,67 @@ [![CI](https://github.com/CAREamics/careamics/actions/workflows/ci.yml/badge.svg)](https://github.com/CAREamics/careamics/actions/workflows/ci.yml) [![codecov](https://codecov.io/gh/CAREamics/careamics/branch/main/graph/badge.svg)](https://codecov.io/gh/CAREamics/careamics) +## Installation -For details installation and usage, please refer to the [guide](https://careamics.github.io/). +``` bash +pip install careamics +``` +For more details on the options please follow the installation [guide](https://careamics.github.io/careamics/). - diff --git a/examples/2D/n2n/example_SEM_careamist.ipynb b/examples/2D/n2n/example_SEM_careamist.ipynb new file mode 100644 index 00000000..db773806 --- /dev/null +++ b/examples/2D/n2n/example_SEM_careamist.ipynb @@ -0,0 +1,239 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import tifffile\n", + "from careamics_portfolio import PortfolioManager\n", + "\n", + "from careamics import CAREamist" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import Dataset Portfolio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Explore portfolio\n", + "portfolio = PortfolioManager()\n", + "print(portfolio.denoising)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download files\n", + "root_path = Path(\"./data\")\n", + "files = portfolio.denoising.N2N_SEM.download(root_path)\n", + "print(f\"List of downloaded files: {files}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize training data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load images\n", + "train_image = tifffile.imread(files[0])\n", + "print(f\"Train image shape: {train_image.shape}\")\n", + "\n", + "# Display images\n", + "side = int(np.ceil(np.sqrt(train_image.shape[0])))\n", + "fig, ax = plt.subplots(side, side, figsize=(15, 15))\n", + "\n", + "for i in range(train_image.shape[0]):\n", + " ax.flat[i].imshow(train_image[i], cmap=\"gray\")\n", + " ax.flat[i].axis(\"off\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize validation data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "val_image = tifffile.imread(files[2])\n", + "print(f\"Validation image shape: {val_image.shape}\")\n", + "\n", + "# Display images\n", + "side = int(np.ceil(np.sqrt(val_image.shape[0])))\n", + "fig, ax = plt.subplots(side, side, figsize=(15, 15))\n", + "for i in range(val_image.shape[0]):\n", + " ax.flat[i].imshow(val_image[i], cmap=\"gray\")\n", + " ax.flat[i].axis(\"off\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set paths\n", + "\n", + "data_path = Path(root_path / \"n2n_sem\")\n", + "train_path = data_path / \"train\"\n", + "test_path = data_path / \"val\"\n", + "\n", + "train_path.mkdir(parents=True, exist_ok=True)\n", + "test_path.mkdir(parents=True, exist_ok=True)\n", + "\n", + "shutil.copy(root_path / files[0], train_path / \"train_image.tif\")\n", + "shutil.copy(root_path / files[1], test_path / \"test_image.tif\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Initialize the Model\n", + "\n", + "Create a Pytorch Lightning module\n", + "\n", + "Please take as look at the [documentation](https://careamics.github.io) to see the full list of parameters and configuration options" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "engine = CAREamist(source=\"n2n_2D_SEM.yml\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Part 3. Run training \n", + "\n", + "We need to specify the paths to training and validation data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "engine.train(\n", + " train_source=train_image[0],\n", + " val_source=train_image[1],\n", + " train_target=train_image[2],\n", + " val_target=train_image[3],\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run prediction\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds = engine.predict(source=val_image[0], tile_size=(256, 256))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize the prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fi, ax = plt.subplots(1, 2, figsize=(15, 15))\n", + "ax[0].imshow(preds[0].squeeze(), cmap=\"gray\")\n", + "ax[0].set_title(\"Prediction\")\n", + "ax[1].imshow(val_image[0].squeeze(), cmap=\"gray\")\n", + "ax[1].set_title(\"Ground Truth\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.13 ('HDNn')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + }, + "vscode": { + "interpreter": { + "hash": "faf8b084d52efbff00ddf863c4fb0ca7a3b023f9f18590a5b65c31dc02d793e2" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/2D/n2n/n2n_2D_SEM.yml b/examples/2D/n2n/n2n_2D_SEM.yml new file mode 100644 index 00000000..936087da --- /dev/null +++ b/examples/2D/n2n/n2n_2D_SEM.yml @@ -0,0 +1,60 @@ +# Name, composed of letters, numbers, spaces, dashes and underscores +experiment_name: N2V_SEM_N2N + +algorithm: + algorithm: n2n + loss: mae + model: + architecture: UNet + n2v2: False + optimizer_parameters: + lr: 1e-3 + lr_scheduler_parameters: + factor: 0.5 + patience: 10 + +training: + # Number of epochs, greater or equal than 1 + num_epochs: 1 + + # Batch size, greater or equal than 1 + batch_size: 128 + + # Optional, use WandB (True or False), must be installed in your conda environment + use_wandb: False + + callbacks: + # Optional, any parameter for Pytorch lightning callbacks + + # Optional, any parameter for Pytorch lightning ModelCheckpoint + checkpoint: + # Optional, monitor the validation loss (val_loss) or the validation accuracy (val_acc) + # parameters: + # dirpath: 'checkpoints' + # monitor : val_loss + save_last: True + save_top_k: 3 + # TODO should it be here or in predict ? + # Optional, path to the checkpoint to load the model + # load_path: '' + + # Optional, automatic mixed precision + amp: + # Use (True or False) + use: True + +data: + # Extension of the data + data_type: array + + patch_size: [64, 64] + + # Axes, among STCZYX with constraints on order + axes: YX + + tta_transforms: True + + batch_size: 128 + + # Optional, number of workers for data loading, greater or equal than 0 + num_workers: 0 \ No newline at end of file diff --git a/examples/2D/example_BSD68.ipynb b/examples/2D/n2v/example_BSD68_careamist.ipynb similarity index 72% rename from examples/2D/example_BSD68.ipynb rename to examples/2D/n2v/example_BSD68_careamist.ipynb index 5643f344..81698274 100644 --- a/examples/2D/example_BSD68.ipynb +++ b/examples/2D/n2v/example_BSD68_careamist.ipynb @@ -6,15 +6,17 @@ "metadata": {}, "outputs": [], "source": [ - "import pprint\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import tifffile\n", "from careamics_portfolio import PortfolioManager\n", "\n", - "from careamics.engine import Engine\n", - "from careamics.metrics import psnr" + "from careamics import CAREamist\n", + "from careamics.lightning_datamodule import (\n", + " CAREamicsPredictDataModule,\n", + ")\n", + "from careamics.utils.metrics import psnr" ] }, { @@ -109,9 +111,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Initialize the Engine\n", + "### Initialize the Model\n", "\n", - "Engine contains the dataloading pipeline and the model training logic. We'll initialize the engine with the config file, but it can also be initialized from a pre-trained checkpoint.\n", + "Create a Pytorch Lightning module\n", "\n", "Please take as look at the [documentation](https://careamics.github.io) to see the full list of parameters and configuration options" ] @@ -122,24 +124,7 @@ "metadata": {}, "outputs": [], "source": [ - "engine = Engine(config_path=\"n2v_2D_BSD.yml\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Visualize training configuration" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pprint.PrettyPrinter(indent=2).pprint(engine.cfg.model_dump(exclude_optionals=False))" + "engine = CAREamist(source=\"n2v_2D_BSD.yml\")" ] }, { @@ -158,7 +143,7 @@ "metadata": {}, "outputs": [], "source": [ - "train_stats, val_stats = engine.train(train_path=train_path, val_path=val_path)" + "engine.train(train_source=train_path, val_source=val_path)" ] }, { @@ -166,7 +151,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Visualize statistics" + "### Define a prediction datamodule" ] }, { @@ -175,10 +160,14 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot([next(iter(d.values())) for d in train_stats], label=\"Train loss\")\n", - "plt.plot([next(iter(d.values())) for d in val_stats], label=\"Validation loss\")\n", - "plt.legend(loc=\"best\")\n", - "plt.xlabel(\"Epoch\")" + "pred_data_module = CAREamicsPredictDataModule(\n", + " pred_data=test_path,\n", + " data_type=\"tiff\",\n", + " tile_size=(256, 256),\n", + " axes=\"YX\",\n", + " batch_size=1,\n", + " tta_transforms=True,\n", + ")" ] }, { @@ -186,7 +175,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Visualize files to denoise" + "### Run prediction\n", + "\n", + "We need to specify the path to the data we want to denoise" ] }, { @@ -195,9 +186,7 @@ "metadata": {}, "outputs": [], "source": [ - "test_img = tifffile.imread(test_path / \"bsd68_gaussian25_1.tiff\")\n", - "plt.imshow(test_img, cmap=\"gray\")\n", - "print(test_img.shape)" + "preds = engine.predict(source=pred_data_module)" ] }, { @@ -205,9 +194,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Run prediction\n", - "\n", - "We need to specify the path to the data we want to denoise" + "### Visualize results and compute metrics\n" ] }, { @@ -216,17 +203,9 @@ "metadata": {}, "outputs": [], "source": [ - "preds = engine.predict(\n", - " input=test_path, tile_shape=[256, 256], overlaps=[48, 48], axes=\"YX\"\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Visualize results and compute metrics\n" + "# Create a list of ground truth images\n", + "\n", + "gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]" ] }, { @@ -235,9 +214,20 @@ "metadata": {}, "outputs": [], "source": [ - "# Create a list of ground truth images\n", + "# Plot single image\n", "\n", - "gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]" + "image_idx = 0\n", + "_, subplot = plt.subplots(1, 2, figsize=(10, 10))\n", + "\n", + "subplot[0].imshow(preds[image_idx].squeeze(), cmap=\"gray\")\n", + "subplot[0].set_title(\"Prediction\")\n", + "subplot[1].imshow(gts[image_idx], cmap=\"gray\")\n", + "subplot[1].set_title(\"Ground truth\")\n", + "\n", + "\n", + "# Calculate PSNR for single image\n", + "psnr_single = psnr(preds[image_idx].squeeze(), gts[image_idx].squeeze())\n", + "print(f\"PSNR for image {image_idx}: {psnr_single}\")" ] }, { @@ -251,7 +241,7 @@ "image_idx = 0\n", "_, subplot = plt.subplots(1, 2, figsize=(10, 10))\n", "\n", - "subplot[0].imshow(preds[image_idx], cmap=\"gray\")\n", + "subplot[0].imshow(preds[image_idx].squeeze(), cmap=\"gray\")\n", "subplot[0].set_title(\"Prediction\")\n", "subplot[1].imshow(gts[image_idx], cmap=\"gray\")\n", "subplot[1].set_title(\"Ground truth\")" @@ -265,7 +255,7 @@ "source": [ "# Calculate PSNR for single image\n", "\n", - "psnr_single = psnr(gts[image_idx], preds[image_idx])\n", + "psnr_single = psnr(gts[image_idx], preds[image_idx].squeeze())\n", "print(f\"PSNR for image {image_idx}: {psnr_single}\")" ] }, @@ -278,34 +268,10 @@ "psnr_total = 0\n", "\n", "for pred, gt in zip(preds, gts):\n", - " psnr_total += psnr(gt, pred)\n", + " psnr_total += psnr(gt, pred.squeeze())\n", "\n", "print(f\"PSNR total: {psnr_total / len(preds)}\")" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Export to bioimage.io" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "engine.save_as_bioimage(engine.cfg.experiment_name + \"bioimage.zip\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -324,7 +290,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.9.18" }, "vscode": { "interpreter": { diff --git a/examples/2D/n2v/example_BSD68_lightning.ipynb b/examples/2D/n2v/example_BSD68_lightning.ipynb new file mode 100644 index 00000000..7b9dec78 --- /dev/null +++ b/examples/2D/n2v/example_BSD68_lightning.ipynb @@ -0,0 +1,397 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import tifffile\n", + "from careamics_portfolio import PortfolioManager\n", + "from pytorch_lightning import Trainer\n", + "\n", + "from careamics import CAREamicsModule\n", + "from careamics.lightning_datamodule import (\n", + " CAREamicsPredictDataModule,\n", + " CAREamicsTrainDataModule,\n", + ")\n", +<<<<<<< Updated upstream + "from careamics.lightning_prediction import CAREamicsPredictionLoop\n", +======= + "from careamics.lightning_prediction import CAREamicsFiring\n", +>>>>>>> Stashed changes + "from careamics.utils.metrics import psnr" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import Dataset Portfolio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Explore portfolio\n", + "portfolio = PortfolioManager()\n", + "print(portfolio.denoising)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download and unzip the files\n", + "root_path = Path(\"data\")\n", + "files = portfolio.denoising.N2V_BSD68.download(root_path)\n", + "print(f\"List of downloaded files: {files}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_path = Path(root_path / \"denoising-N2V_BSD68.unzip/BSD68_reproducibility_data\")\n", + "train_path = data_path / \"train\"\n", + "val_path = data_path / \"val\"\n", + "test_path = data_path / \"test\" / \"images\"\n", + "gt_path = data_path / \"test\" / \"gt\"\n", + "\n", + "train_path.mkdir(parents=True, exist_ok=True)\n", + "val_path.mkdir(parents=True, exist_ok=True)\n", + "test_path.mkdir(parents=True, exist_ok=True)\n", + "gt_path.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize training data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_image = tifffile.imread(next(iter(train_path.rglob(\"*.tiff\"))))[0]\n", + "print(f\"Train image shape: {train_image.shape}\")\n", + "plt.imshow(train_image, cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize validation data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "val_image = tifffile.imread(next(iter(val_path.rglob(\"*.tiff\"))))[0]\n", + "print(f\"Validation image shape: {val_image.shape}\")\n", + "plt.imshow(val_image, cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize the Model\n", + "\n", + "Create a Pytorch Lightning module\n", + "\n", + "Please take as look at the [documentation](https://careamics.github.io) to see the full list of parameters and configuration options" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# N2V2 requires changes to the UNet model and to the Dataset (augmentations)\n", + "use_n2v2 = False # change to True to use N2V2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = CAREamicsModule(\n", + " algorithm=\"n2v\",\n", + " loss=\"n2v\",\n", + " architecture=\"UNet\",\n", + " model_parameters={\"n2v2\": use_n2v2},\n", + " optimizer_parameters={\"lr\": 1e-3},\n", + " lr_scheduler_parameters={\"factor\": 0.5, \"patience\": 10},\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize the datamodule\n", + "\n", + "The data module can take a `Path` or `str` to a folder or file, or a `np.ndarray`.\n", + "\n", + "For custom types, you need to pass a read function and an extension_filter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_data_module = CAREamicsTrainDataModule(\n", + " train_data=train_path,\n", + " val_data=val_path,\n", +<<<<<<< Updated upstream + " data_type=\"tiff\", # to use np.ndarray, set data_type to \"array\"\n", +======= + " data_type=\"tiff\",\n", +>>>>>>> Stashed changes + " patch_size=(64, 64),\n", + " axes=\"SYX\",\n", + " batch_size=128,\n", + " dataloader_params={\"num_workers\": 4},\n", + " use_n2v2=use_n2v2,\n", + " struct_n2v_axis=\"none\", # choice between \"horizontal\", \"vertical\", or \"none\" (no structN2V)\n", + " struct_n2v_span=7,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run training \n", + "\n", + "We need to specify the paths to training and validation data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(max_epochs=1, default_root_dir=\"bsd_test\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.fit(model, datamodule=train_data_module)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define a prediction datamodule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pred_data_module = CAREamicsPredictDataModule(\n", + " pred_data=test_path,\n", + " data_type=\"tiff\",\n", + " tile_size=(256, 256),\n", + " tile_overlap=(48, 48),\n", + " axes=\"YX\",\n", + " batch_size=1,\n", + " tta_transforms=True,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run prediction\n", + "\n", + "First, we want to use CAREamics prediction loop, which allows tiling:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tiled_loop = CAREamicsPredictionLoop(trainer)\n", + "trainer.predict_loop = tiled_loop" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, we predict using the datamodule." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds = trainer.predict(model, datamodule=pred_data_module)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize results and compute metrics\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a list of ground truth images\n", + "\n", + "gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot single image\n", + "\n", + "image_idx = 0\n", + "_, subplot = plt.subplots(1, 2, figsize=(10, 10))\n", + "\n", + "subplot[0].imshow(preds[image_idx].squeeze(), cmap=\"gray\")\n", + "subplot[0].set_title(\"Prediction\")\n", + "subplot[1].imshow(gts[image_idx], cmap=\"gray\")\n", + "subplot[1].set_title(\"Ground truth\")\n", + "\n", + "\n", + "# Calculate PSNR for single image\n", + "psnr_single = psnr(preds[image_idx].squeeze(), gts[image_idx].squeeze())\n", + "print(f\"PSNR for image {image_idx}: {psnr_single}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot single image\n", + "\n", + "image_idx = 0\n", + "_, subplot = plt.subplots(1, 2, figsize=(10, 10))\n", + "\n", + "subplot[0].imshow(preds[image_idx].squeeze(), cmap=\"gray\")\n", + "subplot[0].set_title(\"Prediction\")\n", + "subplot[1].imshow(gts[image_idx], cmap=\"gray\")\n", + "subplot[1].set_title(\"Ground truth\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Calculate PSNR for single image\n", + "\n", + "psnr_single = psnr(gts[image_idx], preds[image_idx].squeeze())\n", + "print(f\"PSNR for image {image_idx}: {psnr_single}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "psnr_total = 0\n", + "\n", + "for pred, gt in zip(preds, gts):\n", + " psnr_total += psnr(gt, pred.squeeze())\n", + "\n", + "print(f\"PSNR total: {psnr_total / len(preds)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.11.3 ('caremics')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + }, + "vscode": { + "interpreter": { + "hash": "0d2a5a3ab9ff26e8b66efec3883fa5121030bb852a7a4271db665831444e4e91" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/2D/example_SEM.ipynb b/examples/2D/n2v/example_SEM_lightning.ipynb similarity index 66% rename from examples/2D/example_SEM.ipynb rename to examples/2D/n2v/example_SEM_lightning.ipynb index 17a848d4..93e4469c 100644 --- a/examples/2D/example_SEM.ipynb +++ b/examples/2D/n2v/example_SEM_lightning.ipynb @@ -6,16 +6,20 @@ "metadata": {}, "outputs": [], "source": [ - "import pprint\n", "import shutil\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import tifffile\n", "from careamics_portfolio import PortfolioManager\n", - "from matplotlib.pyplot import imshow\n", + "from pytorch_lightning import Trainer\n", "\n", - "from careamics.engine import Engine" + "from careamics import CAREamicsModule\n", + "from careamics.lightning_datamodule import (\n", + " CAREamicsPredictDataModule,\n", + " CAREamicsTrainDataModule,\n", + ")\n", + "from careamics.lightning_prediction import CAREamicsPredictionLoop" ] }, { @@ -66,7 +70,7 @@ "# Load images\n", "train_image = tifffile.imread(files[0])\n", "print(f\"Train image shape: {train_image.shape}\")\n", - "imshow(train_image, cmap=\"gray\")" + "plt.imshow(train_image, cmap=\"gray\")" ] }, { @@ -85,7 +89,7 @@ "source": [ "val_image = tifffile.imread(files[1])\n", "print(f\"Validation image shape: {val_image.shape}\")\n", - "imshow(val_image, cmap=\"gray\")" + "plt.imshow(val_image, cmap=\"gray\")" ] }, { @@ -110,9 +114,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Initialize the Engine\n", - "\n", - "Engine contains the dataloading pipeline and the model training logic. We'll initialize the engine with the config file, but it can also be initialized from a pre-trained checkpoint.\n", + "### Initialize the Model\n", "\n", "Please take as look at the [documentation](https://careamics.github.io) to see the full list of parameters and configuration options" ] @@ -123,7 +125,24 @@ "metadata": {}, "outputs": [], "source": [ - "engine = Engine(config_path=\"n2v_2D_SEM.yml\")" + "# N2V2 requires changes to the UNet model and to the Dataset (augmentations)\n", + "use_n2v2 = False # change to True to use N2V2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = CAREamicsModule(\n", + " algorithm=\"n2v\",\n", + " loss=\"n2v\",\n", + " architecture=\"UNet\",\n", + " model_parameters={\"n2v2\": False},\n", + " optimizer_parameters={\"lr\": 1e-3},\n", + " lr_scheduler_parameters={\"factor\": 0.5, \"patience\": 10},\n", + ")" ] }, { @@ -131,7 +150,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Visualize training configuration" + "### Initialize the datamodule" ] }, { @@ -140,7 +159,16 @@ "metadata": {}, "outputs": [], "source": [ - "pprint.PrettyPrinter(indent=2).pprint(engine.cfg.model_dump(exclude_optionals=False))" + "train_data_module = CAREamicsTrainDataModule(\n", + " train_data=train_path,\n", + " val_data=val_path,\n", + " data_type=\"tiff\",\n", + " patch_size=(64, 64),\n", + " axes=\"YX\",\n", + " batch_size=128,\n", + " dataloader_params={\"num_workers\": 0},\n", + " use_n2v2=use_n2v2,\n", + ")" ] }, { @@ -159,7 +187,16 @@ "metadata": {}, "outputs": [], "source": [ - "train_stats, val_stats = engine.train(train_path=train_path, val_path=val_path)" + "trainer = Trainer(max_epochs=1, default_root_dir=\"sem_n2v2_test_struct\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.fit(model, datamodule=train_data_module)" ] }, { @@ -167,7 +204,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Visualize statistics" + "### Define the prediction datamodule" ] }, { @@ -176,10 +213,14 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot([next(iter(d.values())) for d in train_stats], label=\"Train loss\")\n", - "plt.plot([next(iter(d.values())) for d in val_stats], label=\"Validation loss\")\n", - "plt.legend(loc=\"best\")\n", - "plt.xlabel(\"Epoch\")" + "pred_data_module = CAREamicsPredictDataModule(\n", + " pred_data=train_path,\n", + " data_type=\"tiff\",\n", + " tile_size=(256, 256),\n", + " axes=\"YX\",\n", + " batch_size=1,\n", + " tta_transforms=True,\n", + ")" ] }, { @@ -198,15 +239,8 @@ "metadata": {}, "outputs": [], "source": [ - "preds = engine.predict(input=train_path, tile_shape=[256, 256], overlaps=[48, 48])" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Visualize the prediction" + "tiled_loop = CAREamicsPredictionLoop(trainer)\n", + "trainer.predict_loop = tiled_loop" ] }, { @@ -215,7 +249,7 @@ "metadata": {}, "outputs": [], "source": [ - "imshow(preds.squeeze(), cmap=\"gray\")" + "preds = trainer.predict(model, datamodule=pred_data_module)" ] }, { @@ -223,7 +257,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Export to bioimage.io" + "### Visualize the prediction" ] }, { @@ -232,7 +266,13 @@ "metadata": {}, "outputs": [], "source": [ - "engine.save_as_bioimage(engine.cfg.experiment_name + \"bioimage.zip\")" + "image_idx = 0\n", + "_, subplot = plt.subplots(1, 2, figsize=(10, 10))\n", + "\n", + "subplot[0].imshow(preds[0].squeeze(), cmap=\"gray\")\n", + "subplot[0].set_title(\"Prediction\")\n", + "subplot[1].imshow(train_image, cmap=\"gray\")\n", + "subplot[1].set_title(\"Initial image\")" ] }, { @@ -259,7 +299,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.9.18" }, "vscode": { "interpreter": { diff --git a/examples/2D/n2v/n2v_2D_BSD.yml b/examples/2D/n2v/n2v_2D_BSD.yml new file mode 100644 index 00000000..fa2810a5 --- /dev/null +++ b/examples/2D/n2v/n2v_2D_BSD.yml @@ -0,0 +1,60 @@ +# Name, composed of letters, numbers, spaces, dashes and underscores +experiment_name: N2V_BSD_n2v2 + +algorithm: + algorithm: n2v + loss: n2v + model: + architecture: UNet + n2v2: False + optimizer_parameters: + lr: 1e-3 + lr_scheduler_parameters: + factor: 0.5 + patience: 10 + +training: + # Number of epochs, greater or equal than 1 + num_epochs: 1 + + # Batch size, greater or equal than 1 + batch_size: 128 + + # Optional, use WandB (True or False), must be installed in your conda environment + use_wandb: False + + callbacks: + # Optional, any parameter for Pytorch lightning callbacks + + # Optional, any parameter for Pytorch lightning ModelCheckpoint + checkpoint: + # Optional, monitor the validation loss (val_loss) or the validation accuracy (val_acc) + # parameters: + # dirpath: 'checkpoints' + # monitor : val_loss + save_last: True + save_top_k: 3 + # TODO should it be here or in predict ? + # Optional, path to the checkpoint to load the model + # load_path: '' + + # Optional, automatic mixed precision + amp: + # Use (True or False) + use: True + +data: + # Extension of the data + data_type: tiff + + patch_size: [64, 64] + + # Axes, among STCZYX with constraints on order + axes: SYX + + tta_transforms: True + + batch_size: 128 + + # Optional, number of workers for data loading, greater or equal than 0 + num_workers: 0 \ No newline at end of file diff --git a/examples/2D/n2v_2D_BSD.yml b/examples/2D/n2v_2D_BSD.yml deleted file mode 100644 index 7e376c89..00000000 --- a/examples/2D/n2v_2D_BSD.yml +++ /dev/null @@ -1,77 +0,0 @@ ---- -# Name, composed of letters, numbers, spaces, dashes and underscores -experiment_name: N2V_BSD - -# Working directory (logs, models, etc), its parent folder must exist -working_directory: n2v_bsd - -algorithm: - # Loss, currently only n2v is supported - loss: n2v - - # Model, currently only UNet is supported - model: UNet - - # Dimensions 2D (False) or 3D (True) - is_3D: false - - # Optional, parameters of the model - model_parameters: - # Number of filters of the first level, must be divisible by 2 - num_channels_init: 32 - -training: - # Number of epochs, greater or equal than 1 - num_epochs: 100 - - # Patch size, 2D or 3D, divisible by 2 - patch_size: [64, 64] - - # Batch size, greater or equal than 1 - batch_size: 128 - - # Optimizer - optimizer: - # Name, one of Adam or SGD - name: Adam - # Optional, parameters of the optimizer - # see https://pytorch.org/docs/stable/optim.html#algorithms - parameters: - lr: 0.0004 - - # Learning rate scheduler - lr_scheduler: - # Name, one of ReduceLROnPlateau or StepLR - name: ReduceLROnPlateau - # Optional, parameters of the learning rate scheduler - # see https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate - parameters: - factor: 0.5 - - # Use augmentation (True or False) - augmentation: true - - # Optional, use WandB (True or False), must be installed in your conda environment - use_wandb: false - - # Optional, number of workers for data loading, greater or equal than 0 - num_workers: 4 - - # Optional, automatic mixed precision - amp: - # Use (True or False) - use: true - - # Torch.compile (True or False) - compile: - use: false - -data: - # Controls the type of dataloader to use. Set to False if data won't fit into memory - in_memory: true - - # Extension of the data, one of npy, tiff and tif - data_format: tiff - - # Axes, among STCZYX with constraints on order - axes: SYX diff --git a/examples/2D/n2v_2D_SEM.yml b/examples/2D/pn2v/pN2V_Convallaria.yml similarity index 66% rename from examples/2D/n2v_2D_SEM.yml rename to examples/2D/pn2v/pN2V_Convallaria.yml index 34fd2507..bf0172b1 100644 --- a/examples/2D/n2v_2D_SEM.yml +++ b/examples/2D/pn2v/pN2V_Convallaria.yml @@ -1,63 +1,71 @@ ---- # Name, composed of letters, numbers, spaces, dashes and underscores -experiment_name: N2V_SEM +experiment_name: pN2V_Convallaria # Working directory (logs, models, etc), its parent folder must exist -working_directory: n2v_sem +working_directory: pN2V_Convallaria algorithm: + algorithm_type: pn2v # Loss, currently only n2v is supported - loss: n2v + loss: pn2v # Model, currently only UNet is supported - model: UNet + model: UNet + # Noise model, Histogram or GMM + noise_model: + model_type: hist + parameters: + min_value: 350 + max_value: 6500 + bins: 256 + # Dimensions 2D (False) or 3D (True) - is_3D: false + is_3D: False training: # Number of epochs, greater or equal than 1 - num_epochs: 100 + num_epochs: 10 # Patch size, 2D or 3D, divisible by 2 - patch_size: [64, 64] + patch_size: [64, 64] # Batch size, greater or equal than 1 - batch_size: 128 + batch_size: 128 # Optimizer - optimizer: + optimizer: # Name, one of Adam or SGD - name: Adam + name: Adam # Optional, parameters of the optimizer # see https://pytorch.org/docs/stable/optim.html#algorithms - parameters: - lr: 0.0004 + parameters: + lr: 0.0004 # Learning rate scheduler - lr_scheduler: + lr_scheduler: # Name, one of ReduceLROnPlateau or StepLR - name: ReduceLROnPlateau + name: ReduceLROnPlateau # Optional, parameters of the learning rate scheduler # see https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate - parameters: - factor: 0.5 + parameters: + factor: 0.5 # Use augmentation (True or False) - augmentation: true + augmentation: True # Optional, use WandB (True or False), must be installed in your conda environment - use_wandb: false + use_wandb: False # Optional, number of workers for data loading, greater or equal than 0 - num_workers: 4 + num_workers: 0 data: # Controls the type of dataloader to use. Set to False if data won't fit into memory - in_memory: true + in_memory: True # Extension of the data, one of npy, tiff and tif - data_format: tif + data_format: tiff # Axes, among STCZYX with constraints on order - axes: YX + axes: YX diff --git a/examples/3D/example_flywing_3D.ipynb b/examples/3D/example_flywing_3D.ipynb index f7bb1567..512f8ef9 100644 --- a/examples/3D/example_flywing_3D.ipynb +++ b/examples/3D/example_flywing_3D.ipynb @@ -6,17 +6,31 @@ "metadata": {}, "outputs": [], "source": [ - "import pprint\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import tifffile\n", "from careamics_portfolio import PortfolioManager\n", - "from itkwidgets import compare, view # \"pip install itkwidgets \"if necessary\n", - "from matplotlib.pyplot import imshow\n", "\n", - "from careamics.engine import Engine" + "# from itkwidgets import compare, view # \"pip install itkwidgets \"if necessary\n", + "from pytorch_lightning import Trainer\n", + "\n", + "from careamics import CAREamicsModule\n", + "from careamics.lightning_datamodule import (\n", + " CAREamicsPredictDataModule,\n", + " CAREamicsTrainDataModule,\n", + ")\n", + "from careamics.lightning_prediction import CAREamicsPredictionLoop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext autoreload" ] }, { @@ -86,7 +100,7 @@ "source": [ "train_image = tifffile.imread(next(iter(data_path.rglob(\"*.tif\"))))\n", "print(f\"Train image shape: {train_image.shape}\")\n", - "imshow(np.max(train_image, axis=0), cmap=\"magma\")" + "plt.imshow(np.max(train_image, axis=0), cmap=\"magma\")" ] }, { @@ -112,9 +126,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Initialize the Engine\n", + "### Initialize the Model\n", "\n", - "Engine contains the dataloading pipeline and the model training logic. We'll initialize the engine with the config file, but it can also be initialized from a pre-trained checkpoint.\n", + "Create a Pytorch Lightning module\n", "\n", "Please take as look at the [documentation](https://careamics.github.io) to see the full list of parameters and configuration options" ] @@ -125,15 +139,8 @@ "metadata": {}, "outputs": [], "source": [ - "engine = Engine(config_path=\"n2v_flywing_3D.yml\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Visualize training configuration" + "# N2V2 requires changes to the UNet model and to the Dataset (augmentations)\n", + "use_n2v2 = False # change to True to use N2V2" ] }, { @@ -142,7 +149,14 @@ "metadata": {}, "outputs": [], "source": [ - "pprint.PrettyPrinter(indent=2).pprint(engine.cfg.model_dump(exclude_optionals=False))" + "model = CAREamicsModule(\n", + " algorithm=\"n2v\",\n", + " loss=\"n2v\",\n", + " architecture=\"UNet\",\n", + " model_parameters={\"n2v2\": use_n2v2, \"conv_dims\": 3},\n", + " optimizer_parameters={\"lr\": 1e-3},\n", + " lr_scheduler_parameters={\"factor\": 0.5, \"patience\": 10},\n", + ")" ] }, { @@ -150,8 +164,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "### Initialize the datamodule\n", "\n", - "Start training" + "The data module can take a `Path` or `str` to a folder or file, or a `np.ndarray`.\n", + "\n", + "For custom types, you need to pass a read function and an extension_filter." ] }, { @@ -160,7 +177,17 @@ "metadata": {}, "outputs": [], "source": [ - "train_stats, val_stats = engine.train(train_path=data_path, val_path=data_path)" + "train_data_module = CAREamicsTrainDataModule(\n", + " train_data=train_image,\n", + " data_type=\"array\", # to use np.ndarray, set data_type to \"array\"\n", + " patch_size=(32, 64, 64),\n", + " axes=\"ZYX\",\n", + " batch_size=32,\n", + " dataloader_params={\"num_workers\": 0},\n", + " use_n2v2=use_n2v2,\n", + " struct_n2v_axis=\"none\", # choice between \"horizontal\", \"vertical\", or \"none\" (no # structN2V)\n", + " struct_n2v_span=7,\n", + ")" ] }, { @@ -168,7 +195,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Visualize statistics" + "### Run training \n", + "\n", + "We need to specify the paths to training and validation data" ] }, { @@ -177,10 +206,16 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot([next(iter(d.values())) for d in train_stats], label=\"Train loss\")\n", - "plt.plot([next(iter(d.values())) for d in val_stats], label=\"Validation loss\")\n", - "plt.legend(loc=\"best\")\n", - "plt.xlabel(\"Epoch\")" + "trainer = Trainer(max_epochs=1, default_root_dir=\"bsd_test\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.fit(model, datamodule=train_data_module)" ] }, { @@ -188,7 +223,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Predict" + "### Define a prediction datamodule" ] }, { @@ -197,7 +232,16 @@ "metadata": {}, "outputs": [], "source": [ - "preds = engine.predict(input=data_path, tile_shape=[32, 64, 64], overlaps=[24, 48, 48])" + "pred_data_module = CAREamicsPredictDataModule(\n", + " pred_data=train_image[:, :128, :128],\n", + " data_type=\"array\",\n", + " tile_size=(32, 64, 64),\n", + " tile_overlap=(16, 48, 48),\n", + " axes=\"ZYX\",\n", + " batch_size=1,\n", + " tta_transforms=True,\n", + " dataloader_params={\"num_workers\": 0},\n", + ")" ] }, { @@ -205,7 +249,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Visualize predictions\n" + "### Run prediction\n", + "\n", + "First, we want to use CAREamics prediction loop, which allows tiling:" ] }, { @@ -214,8 +260,8 @@ "metadata": {}, "outputs": [], "source": [ - "print(f\"Train image shape: {preds.shape}\")\n", - "imshow(np.max(preds.squeeze(), axis=0), cmap=\"magma\")" + "tiled_loop = CAREamicsPredictionLoop(trainer)\n", + "trainer.predict_loop = tiled_loop" ] }, { @@ -223,7 +269,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### [Optional] Visualize predictions in 3D" + "Then, we predict using the datamodule." ] }, { @@ -232,7 +278,7 @@ "metadata": {}, "outputs": [], "source": [ - "compare(train_image, preds.squeeze())" + "preds = trainer.predict(model, datamodule=pred_data_module)" ] }, { @@ -240,7 +286,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Save predictions" + "### Visualize predictions\n" ] }, { @@ -249,7 +295,8 @@ "metadata": {}, "outputs": [], "source": [ - "tifffile.imwrite(\"flywing_preds.tif\", preds.squeeze())" + "print(f\"Train image shape: {preds.shape}\")\n", + "plt.imshow(np.max(preds.squeeze(), axis=0), cmap=\"magma\")" ] }, { @@ -257,7 +304,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Export to bioimage.io" + "### [Optional] Visualize predictions in 3D" ] }, { @@ -266,7 +313,7 @@ "metadata": {}, "outputs": [], "source": [ - "engine.save_as_bioimage(engine.cfg.experiment_name + \"bioimage.zip\")" + "compare(train_image, preds.squeeze())" ] } ], @@ -286,7 +333,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.9.18" }, "vscode": { "interpreter": { diff --git a/examples/3D/n2v_3D.yml b/examples/3D/n2v_3D.yml deleted file mode 100644 index a8f3f71a..00000000 --- a/examples/3D/n2v_3D.yml +++ /dev/null @@ -1,74 +0,0 @@ ---- -# Name, composed of letters, numbers, spaces, dashes and underscores -experiment_name: N2V_BSD - -# Working directory (logs, models, etc), its parent folder must exist -working_directory: n2v_bsd - -algorithm: - # Loss, currently only n2v is supported - loss: n2v - - # Model, currently only UNet is supported - model: UNet - - # Dimensions 2D (False) or 3D (True) - is_3D: true - - # Optional, parameters of the model - model_parameters: - # Number of filters of the first level, must be divisible by 2 - num_channels_init: 32 - -training: - # Number of epochs, greater or equal than 1 - num_epochs: 20 - - # Patch size, 2D or 3D, divisible by 2 - patch_size: [32, 64, 64] - - # Batch size, greater or equal than 1 - batch_size: 16 - - # Optimizer - optimizer: - # Name, one of Adam or SGD - name: Adam - - # Optional, parameters of the optimizer - # see https://pytorch.org/docs/stable/optim.html#algorithms - parameters: - lr: 0.0004 - - # Learning rate scheduler - lr_scheduler: - # Name, one of ReduceLROnPlateau or StepLR - name: ReduceLROnPlateau - # Optional, parameters of the learning rate scheduler - # see https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate - parameters: - factor: 0.5 - - # Extraction strategy, one of random or sequential - extraction_strategy: random - - # Use augmentation (True or False) - augmentation: true - - # Optional, use WandB (True or False), must be installed in your conda environment - use_wandb: false - - # Optional, number of workers for data loading, greater or equal than 0 - num_workers: 4 - - # Optional, automatic mixed precision - amp: - # Use (True or False) - use: true - -data: - # Extension of the data, one of npy, tiff and tif - data_format: tif - - # Axes, among STCZYX with constraints on order - axes: ZYX diff --git a/examples/3D/n2v_flywing_3D.yml b/examples/3D/n2v_flywing_3D.yml index 946afd55..733ade15 100644 --- a/examples/3D/n2v_flywing_3D.yml +++ b/examples/3D/n2v_flywing_3D.yml @@ -1,4 +1,3 @@ ---- # Name, composed of letters, numbers, spaces, dashes and underscores experiment_name: N2V_flywing_3D @@ -7,68 +6,68 @@ working_directory: n2v_bsd algorithm: # Loss, currently only n2v is supported - loss: n2v + loss: n2v # Model, currently only UNet is supported - model: UNet + model: UNet # Dimensions 2D (False) or 3D (True) - is_3D: true + is_3D: True # Optional, parameters of the model - model_parameters: + model_parameters: # Number of filters of the first level, must be divisible by 2 - num_channels_init: 32 + num_channels_init: 32 training: # Number of epochs, greater or equal than 1 - num_epochs: 50 + num_epochs: 50 # Patch size, 2D or 3D, divisible by 2 - patch_size: [32, 64, 64] + patch_size: [32, 64, 64] # Batch size, greater or equal than 1 - batch_size: 4 + batch_size: 4 # Optimizer - optimizer: + optimizer: # Name, one of Adam or SGD - name: Adam + name: Adam # Optional, parameters of the optimizer # see https://pytorch.org/docs/stable/optim.html#algorithms - parameters: - lr: 0.0004 + parameters: + lr: 0.0004 # Learning rate scheduler - lr_scheduler: + lr_scheduler: # Name, one of ReduceLROnPlateau or StepLR - name: ReduceLROnPlateau + name: ReduceLROnPlateau # Optional, parameters of the learning rate scheduler # see https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate - parameters: - factor: 0.5 + parameters: + factor: 0.5 # Use augmentation (True or False) - augmentation: true + augmentation: True # Optional, use WandB (True or False), must be installed in your conda environment - use_wandb: false + use_wandb: False # Optional, number of workers for data loading, greater or equal than 0 - num_workers: 4 + num_workers: 4 # Optional, automatic mixed precision - amp: + amp: # Use (True or False) - use: true + use: True data: # Controls the type of dataloader to use. Set to False if data won't fit into memory - in_memory: true - + in_memory: True + # Extension of the data, one of npy, tiff and tif - data_format: tif + data_format: tif # Axes, among STCZYX with constraints on order - axes: ZYX + axes: ZYX diff --git a/examples/n2v_full_reference.yml b/examples/n2v_full_reference.yml deleted file mode 100644 index 8e5dfa90..00000000 --- a/examples/n2v_full_reference.yml +++ /dev/null @@ -1,93 +0,0 @@ ---- -# Name, composed of letters, numbers, spaces, dashes and underscores -experiment_name: project_name - -# Working directory (logs, models, etc), its parent folder must exist -working_directory: path/to/working_directory - -# Optional -# Absolute or relative (wrt working_directory) path to model -trained_model: best_model.pth - -algorithm: - # Loss, currently only n2v is supported - loss: n2v - - # Model, currently only UNet is supported - model: UNet - - # Dimensions 2D (False) or 3D (True) - is_3D: false - - # Optional, masking strategy, currently only default is supported - masking_strategy: default - # Optional, percentage of masked pixel per patch (between 0.1 and 20%) - masked_pixel_percentage: 0.2 - # Optional, parameters of the model - model_parameters: - # Depth betwen 1 and 10 - depth: 2 - # Number of filters of the first level, must be divisible by 2 - num_filter_base: 32 - -training: - # Number of epochs, greater or equal than 1 - num_epochs: 100 - - # Patch size, 2D or 3D, divisible by 2 - patch_size: [64, 64] - - # Batch size, greater or equal than 1 - batch_size: 128 - - # Optimizer - optimizer: - # Name, one of Adam or SGD - name: Adam - - # Optional, parameters of the optimizer - # see https://pytorch.org/docs/stable/optim.html#algorithms - parameters: - lr: 0.0004 - - # Learning rate scheduler - lr_scheduler: - # Name, one of ReduceLROnPlateau or StepLR - name: ReduceLROnPlateau - # Optional, parameters of the learning rate scheduler - # see https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate - parameters: - mode: min - patience: 10 - factor: 0.5 - - # Use augmentation (True or False) - augmentation: true - - # Optional, use WandB (True or False), must be installed in your conda environment - use_wandb: false - - # Optional, number of workers for data loading, greater or equal than 0 - num_workers: 0 - - # Optional, automatic mixed precision - amp: - # Use (True or False) - use: true - # Optional, scaling parameter for mixed precision training, power of 2 recommended. - # minimum is 512, maximum is 65536 - init_scale: 1024 - - # Torch.compile (True or False) - compile: - use: false - -data: - # Controls the type of dataloader to use. Set to False if data won't fit into memory - in_memory: false - - # Extension of the data, e.g. tiff or zarr - data_format: tiff - - # Axes, among STCZYX with constraints on order - axes: SYX diff --git a/pyproject.toml b/pyproject.toml index 913da3a0..bdddf8a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,15 @@ +# https://peps.python.org/pep-0517/ + [build-system] requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" +# read more about configuring hatch at: +# https://hatch.pypa.io/latest/config/build/ # https://hatch.pypa.io/latest/config/metadata/ [tool.hatch.version] source = "vcs" -# https://hatch.pypa.io/latest/config/build/ [tool.hatch.build.targets.wheel] only-include = ["src"] sources = ["src"] @@ -22,8 +25,6 @@ license = { text = "BSD-3-Clause" } authors = [ { name = 'Igor Zubarev', email = 'igor.zubarev@fht.org' }, { name = 'Joran Deschamps', email = 'joran.deschamps@fht.org' }, - { name = 'Vera Galinova', email = 'vera.galinova@fht.org' }, - { name = 'Mehdi Seifi', email = 'mehdi.seifi@fht.org' }, ] classifiers = [ "Development Status :: 3 - Alpha", @@ -32,74 +33,73 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", "License :: OSI Approved :: BSD License", "Typing :: Typed", ] dependencies = [ - 'torch', - 'torchvision', + 'torch>=2.0.0', + 'albumentations', + 'bioimageio.core>=0.6.0', 'tifffile', + 'psutil', + 'pydantic>=2.5', + 'pytorch_lightning>=2.2.0', 'pyyaml', - 'pydantic>=2.0', 'scikit-image', - 'bioimageio.core', 'zarr', ] [project.optional-dependencies] # development dependencies and tooling -dev = ["pre-commit", "pytest", "pytest-cov"] - -# for ci -test = ["pytest", "pytest-cov", "wandb"] - -# notebooks -notebooks = [ - "jupyter", - "careamics-portfolio", - "itkwidgets", - "torchsummary", - "ipython", - "wandb", -] - -# all -all = [ +dev = [ "pre-commit", "pytest", "pytest-cov", - "wandb", - "jupyter", - "careamics-portfolio", - "itkwidgets", - "torchsummary", - "ipython", + "sybil", # doctesting ] +# notebooks +examples = ["jupyter", "careamics-portfolio", "matplotlib"] + +# loggers +wandb = ["wandb"] +tensorboard = ["tensorboard", "protobuf==3.20.3"] + [project.urls] homepage = "https://careamics.github.io/" repository = "https://github.com/CAREamics/careamics" -# https://docs.astral.sh/ruff/ +# https://beta.ruff.rs/docs [tool.ruff] line-length = 88 target-version = "py38" src = ["src"] select = [ - "E", # style errors - "W", # style warnings - "F", # flakes - "I", # isort - "UP", # pyupgrade + "E", # style errors + "W", # style warnings + "F", # flakes + "D", # pydocstyle + "I", # isort + "UP", # pyupgrade + # "S", # bandit "C4", # flake8-comprehensions "B", # flake8-bugbear "A001", # flake8-builtins "RUF", # ruff-specific rules - "TCH", # flake8-type-checking - "TID", # flake8-tidy-imports ] ignore = [ + "D100", # Missing docstring in public module + "D107", # Missing docstring in __init__ + "D203", # 1 blank line required before class docstring + "D212", # Multi-line docstring summary should start at the first line + "D213", # Multi-line docstring summary should start at the second line + "D401", # First line should be in imperative mood + "D413", # Missing blank line after last section + "D416", # Section name should end with a colon + + # incompatibility with mypy + "RUF005", # collection-literal-concatenation, in prediction_utils.py:30 + # version specific "UP006", # Replace typing.List by list, mandatory for py3.8 "UP007", # Replace Union by |, mandatory for py3.9 @@ -107,8 +107,12 @@ ignore = [ ignore-init-module-imports = true show-fixes = true +[tool.ruff.pydocstyle] +convention = "numpy" + [tool.ruff.per-file-ignores] -"tests/*.py" = ["S"] +"tests/*.py" = ["D", "S"] +"setup.py" = ["D"] [tool.black] line-length = 88 @@ -117,35 +121,32 @@ line-length = 88 [tool.mypy] files = "src/**/" strict = false -allow_untyped_defs = false -allow_untyped_calls = false -disallow_any_generics = false -ignore_missing_imports = false +# allow_untyped_defs = false +# allow_untyped_calls = false +# disallow_any_generics = false +# ignore_missing_imports = false + # https://docs.pytest.org/en/6.2.x/customize.html [tool.pytest.ini_options] minversion = "6.0" -testpaths = ["tests"] +testpaths = ["src/careamics", "tests"] # add src/careamics for doctest discovery filterwarnings = [ # "error", # "ignore::UserWarning", ] -markers = ["gpu: mark tests as requiring gpu"] +addopts = "-p no:doctest" +markers = ["gpu: marks tests as requiring gpu"] # https://coverage.readthedocs.io/en/6.4/config.html [tool.coverage.report] exclude_lines = [ "pragma: no cover", "if TYPE_CHECKING:", "@overload", + "except ImportError", "\\.\\.\\.", - "except ImportError:", "raise NotImplementedError()", - "except PackageNotFoundError:", - "if torch.cuda.is_available():", - "except UsageError as e:", - "except ModuleNotFoundError:", - "except KeyboardInterrupt:", ] [tool.coverage.run] @@ -159,16 +160,21 @@ ignore = [ ".github_changelog_generator", ".pre-commit-config.yaml", ".ruff_cache/**/*", + "setup.py", "tests/**/*", ] -# https://numpydoc.readthedocs.io/en/latest/format.html [tool.numpydoc_validation] checks = [ - "all", # report on all checks, except the ones below - "EX01", # No examples section found + "all", # report on all checks, except the below + "EX01", # Example section not found "SA01", # See Also section not found - "ES01", # No extended summary found + "ES01", # Extended Summar not found + "GL01", # Docstring text (summary) should start in the line immediately + # after the opening quotes + "GL02", # Closing quotes should be placed in the line after the last text + # in the docstring + "GL03", # Double line break found ] exclude = [ # don't report on objects that match any of these regex "test_*", diff --git a/src/careamics/__init__.py b/src/careamics/__init__.py index 668fdf7c..222aff45 100644 --- a/src/careamics/__init__.py +++ b/src/careamics/__init__.py @@ -1,5 +1,4 @@ -"""Main module.""" - +"""Main CAREamics module.""" from importlib.metadata import PackageNotFoundError, version @@ -8,7 +7,18 @@ except PackageNotFoundError: __version__ = "uninstalled" -__all__ = ["Engine", "Configuration", "load_configuration", "save_configuration"] +__all__ = [ + "CAREamist", + "CAREamicsModule", + "Configuration", + "load_configuration", + "save_configuration", + "CAREamicsTrainDataModule", + "CAREamicsPredictDataModule", +] +from .careamist import CAREamist from .config import Configuration, load_configuration, save_configuration -from .engine import Engine as Engine +from .lightning_datamodule import CAREamicsTrainDataModule +from .lightning_module import CAREamicsModule +from .lightning_prediction_datamodule import CAREamicsPredictDataModule diff --git a/src/careamics/bioimage/__init__.py b/src/careamics/bioimage/__init__.py deleted file mode 100644 index 675ab54e..00000000 --- a/src/careamics/bioimage/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Provide utilities for exporting models to BioImage model zoo.""" - -__all__ = [ - "save_bioimage_model", - "import_bioimage_model", - "get_default_model_specs", - "PYTORCH_STATE_DICT", -] - -from .io import ( - PYTORCH_STATE_DICT, - import_bioimage_model, - save_bioimage_model, -) -from .rdf import get_default_model_specs diff --git a/src/careamics/bioimage/docs/Noise2Void.md b/src/careamics/bioimage/docs/Noise2Void.md deleted file mode 100644 index 6f49132c..00000000 --- a/src/careamics/bioimage/docs/Noise2Void.md +++ /dev/null @@ -1,5 +0,0 @@ -## Noise2Void -Learning Denoising From Single Noisy Images - -## Cite Noise2Void -A. Krull, T.-O. Buchholz and F. Jug, "Noise2Void - Learning Denoising From Single Noisy Images," 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2019, pp. 2124-2132 \ No newline at end of file diff --git a/src/careamics/bioimage/docs/__init__.py b/src/careamics/bioimage/docs/__init__.py deleted file mode 100644 index a82f161c..00000000 --- a/src/careamics/bioimage/docs/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Default algorithm READMEs for bioimage.io format export.""" diff --git a/src/careamics/bioimage/io.py b/src/careamics/bioimage/io.py deleted file mode 100644 index f4d6ad92..00000000 --- a/src/careamics/bioimage/io.py +++ /dev/null @@ -1,182 +0,0 @@ -"""Export to bioimage.io format.""" -from pathlib import Path -from typing import Union - -import torch -from bioimageio.core import load_resource_description -from bioimageio.core.build_spec import build_model - -from careamics.config.config import Configuration -from careamics.utils.context import cwd - -PYTORCH_STATE_DICT = "pytorch_state_dict" - - -def save_bioimage_model( - path: Union[str, Path], - config: Configuration, - specs: dict, -) -> None: - """ - Build bioimage model zip file from model RDF data. - - Parameters - ---------- - path : Union[str, Path] - Path to the model zip file. - config : Configuration - Configuration object. - specs : dict - Model RDF dict. - """ - workdir = config.working_directory - - # temporary folder - temp_folder = Path.home().joinpath(".careamics", "bmz_tmp") - temp_folder.mkdir(exist_ok=True, parents=True) - - # change working directory to the temp folder - with cwd(temp_folder): - # load best checkpoint - checkpoint_path = workdir.joinpath( - f"{config.experiment_name}_best.pth" - ).absolute() - checkpoint = torch.load(checkpoint_path, map_location="cpu") - - # save chekpoint entries in separate files - weight_path = Path("model_weights.pth") - torch.save(checkpoint["model_state_dict"], weight_path) - - optim_path = Path("optim.pth") - torch.save(checkpoint["optimizer_state_dict"], optim_path) - - scheduler_path = Path("scheduler.pth") - torch.save(checkpoint["scheduler_state_dict"], scheduler_path) - - grad_path = Path("grad.pth") - torch.save(checkpoint["grad_scaler_state_dict"], grad_path) - - config_path = Path("config.pth") - torch.save(config.model_dump(), config_path) - - # create attachments - attachments = [ - str(optim_path), - str(scheduler_path), - str(grad_path), - str(config_path), - ] - - # create requirements file - requirements = Path("requirements.txt") - with open(requirements, "w") as f: - f.write("git+https://github.com/CAREamics/careamics.git") - - algo_config = config.algorithm - specs.update( - { - "weight_type": PYTORCH_STATE_DICT, - "weight_uri": str(weight_path), - "architecture": "careamics.models.unet.UNet", - "pytorch_version": torch.__version__, - "model_kwargs": { - "conv_dim": algo_config.get_conv_dim(), - "depth": algo_config.model_parameters.depth, - "num_channels_init": algo_config.model_parameters.num_channels_init, - }, - "dependencies": "pip:" + str(requirements), - "attachments": {"files": attachments}, - } - ) - - if config.algorithm.is_3D: - specs["tags"].append("3D") - else: - specs["tags"].append("2D") - - # build model zip - build_model( - output_path=Path(path).absolute(), - **specs, - ) - - # remove temporary files - for file in temp_folder.glob("*"): - file.unlink() - - # delete temporary folder - temp_folder.rmdir() - - -def import_bioimage_model(model_path: Union[str, Path]) -> Path: - """ - Load configuration and weights from a bioimage zip model. - - Parameters - ---------- - model_path : Union[str, Path] - Path to the bioimage.io archive. - - Returns - ------- - Path - Path to the checkpoint. - - Raises - ------ - ValueError - If the model format is invalid. - FileNotFoundError - If the checkpoint file was not found. - """ - model_path = Path(model_path) - - # check the model extension (should be a zip file). - if model_path.suffix != ".zip": - raise ValueError("Invalid model format. Expected bioimage model zip file.") - - # load the model - rdf = load_resource_description(model_path) - - # create a valid checkpoint file from weights and attached files - basedir = model_path.parent.joinpath("rdf_model") - basedir.mkdir(exist_ok=True) - optim_path = None - scheduler_path = None - grad_path = None - config_path = None - weight_path = None - - if rdf.weights.get(PYTORCH_STATE_DICT) is not None: - weight_path = rdf.weights.get(PYTORCH_STATE_DICT).source - - for file in rdf.attachments.files: - if file.name.endswith("optim.pth"): - optim_path = file - elif file.name.endswith("scheduler.pth"): - scheduler_path = file - elif file.name.endswith("grad.pth"): - grad_path = file - elif file.name.endswith("config.pth"): - config_path = file - - if ( - weight_path is None - or optim_path is None - or scheduler_path is None - or grad_path is None - or config_path is None - ): - raise FileNotFoundError(f"No valid checkpoint file was found in {model_path}.") - - checkpoint = { - "model_state_dict": torch.load(weight_path, map_location="cpu"), - "optimizer_state_dict": torch.load(optim_path, map_location="cpu"), - "scheduler_state_dict": torch.load(scheduler_path, map_location="cpu"), - "grad_scaler_state_dict": torch.load(grad_path, map_location="cpu"), - "config": torch.load(config_path, map_location="cpu"), - } - checkpoint_path = basedir.joinpath("checkpoint.pth") - torch.save(checkpoint, checkpoint_path) - - return checkpoint_path diff --git a/src/careamics/bioimage/rdf.py b/src/careamics/bioimage/rdf.py deleted file mode 100644 index be42d4e2..00000000 --- a/src/careamics/bioimage/rdf.py +++ /dev/null @@ -1,105 +0,0 @@ -"""RDF related methods.""" -from pathlib import Path - - -def _get_model_doc(name: str) -> str: - """ - Return markdown documentation path for the provided model. - - Parameters - ---------- - name : str - Model's name. - - Returns - ------- - str - Path to the model's markdown documentation. - - Raises - ------ - FileNotFoundError - If the documentation file was not found. - """ - doc = Path(__file__).parent.joinpath("docs").joinpath(f"{name}.md") - if doc.exists(): - return str(doc.absolute()) - else: - raise FileNotFoundError(f"Documentation for {name} was not found.") - - -def get_default_model_specs( - name: str, mean: float, std: float, is_3D: bool = False -) -> dict: - """ - Return the default bioimage.io specs for the provided model's name. - - Currently only supports `Noise2Void` model. - - Parameters - ---------- - name : str - Algorithm's name. - mean : float - Mean of the dataset. - std : float - Std of the dataset. - is_3D : bool, optional - Whether the model is 3D or not, by default False. - - Returns - ------- - dict - Model specs compatible with bioimage.io export. - """ - rdf = { - "name": "Noise2Void", - "description": "Self-supervised denoising.", - "license": "BSD-3-Clause", - "authors": [ - {"name": "Alexander Krull"}, - {"name": "Tim-Oliver Buchholz"}, - {"name": "Florian Jug"}, - ], - "cite": [ - { - "doi": "10.48550/arXiv.1811.10980", - "text": 'A. Krull, T.-O. Buchholz and F. Jug, "Noise2Void - Learning ' - 'Denoising From Single Noisy Images," 2019 IEEE/CVF ' - "Conference on Computer Vision and Pattern Recognition " - "(CVPR), 2019, pp. 2124-2132", - } - ], - # "input_axes": ["bcyx"], <- overriden in save_as_bioimage - "preprocessing": [ # for multiple inputs - [ # multiple processes per input - { - "kwargs": { - "axes": "zyx" if is_3D else "yx", - "mean": [mean], - "mode": "fixed", - "std": [std], - }, - "name": "zero_mean_unit_variance", - } - ] - ], - # "output_axes": ["bcyx"], <- overriden in save_as_bioimage - "postprocessing": [ # for multiple outputs - [ # multiple processes per input - { - "kwargs": { - "axes": "zyx" if is_3D else "yx", - "gain": [std], - "offset": [mean], - }, - "name": "scale_linear", - } - ] - ], - "tags": ["unet", "denoising", "Noise2Void", "tensorflow", "napari"], - } - - rdf["documentation"] = _get_model_doc(name) - - return rdf diff --git a/src/careamics/callbacks/__init__.py b/src/careamics/callbacks/__init__.py new file mode 100644 index 00000000..74a82442 --- /dev/null +++ b/src/careamics/callbacks/__init__.py @@ -0,0 +1,6 @@ +"""Callbacks module.""" + +__all__ = ["HyperParametersCallback", "ProgressBarCallback"] + +from .hyperparameters_callback import HyperParametersCallback +from .progress_bar_callback import ProgressBarCallback diff --git a/src/careamics/callbacks/hyperparameters_callback.py b/src/careamics/callbacks/hyperparameters_callback.py new file mode 100644 index 00000000..d0609077 --- /dev/null +++ b/src/careamics/callbacks/hyperparameters_callback.py @@ -0,0 +1,42 @@ +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import Callback + +from careamics.config import Configuration + + +class HyperParametersCallback(Callback): + """ + Callback allowing saving CAREamics configuration as hyperparameters in the model. + + This allows saving the configuration as dictionnary in the checkpoints, and + loading it subsequently in a CAREamist instance. + + Attributes + ---------- + config : Configuration + CAREamics configuration to be saved as hyperparameter in the model. + """ + + def __init__(self, config: Configuration): + """ + Constructor. + + Parameters + ---------- + config : Configuration + CAREamics configuration to be saved as hyperparameter in the model. + """ + self.config = config + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule): + """ + Update the hyperparameters of the model with the configuration on train start. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer. + pl_module : LightningModule + PyTorch Lightning module. + """ + pl_module.hparams.update(self.config.model_dump()) diff --git a/src/careamics/callbacks/progress_bar_callback.py b/src/careamics/callbacks/progress_bar_callback.py new file mode 100644 index 00000000..d7862091 --- /dev/null +++ b/src/careamics/callbacks/progress_bar_callback.py @@ -0,0 +1,57 @@ +import sys +from typing import Dict, Union + +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import TQDMProgressBar +from tqdm import tqdm + + +class ProgressBarCallback(TQDMProgressBar): + """Progress bar for training and validation steps.""" + + def init_train_tqdm(self) -> tqdm: + """Override this to customize the tqdm bar for training.""" + bar = tqdm( + desc="Training", + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout, + smoothing=0, + ) + return bar + + def init_validation_tqdm(self) -> tqdm: + """Override this to customize the tqdm bar for validation.""" + # The main progress bar doesn't exist in `trainer.validate()` + has_main_bar = self.train_progress_bar is not None + bar = tqdm( + desc="Validating", + position=(2 * self.process_position + has_main_bar), + disable=self.is_disabled, + leave=False, + dynamic_ncols=True, + file=sys.stdout, + ) + return bar + + def init_test_tqdm(self) -> tqdm: + """Override this to customize the tqdm bar for testing.""" + bar = tqdm( + desc="Testing", + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=False, + ncols=100, + file=sys.stdout, + ) + return bar + + def get_metrics( + self, trainer: Trainer, pl_module: LightningModule + ) -> Dict[str, Union[int, str, float, Dict[str, float]]]: + """Override this to customize the metrics displayed in the progress bar.""" + pbar_metrics = trainer.progress_bar_metrics + return {**pbar_metrics} diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py new file mode 100644 index 00000000..34a10009 --- /dev/null +++ b/src/careamics/careamist.py @@ -0,0 +1,761 @@ +"""A class to train, predict and export models in CAREamics.""" + +from pathlib import Path +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, overload + +import numpy as np +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ( + Callback, + EarlyStopping, + ModelCheckpoint, +) +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger + +from careamics.callbacks import ProgressBarCallback +from careamics.config import ( + Configuration, + create_inference_configuration, + load_configuration, +) +from careamics.config.inference_model import TRANSFORMS_UNION +from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger +from careamics.lightning_datamodule import CAREamicsWood +from careamics.lightning_module import CAREamicsKiln +from careamics.lightning_prediction_datamodule import CAREamicsClay +from careamics.lightning_prediction_loop import CAREamicsPredictionLoop +from careamics.model_io import export_to_bmz, load_pretrained +from careamics.utils import check_path_exists, get_logger + +from .callbacks import HyperParametersCallback + +logger = get_logger(__name__) + +LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]] + + +# TODO napari callbacks +# TODO: how to do AMP? How to continue training? +class CAREamist: + """Main CAREamics class, allowing training and prediction using various algorithms. + + Parameters + ---------- + source : Union[Path, str, Configuration] + Path to a configuration file or a trained model. + work_dir : Optional[str], optional + Path to working directory in which to save checkpoints and logs, + by default None. + experiment_name : str, optional + Experiment name used for checkpoints, by default "CAREamics". + + Attributes + ---------- + model : CAREamicsKiln + CAREamics model. + cfg : Configuration + CAREamics configuration. + trainer : Trainer + PyTorch Lightning trainer. + experiment_logger : Optional[Union[TensorBoardLogger, WandbLogger]] + Experiment logger, "wandb" or "tensorboard". + work_dir : Path + Working directory. + train_datamodule : Optional[CAREamicsWood] + Training datamodule. + pred_datamodule : Optional[CAREamicsClay] + Prediction datamodule. + """ + + @overload + def __init__( # numpydoc ignore=GL08 + self, + source: Union[Path, str], + work_dir: Optional[str] = None, + experiment_name: str = "CAREamics", + ) -> None: + ... + + @overload + def __init__( # numpydoc ignore=GL08 + self, + source: Configuration, + work_dir: Optional[str] = None, + experiment_name: str = "CAREamics", + ) -> None: + ... + + def __init__( + self, + source: Union[Path, str, Configuration], + work_dir: Optional[Union[Path, str]] = None, + experiment_name: str = "CAREamics", + ) -> None: + """ + Initialize CAREamist with a configuration object or a path. + + A configuration object can be created using directly by calling `Configuration`, + using the configuration factory or loading a configuration from a yaml file. + + Path can contain either a yaml file with parameters, or a saved checkpoint. + + If no working directory is provided, the current working directory is used. + + If `source` is a checkpoint, then `experiment_name` is used to name the + checkpoint, and is recorded in the configuration. + + Parameters + ---------- + source : Union[Path, str, Configuration] + Path to a configuration file or a trained model. + work_dir : Optional[str], optional + Path to working directory in which to save checkpoints and logs, + by default None. + experiment_name : str, optional + Experiment name used for checkpoints, by default "CAREamics". + + Raises + ------ + NotImplementedError + If the model is loaded from BioImage Model Zoo. + ValueError + If no hyper parameters are found in the checkpoint. + ValueError + If no data module hyper parameters are found in the checkpoint. + """ + super().__init__() + + # select current working directory if work_dir is None + if work_dir is None: + self.work_dir = Path.cwd() + logger.warning( + f"No working directory provided. Using current working directory: " + f"{self.work_dir}." + ) + else: + self.work_dir = Path(work_dir) + + # configuration object + if isinstance(source, Configuration): + self.cfg = source + + # instantiate model + self.model = CAREamicsKiln( + algorithm_config=self.cfg.algorithm_config, + ) + + # path to configuration file or model + else: + source = check_path_exists(source) + + # configuration file + if source.is_file() and ( + source.suffix == ".yaml" or source.suffix == ".yml" + ): + # load configuration + self.cfg = load_configuration(source) + + # instantiate model + self.model = CAREamicsKiln( + algorithm_config=self.cfg.algorithm_config, + ) + + # attempt loading a pre-trained model + else: + self.model, self.cfg = load_pretrained(source) + + # define the checkpoint saving callback + self.callbacks = self._define_callbacks() + + # instantiate logger + if self.cfg.training_config.has_logger(): + if self.cfg.training_config.logger == SupportedLogger.WANDB: + self.experiment_logger: LOGGER_TYPES = WandbLogger( + name=experiment_name, + save_dir=self.work_dir / Path("logs"), + ) + elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD: + self.experiment_logger = TensorBoardLogger( + save_dir=self.work_dir / Path("logs"), + ) + else: + self.experiment_logger = None + + # instantiate trainer + self.trainer = Trainer( + max_epochs=self.cfg.training_config.num_epochs, + callbacks=self.callbacks, + default_root_dir=self.work_dir, + logger=self.experiment_logger, + ) + + # change the prediction loop, necessary for tiled prediction + self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer) + + # place holder for the datamodules + self.train_datamodule: Optional[CAREamicsWood] = None + self.pred_datamodule: Optional[CAREamicsClay] = None + + def _define_callbacks(self) -> List[Callback]: + """ + Define the callbacks for the training loop. + + Returns + ------- + List[Callback] + List of callbacks to be used during training. + """ + # checkpoint callback saves checkpoints during training + self.callbacks = [ + HyperParametersCallback(self.cfg), + ModelCheckpoint( + dirpath=self.work_dir / Path("checkpoints"), + filename=self.cfg.experiment_name, + **self.cfg.training_config.checkpoint_callback.model_dump(), + ), + ProgressBarCallback(), + ] + + # early stopping callback + if self.cfg.training_config.early_stopping_callback is not None: + self.callbacks.append( + EarlyStopping(self.cfg.training_config.early_stopping_callback) + ) + + return self.callbacks + + def train( + self, + *, + datamodule: Optional[CAREamicsWood] = None, + train_source: Optional[Union[Path, str, np.ndarray]] = None, + val_source: Optional[Union[Path, str, np.ndarray]] = None, + train_target: Optional[Union[Path, str, np.ndarray]] = None, + val_target: Optional[Union[Path, str, np.ndarray]] = None, + use_in_memory: bool = True, + val_percentage: float = 0.1, + val_minimum_split: int = 1, + ) -> None: + """ + Train the model on the provided data. + + If a datamodule is provided, then training will be performed using it. + Alternatively, the training data can be provided as arrays or paths. + + If `use_in_memory` is set to True, the source provided as Path or str will be + loaded in memory if it fits. Otherwise, training will be performed by loading + patches from the files one by one. Training on arrays is always performed + in memory. + + If no validation source is provided, then the validation is extracted from + the training data using `val_percentage` and `val_minimum_split`. In the case + of data provided as Path or str, the percentage and minimum number are applied + to the number of files. For arrays, it is the number of patches. + + Parameters + ---------- + datamodule : Optional[CAREamicsWood], optional + Datamodule to train on, by default None. + train_source : Optional[Union[Path, str, np.ndarray]], optional + Train source, if no datamodule is provided, by default None. + val_source : Optional[Union[Path, str, np.ndarray]], optional + Validation source, if no datamodule is provided, by default None. + train_target : Optional[Union[Path, str, np.ndarray]], optional + Train target source, if no datamodule is provided, by default None. + val_target : Optional[Union[Path, str, np.ndarray]], optional + Validation target source, if no datamodule is provided, by default None. + use_in_memory : bool, optional + Use in memory dataset if possible, by default True. + val_percentage : float, optional + Percentage of validation extracted from training data, by default 0.1. + val_minimum_split : int, optional + Minimum number of validation (patch or file) extracted from training data, + by default 1. + + Raises + ------ + ValueError + If both `datamodule` and `train_source` are provided. + ValueError + If sources are not of the same type (e.g. train is an array and val is + a Path). + ValueError + If the training target is provided to N2V. + ValueError + If neither a datamodule nor a source is provided. + """ + if datamodule is not None and train_source: + raise ValueError( + "Only one of `datamodule` and `train_source` can be provided." + ) + + # check that inputs are the same type + source_types = { + type(s) + for s in (train_source, val_source, train_target, val_target) + if s is not None + } + if len(source_types) > 1: + raise ValueError("All sources should be of the same type.") + + # train + if datamodule is not None: + self._train_on_datamodule(datamodule=datamodule) + + else: + # raise error if target is provided to N2V + if self.cfg.algorithm_config.algorithm == SupportedAlgorithm.N2V.value: + if train_target is not None: + raise ValueError( + "Training target not compatible with N2V training." + ) + + # dispatch the training + if isinstance(train_source, np.ndarray): + # mypy checks + assert isinstance(val_source, np.ndarray) or val_source is None + assert isinstance(train_target, np.ndarray) or train_target is None + assert isinstance(val_target, np.ndarray) or val_target is None + + self._train_on_array( + train_source, + val_source, + train_target, + val_target, + val_percentage, + val_minimum_split, + ) + + elif isinstance(train_source, Path) or isinstance(train_source, str): + # mypy checks + assert ( + isinstance(val_source, Path) + or isinstance(val_source, str) + or val_source is None + ) + assert ( + isinstance(train_target, Path) + or isinstance(train_target, str) + or train_target is None + ) + assert ( + isinstance(val_target, Path) + or isinstance(val_target, str) + or val_target is None + ) + + self._train_on_path( + train_source, + val_source, + train_target, + val_target, + use_in_memory, + val_percentage, + val_minimum_split, + ) + + else: + raise ValueError( + f"Invalid input, expected a str, Path, array or CAREamicsWood " + f"instance (got {type(train_source)})." + ) + + def _train_on_datamodule(self, datamodule: CAREamicsWood) -> None: + """ + Train the model on the provided datamodule. + + Parameters + ---------- + datamodule : CAREamicsWood + Datamodule to train on. + """ + # record datamodule + self.train_datamodule = datamodule + + self.trainer.fit(self.model, datamodule=datamodule) + + def _train_on_array( + self, + train_data: np.ndarray, + val_data: Optional[np.ndarray] = None, + train_target: Optional[np.ndarray] = None, + val_target: Optional[np.ndarray] = None, + val_percentage: float = 0.1, + val_minimum_split: int = 5, + ) -> None: + """ + Train the model on the provided data arrays. + + Parameters + ---------- + train_data : np.ndarray + Training data. + val_data : Optional[np.ndarray], optional + Validation data, by default None. + train_target : Optional[np.ndarray], optional + Train target data, by default None. + val_target : Optional[np.ndarray], optional + Validation target data, by default None. + val_percentage : float, optional + Percentage of patches to use for validation, by default 0.1. + val_minimum_split : int, optional + Minimum number of patches to use for validation, by default 5. + """ + # create datamodule + datamodule = CAREamicsWood( + data_config=self.cfg.data_config, + train_data=train_data, + val_data=val_data, + train_data_target=train_target, + val_data_target=val_target, + val_percentage=val_percentage, + val_minimum_split=val_minimum_split, + ) + + # train + self.train(datamodule=datamodule) + + def _train_on_path( + self, + path_to_train_data: Union[Path, str], + path_to_val_data: Optional[Union[Path, str]] = None, + path_to_train_target: Optional[Union[Path, str]] = None, + path_to_val_target: Optional[Union[Path, str]] = None, + use_in_memory: bool = True, + val_percentage: float = 0.1, + val_minimum_split: int = 1, + ) -> None: + """ + Train the model on the provided data paths. + + Parameters + ---------- + path_to_train_data : Union[Path, str] + Path to the training data. + path_to_val_data : Optional[Union[Path, str]], optional + Path to validation data, by default None. + path_to_train_target : Optional[Union[Path, str]], optional + Path to train target data, by default None. + path_to_val_target : Optional[Union[Path, str]], optional + Path to validation target data, by default None. + use_in_memory : bool, optional + Use in memory dataset if possible, by default True. + val_percentage : float, optional + Percentage of files to use for validation, by default 0.1. + val_minimum_split : int, optional + Minimum number of files to use for validation, by default 1. + """ + # sanity check on data (path exists) + path_to_train_data = check_path_exists(path_to_train_data) + + if path_to_val_data is not None: + path_to_val_data = check_path_exists(path_to_val_data) + + if path_to_train_target is not None: + path_to_train_target = check_path_exists(path_to_train_target) + + if path_to_val_target is not None: + path_to_val_target = check_path_exists(path_to_val_target) + + # create datamodule + datamodule = CAREamicsWood( + data_config=self.cfg.data_config, + train_data=path_to_train_data, + val_data=path_to_val_data, + train_data_target=path_to_train_target, + val_data_target=path_to_val_target, + use_in_memory=use_in_memory, + val_percentage=val_percentage, + val_minimum_split=val_minimum_split, + ) + + # train + self.train(datamodule=datamodule) + + @overload + def predict( # numpydoc ignore=GL08 + self, + source: CAREamicsClay, + *, + checkpoint: Optional[Literal["best", "last"]] = None, + ) -> Union[list, np.ndarray]: + ... + + @overload + def predict( # numpydoc ignore=GL08 + self, + source: Union[Path, str], + *, + batch_size: int = 1, + tile_size: Optional[Tuple[int, ...]] = None, + tile_overlap: Tuple[int, ...] = (48, 48), + axes: Optional[str] = None, + data_type: Optional[Literal["tiff", "custom"]] = None, + transforms: Optional[List[TRANSFORMS_UNION]] = None, + tta_transforms: bool = True, + dataloader_params: Optional[Dict] = None, + read_source_func: Optional[Callable] = None, + extension_filter: str = "", + checkpoint: Optional[Literal["best", "last"]] = None, + ) -> Union[list, np.ndarray]: + ... + + @overload + def predict( # numpydoc ignore=GL08 + self, + source: np.ndarray, + *, + batch_size: int = 1, + tile_size: Optional[Tuple[int, ...]] = None, + tile_overlap: Tuple[int, ...] = (48, 48), + axes: Optional[str] = None, + data_type: Optional[Literal["array"]] = None, + transforms: Optional[List[TRANSFORMS_UNION]] = None, + tta_transforms: bool = True, + dataloader_params: Optional[Dict] = None, + checkpoint: Optional[Literal["best", "last"]] = None, + ) -> Union[list, np.ndarray]: + ... + + def predict( + self, + source: Union[CAREamicsClay, Path, str, np.ndarray], + *, + batch_size: int = 1, + tile_size: Optional[Tuple[int, ...]] = None, + tile_overlap: Tuple[int, ...] = (48, 48), + axes: Optional[str] = None, + data_type: Optional[Literal["array", "tiff", "custom"]] = None, + transforms: Optional[List[TRANSFORMS_UNION]] = None, + tta_transforms: bool = True, + dataloader_params: Optional[Dict] = None, + read_source_func: Optional[Callable] = None, + extension_filter: str = "", + checkpoint: Optional[Literal["best", "last"]] = None, + **kwargs: Any, + ) -> Union[List[np.ndarray], np.ndarray]: + """ + Make predictions on the provided data. + + Input can be a CAREamicsClay instance, a path to a data file, or a numpy array. + + If `data_type`, `axes` and `tile_size` are not provided, the training + configuration parameters will be used, with the `patch_size` instead of + `tile_size`. + + The default transforms are defined in the `InferenceModel` Pydantic model. + + Test-time augmentation (TTA) can be switched off using the `tta_transforms` + parameter. + + Parameters + ---------- + source : Union[CAREamicsClay, Path, str, np.ndarray] + Data to predict on. + batch_size : int, optional + Batch size for prediction, by default 1. + tile_size : Optional[Tuple[int, ...]], optional + Size of the tiles to use for prediction, by default None. + tile_overlap : Tuple[int, ...], optional + Overlap between tiles, by default (48, 48). + axes : Optional[str], optional + Axes of the input data, by default None. + data_type : Optional[Literal["array", "tiff", "custom"]], optional + Type of the input data, by default None. + transforms : Optional[List[TRANSFORMS_UNION]], optional + List of transforms to apply to the data, by default None. + tta_transforms : bool, optional + Whether to apply test-time augmentation, by default True. + dataloader_params : Optional[Dict], optional + Parameters to pass to the dataloader, by default None. + read_source_func : Optional[Callable], optional + Function to read the source data, by default None. + extension_filter : str, optional + Filter for the file extension, by default "". + checkpoint : Optional[Literal["best", "last"]], optional + Checkpoint to use for prediction, by default None. + **kwargs : Any + Unused. + + Returns + ------- + Union[List[np.ndarray], np.ndarray] + Predictions made by the model. + + Raises + ------ + ValueError + If the input is not a CAREamicsClay instance, a path or a numpy array. + """ + if isinstance(source, CAREamicsClay): + # record datamodule + self.pred_datamodule = source + + return self.trainer.predict( + model=self.model, datamodule=source, ckpt_path=checkpoint + ) + else: + if self.cfg is None: + raise ValueError( + "No configuration found. Train a model or load from a " + "checkpoint before predicting." + ) + # create predict config, reuse training config if parameters missing + prediction_config = create_inference_configuration( + training_configuration=self.cfg, + tile_size=tile_size, + tile_overlap=tile_overlap, + data_type=data_type, + axes=axes, + transforms=transforms, + tta_transforms=tta_transforms, + batch_size=batch_size, + ) + + # remove batch from dataloader parameters (priority given to config) + if dataloader_params is None: + dataloader_params = {} + if "batch_size" in dataloader_params: + del dataloader_params["batch_size"] + + if isinstance(source, Path) or isinstance(source, str): + # Check the source + source_path = check_path_exists(source) + + # create datamodule + datamodule = CAREamicsClay( + prediction_config=prediction_config, + pred_data=source_path, + read_source_func=read_source_func, + extension_filter=extension_filter, + dataloader_params=dataloader_params, + ) + + # record datamodule + self.pred_datamodule = datamodule + + return self.trainer.predict( + model=self.model, datamodule=datamodule, ckpt_path=checkpoint + ) + + elif isinstance(source, np.ndarray): + # create datamodule + datamodule = CAREamicsClay( + prediction_config=prediction_config, + pred_data=source, + dataloader_params=dataloader_params, + ) + + # record datamodule + self.pred_datamodule = datamodule + + return self.trainer.predict( + model=self.model, datamodule=datamodule, ckpt_path=checkpoint + ) + + else: + raise ValueError( + f"Invalid input. Expected a CAREamicsWood instance, paths or " + f"np.ndarray (got {type(source)})." + ) + + def export_to_bmz( + self, + path: Union[Path, str], + name: str, + authors: List[dict], + input_array: Optional[np.ndarray] = None, + general_description: str = "", + channel_names: Optional[List[str]] = None, + data_description: Optional[str] = None, + ) -> None: + """Export the model to the BioImage Model Zoo format. + + Input array must be of shape SC(Z)YX, with S and C singleton dimensions. + + Parameters + ---------- + path : Union[Path, str] + Path to save the model. + name : str + Name of the model. + authors : List[dict] + List of authors of the model. + input_array : Optional[np.ndarray], optional + Input array for the model, must be of shape SC(Z)YX, by default None. + general_description : str + General description of the model, used in the metadata of the BMZ archive. + channel_names : Optional[List[str]], optional + Channel names, by default None. + data_description : Optional[str], optional + Description of the data, by default None. + """ + if input_array is None: + # generate images, priority is given to the prediction data module + if self.pred_datamodule is not None: + # unpack a batch, ignore masks or targets + input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader())) + + # convert torch.Tensor to numpy + input_patch = input_patch.numpy() + elif self.train_datamodule is not None: + input_patch, *_ = next(iter(self.train_datamodule.train_dataloader())) + input_patch = input_patch.numpy() + else: + if ( + self.cfg.data_config.mean is None + or self.cfg.data_config.std is None + ): + raise ValueError( + "Mean and std cannot be None in the configuration in order to" + "export to the BMZ format. Was the model trained?" + ) + + # create a random input array + input_patch = np.random.normal( + loc=self.cfg.data_config.mean, + scale=self.cfg.data_config.std, + size=self.cfg.data_config.patch_size, + ).astype(np.float32)[ + np.newaxis, np.newaxis, ... + ] # add S & C dimensions + else: + input_patch = input_array + + # if there is a batch dimension + if input_patch.shape[0] > 1: + input_patch = input_patch[0:1, ...] # keep singleton dim + + # axes need to be reformated for the export because reshaping was done in the + # datamodule + if "Z" in self.cfg.data_config.axes: + axes = "SCZYX" + else: + axes = "SCYX" + + # predict output, remove extra dimensions for the purpose of the prediction + output_patch = self.predict( + input_patch, + data_type=SupportedData.ARRAY.value, + axes=axes, + tta_transforms=False, + ) + + if not isinstance(output_patch, np.ndarray): + raise ValueError( + f"Numpy array required for export to BioImage Model Zoo, got " + f"{type(output_patch)}." + ) + + export_to_bmz( + model=self.model, + config=self.cfg, + path=path, + name=name, + general_description=general_description, + authors=authors, + input_array=input_patch, + output_array=output_patch, + channel_names=channel_names, + data_description=data_description, + ) diff --git a/src/careamics/config/__init__.py b/src/careamics/config/__init__.py index db683e78..9ee70b7e 100644 --- a/src/careamics/config/__init__.py +++ b/src/careamics/config/__init__.py @@ -1,11 +1,35 @@ """Configuration module.""" -__all__ = ["Configuration", "load_configuration", "save_configuration"] +__all__ = [ + "AlgorithmModel", + "DataModel", + "Configuration", + "CheckpointModel", + "InferenceModel", + "load_configuration", + "save_configuration", + "TrainingModel", + "create_n2v_configuration", + "register_model", + "CustomModel", + "create_inference_configuration", + "clear_custom_models", + "ConfigurationInformation", +] -from .config import ( +from .algorithm_model import AlgorithmModel +from .architectures import CustomModel, clear_custom_models, register_model +from .callback_model import CheckpointModel +from .configuration_factory import ( + create_inference_configuration, + create_n2v_configuration, +) +from .configuration_model import ( Configuration, load_configuration, save_configuration, ) -from .torch_optim import get_parameters as get_parameters +from .data_model import DataModel +from .inference_model import InferenceModel +from .training_model import TrainingModel diff --git a/src/careamics/config/algorithm.py b/src/careamics/config/algorithm.py deleted file mode 100644 index 0237b5d9..00000000 --- a/src/careamics/config/algorithm.py +++ /dev/null @@ -1,231 +0,0 @@ -"""Algorithm configuration.""" -from enum import Enum -from typing import Dict, List - -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from .config_filter import remove_default_optionals - - -# python 3.11: https://docs.python.org/3/library/enum.html -class Loss(str, Enum): - """ - Available loss functions. - - Currently supported losses: - - - n2v: Noise2Void loss. - """ - - N2V = "n2v" - - -class Models(str, Enum): - """ - Available models. - - Currently supported models: - - UNet: U-Net model. - """ - - UNET = "UNet" - - -class MaskingStrategy(str, Enum): - """ - Available masking strategy. - - Currently supported strategies: - - - default: default masking strategy of Noise2Void (uniform sampling of neighbors). - - median: median masking strategy of N2V2. - """ - - DEFAULT = "default" - MEDIAN = "median" - - -class ModelParameters(BaseModel): - """ - Deep-learning model parameters. - - The number of filters (base) must be even and minimum 8. - - Attributes - ---------- - depth : int - Depth of the model, between 1 and 10 (default 2). - num_channels_init : int - Number of filters of the first level of the network, should be even - and minimum 8 (default 96). - """ - - model_config = ConfigDict(validate_assignment=True) - - depth: int = Field(default=2, ge=1, le=10) - num_channels_init: int = Field(default=32, ge=8) - - # TODO revisit the constraints on num_channels_init - @field_validator("num_channels_init") - def even(cls, num_channels: int) -> int: - """ - Validate that num_channels_init is even. - - Parameters - ---------- - num_channels : int - Number of channels. - - Returns - ------- - int - Validated number of channels. - - Raises - ------ - ValueError - If the number of channels is odd. - """ - # if odd - if num_channels % 2 != 0: - raise ValueError( - f"Number of channels (init) must be even (got {num_channels})." - ) - - return num_channels - - -class Algorithm(BaseModel): - """ - Algorithm configuration. - - The minimum algorithm configuration is composed of the following fields: - - loss: - Loss to use, currently only supports n2v. - - model: - Model to use, currently only supports UNet. - - is_3D: - Whether to use a 3D model or not, this should be coherent with the - data configuration (axes). - - Other optional fields are: - - masking_strategy: - Masking strategy to use, currently only supports default masking. - - masked_pixel_percentage: - Percentage of pixels to be masked in each patch. - - roi_size: - Size of the region of interest to use in the masking algorithm. - - model_parameters: - Model parameters, see ModelParameters for more details. - - Attributes - ---------- - loss : List[Losses] - List of losses to use, currently only supports n2v. - model : Models - Model to use, currently only supports UNet. - is_3D : bool - Whether to use a 3D model or not. - masking_strategy : MaskingStrategies - Masking strategy to use, currently only supports default masking. - masked_pixel_percentage : float - Percentage of pixels to be masked in each patch. - roi_size : int - Size of the region of interest used in the masking scheme. - model_parameters : ModelParameters - Model parameters, see ModelParameters for more details. - """ - - # Pydantic class configuration - model_config = ConfigDict( - use_enum_values=True, - protected_namespaces=(), # allows to use model_* as a field name - validate_assignment=True, - ) - - # Mandatory fields - loss: Loss - model: Models - is_3D: bool - - # Optional fields, define a default value - masking_strategy: MaskingStrategy = MaskingStrategy.DEFAULT - masked_pixel_percentage: float = Field(default=0.2, ge=0.1, le=20) - roi_size: int = Field(default=11, ge=3, le=21) - model_parameters: ModelParameters = ModelParameters() - - def get_conv_dim(self) -> int: - """ - Get the convolution layers dimension (2D or 3D). - - Returns - ------- - int - Dimension (2 or 3). - """ - return 3 if self.is_3D else 2 - - @field_validator("roi_size") - def even(cls, roi_size: int) -> int: - """ - Validate that roi_size is odd. - - Parameters - ---------- - roi_size : int - Size of the region of interest in the masking scheme. - - Returns - ------- - int - Validated size of the region of interest. - - Raises - ------ - ValueError - If the size of the region of interest is even. - """ - # if even - if roi_size % 2 == 0: - raise ValueError(f"ROI size must be odd (got {roi_size}).") - - return roi_size - - def model_dump( - self, exclude_optionals: bool = True, *args: List, **kwargs: Dict - ) -> Dict: - """ - Override model_dump method. - - The purpose is to ensure export smooth import to yaml. It includes: - - remove entries with None value. - - remove optional values if they have the default value. - - Parameters - ---------- - exclude_optionals : bool, optional - Whether to exclude optional arguments if they are default, by default True. - *args : List - Positional arguments, unused. - **kwargs : Dict - Keyword arguments, unused. - - Returns - ------- - Dict - Dictionary representation of the model. - """ - dictionary = super().model_dump(exclude_none=True) - - if exclude_optionals is True: - # remove optional arguments if they are default - defaults = { - "masking_strategy": MaskingStrategy.DEFAULT.value, - "masked_pixel_percentage": 0.2, - "roi_size": 11, - "model_parameters": ModelParameters().model_dump(exclude_none=True), - } - - remove_default_optionals(dictionary, defaults) - - return dictionary diff --git a/src/careamics/config/algorithm_model.py b/src/careamics/config/algorithm_model.py new file mode 100644 index 00000000..41906528 --- /dev/null +++ b/src/careamics/config/algorithm_model.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from pprint import pformat +from typing import Literal, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Self + +from .architectures import CustomModel, UNetModel, VAEModel +from .optimizer_models import LrSchedulerModel, OptimizerModel + + +class AlgorithmModel(BaseModel): + """Algorithm configuration. + + This Pydantic model validates the parameters governing the components of the + training algorithm: which algorithm, loss function, model architecture, optimizer, + and learning rate scheduler to use. + + Currently, we only support N2V and custom algorithms. The `n2v` algorithm is only + compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm allows + you to register your own architecture and select it using its name as + `name` in the custom pydantic model. + + Attributes + ---------- + algorithm : Literal["n2v", "custom"] + Algorithm to use. + loss : Literal["n2v", "mae", "mse"] + Loss function to use. + model : Union[UNetModel, VAEModel, CustomModel] + Model architecture to use. + optimizer : OptimizerModel, optional + Optimizer to use. + lr_scheduler : LrSchedulerModel, optional + Learning rate scheduler to use. + + Raises + ------ + ValueError + Algorithm parameter type validation errors. + ValueError + If the algorithm, loss and model are not compatible. + + Examples + -------- + Minimum example: + >>> from careamics.config import AlgorithmModel + >>> config_dict = { + ... "algorithm": "n2v", + ... "loss": "n2v", + ... "model": { + ... "architecture": "UNet", + ... } + ... } + >>> config = AlgorithmModel(**config_dict) + + Using a custom model: + >>> from torch import nn, ones + >>> from careamics.config import AlgorithmModel, register_model + ... + >>> @register_model(name="linear_model") + ... class LinearModel(nn.Module): + ... def __init__(self, in_features, out_features, *args, **kwargs): + ... super().__init__() + ... self.in_features = in_features + ... self.out_features = out_features + ... self.weight = nn.Parameter(ones(in_features, out_features)) + ... self.bias = nn.Parameter(ones(out_features)) + ... def forward(self, input): + ... return (input @ self.weight) + self.bias + ... + >>> config_dict = { + ... "algorithm": "custom", + ... "loss": "mse", + ... "model": { + ... "architecture": "Custom", + ... "name": "linear_model", + ... "in_features": 10, + ... "out_features": 5, + ... } + ... } + >>> config = AlgorithmModel(**config_dict) + """ + + # Pydantic class configuration + model_config = ConfigDict( + protected_namespaces=(), # allows to use model_* as a field name + validate_assignment=True, + ) + + # Mandatory fields + algorithm: Literal["n2v", "care", "n2n", "custom"] # defined in SupportedAlgorithm + loss: Literal["n2v", "mae", "mse"] + model: Union[UNetModel, VAEModel, CustomModel] = Field(discriminator="architecture") + + # Optional fields + optimizer: OptimizerModel = OptimizerModel() + lr_scheduler: LrSchedulerModel = LrSchedulerModel() + + @model_validator(mode="after") + def algorithm_cross_validation(self: Self) -> Self: + """Validate the algorithm model based on `algorithm`. + + N2V: + - loss must be n2v + - model must be a `UNetModel` + + Returns + ------- + Self + The validated model. + """ + # N2V + if self.algorithm == "n2v": + # n2v is only compatible with the n2v loss + if self.loss != "n2v": + raise ValueError( + f"Algorithm {self.algorithm} only supports loss `n2v`." + ) + + # n2v is only compatible with the UNet model + if not isinstance(self.model, UNetModel): + raise ValueError( + f"Model for algorithm {self.algorithm} must be a `UNetModel`." + ) + + # n2v requires the number of input and output channels to be the same + if self.model.in_channels != self.model.num_classes: + raise ValueError( + "N2V requires the same number of input and output channels. Make " + "sure that `in_channels` and `num_classes` are the same." + ) + + # N2N + if self.algorithm == "n2n": + # n2n is only compatible with the UNet model + if not isinstance(self.model, UNetModel): + raise ValueError( + f"Model for algorithm {self.algorithm} must be a `UNetModel`." + ) + + # n2n requires the number of input and output channels to be the same + if self.model.in_channels != self.model.num_classes: + raise ValueError( + "N2N requires the same number of input and output channels. Make " + "sure that `in_channels` and `num_classes` are the same." + ) + + if self.algorithm == "care" or self.algorithm == "n2n": + if self.loss == "n2v": + raise ValueError("Supervised algorithms do not support loss `n2v`.") + + if isinstance(self.model, VAEModel): + raise ValueError("VAE are currently not implemented.") + + return self + + def __str__(self) -> str: + """Pretty string representing the configuration. + + Returns + ------- + str + Pretty string. + """ + return pformat(self.model_dump()) diff --git a/src/careamics/config/architectures/__init__.py b/src/careamics/config/architectures/__init__.py new file mode 100644 index 00000000..c65d97f0 --- /dev/null +++ b/src/careamics/config/architectures/__init__.py @@ -0,0 +1,17 @@ +"""Deep-learning model configurations.""" + +__all__ = [ + "ArchitectureModel", + "CustomModel", + "UNetModel", + "VAEModel", + "clear_custom_models", + "get_custom_model", + "register_model", +] + +from .architecture_model import ArchitectureModel +from .custom_model import CustomModel +from .register_model import clear_custom_models, get_custom_model, register_model +from .unet_model import UNetModel +from .vae_model import VAEModel diff --git a/src/careamics/config/architectures/architecture_model.py b/src/careamics/config/architectures/architecture_model.py new file mode 100644 index 00000000..28113112 --- /dev/null +++ b/src/careamics/config/architectures/architecture_model.py @@ -0,0 +1,29 @@ +from typing import Any, Dict + +from pydantic import BaseModel + + +class ArchitectureModel(BaseModel): + """ + Base Pydantic model for all model architectures. + + The `model_dump` method allows removing the `architecture` key from the model. + """ + + architecture: str + + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + """ + Dump the model as a dictionary, ignoring the architecture keyword. + + Returns + ------- + dict[str, Any] + Model as a dictionnary. + """ + model_dict = super().model_dump(**kwargs) + + # remove the architecture key + model_dict.pop("architecture") + + return model_dict diff --git a/src/careamics/config/architectures/custom_model.py b/src/careamics/config/architectures/custom_model.py new file mode 100644 index 00000000..3290344a --- /dev/null +++ b/src/careamics/config/architectures/custom_model.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from pprint import pformat +from typing import Any, Dict, Literal + +from pydantic import ConfigDict, field_validator, model_validator +from torch.nn import Module +from typing_extensions import Self + +from .architecture_model import ArchitectureModel +from .register_model import get_custom_model + + +class CustomModel(ArchitectureModel): + """Custom model configuration. + + This Pydantic model allows storing parameters for a custom model. In order for the + model to be valid, the specific model needs to be registered using the + `register_model` decorator, and its name correctly passed to this model + configuration (see Examples). + + Attributes + ---------- + architecture : Literal["Custom"] + Discriminator for the custom model, must be set to "Custom". + name : str + Name of the custom model. + parameters : CustomParametersModel + Parameters of the custom model. + + Raises + ------ + ValueError + If the custom model `name` is unknown. + ValueError + If the custom model is not a torch Module subclass. + ValueError + If the custom model parameters are not valid. + + Examples + -------- + >>> from torch import nn, ones + >>> from careamics.config import CustomModel, register_model + >>> # Register a custom model + >>> @register_model(name="my_linear") + ... class LinearModel(nn.Module): + ... def __init__(self, in_features, out_features, *args, **kwargs): + ... super().__init__() + ... self.in_features = in_features + ... self.out_features = out_features + ... self.weight = nn.Parameter(ones(in_features, out_features)) + ... self.bias = nn.Parameter(ones(out_features)) + ... def forward(self, input): + ... return (input @ self.weight) + self.bias + ... + >>> # Create a configuration + >>> config_dict = { + ... "architecture": "Custom", + ... "name": "my_linear", + ... "in_features": 10, + ... "out_features": 5, + ... } + >>> config = CustomModel(**config_dict) + """ + + # pydantic model config + model_config = ConfigDict( + extra="allow", + ) + + # discriminator used for choosing the pydantic model in Model + architecture: Literal["Custom"] + + # name of the custom model + name: str + + @field_validator("name") + @classmethod + def custom_model_is_known(cls, value: str) -> str: + """Check whether the custom model is known. + + Parameters + ---------- + value : str + Name of the custom model as registered using the `@register_model` + decorator. + """ + # delegate error to get_custom_model + model = get_custom_model(value) + + # check if it is a torch Module subclass + if not issubclass(model, Module): + raise ValueError( + f'Retrieved class {model} with name "{value}" is not a ' + f"torch.nn.Module subclass." + ) + + return value + + @model_validator(mode="after") + def check_parameters(self: Self) -> Self: + """Validate model by instantiating the model with the parameters. + + Returns + ------- + Self + The validated model. + """ + # instantiate model + try: + get_custom_model(self.name)(**self.model_dump()) + except Exception as e: + raise ValueError( + f"error while passing parameters to the model {e}. Verify that all " + f"mandatory parameters are provided, and that either the {e} accepts " + f"*args and **kwargs in its __init__() method, or that no additional" + f"parameter is provided." + ) from None + + return self + + def __str__(self) -> str: + """Pretty string representing the configuration. + + Returns + ------- + str + Pretty string. + """ + return pformat(self.model_dump()) + + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + """Dump the model configuration. + + Parameters + ---------- + kwargs : Any + Additional keyword arguments from Pydantic BaseModel model_dump method. + + Returns + ------- + Dict[str, Any] + Model configuration. + """ + model_dict = super().model_dump() + + # remove the name key + model_dict.pop("name") + + return model_dict diff --git a/src/careamics/config/architectures/register_model.py b/src/careamics/config/architectures/register_model.py new file mode 100644 index 00000000..f35b0b88 --- /dev/null +++ b/src/careamics/config/architectures/register_model.py @@ -0,0 +1,101 @@ +from typing import Callable + +from torch.nn import Module + +CUSTOM_MODELS = {} # dictionary of custom models {"name": __class__} + + +def register_model(name: str) -> Callable: + """Decorator used to register a torch.nn.Module class with a given `name`. + + Parameters + ---------- + name : str + Name of the model. + + Returns + ------- + Callable + Function allowing to instantiate the wrapped Module class. + + Raises + ------ + ValueError + If a model is already registered with that name. + + Examples + -------- + ```python + @register_model(name="linear") + class LinearModel(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + + self.weight = nn.Parameter(ones(in_features, out_features)) + self.bias = nn.Parameter(ones(out_features)) + + def forward(self, input): + return (input @ self.weight) + self.bias + ``` + """ + if name is None or name == "": + raise ValueError("Model name cannot be empty.") + + if name in CUSTOM_MODELS: + raise ValueError( + f"Model {name} already exists. Choose a different name or run " + f"`clear_custom_models()` to empty the registry." + ) + + def add_custom_model(model: Module) -> Module: + """Add a custom model to the registry and return it. + + Parameters + ---------- + model : Module + Module class to register + + Returns + ------- + Module + The registered model. + """ + # add model to the registry + CUSTOM_MODELS[name] = model + + return model + + return add_custom_model + + +def get_custom_model(name: str) -> Module: + """Get the custom model corresponding to `name` from the registry. + + Parameters + ---------- + name : str + Name of the model to retrieve. + + Returns + ------- + Module + The requested model. + + Raises + ------ + ValueError + If the model is not registered. + """ + if name not in CUSTOM_MODELS: + raise ValueError( + f"Model {name} is unknown. Have you registered it using " + f'@register_model("{name}") as decorator?' + ) + + return CUSTOM_MODELS[name] + + +def clear_custom_models() -> None: + """Clear the custom models registry.""" + # clear dictionary + CUSTOM_MODELS.clear() diff --git a/src/careamics/config/architectures/unet_model.py b/src/careamics/config/architectures/unet_model.py new file mode 100644 index 00000000..9a032e2b --- /dev/null +++ b/src/careamics/config/architectures/unet_model.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from typing import Literal + +from pydantic import ConfigDict, Field, field_validator + +from .architecture_model import ArchitectureModel + + +# TODO tests activation <-> pydantic model, test the literals! +# TODO annotations for the json schema? +class UNetModel(ArchitectureModel): + """ + Pydantic model for a N2V(2)-compatible UNet. + + Attributes + ---------- + depth : int + Depth of the model, between 1 and 10 (default 2). + num_channels_init : int + Number of filters of the first level of the network, should be even + and minimum 8 (default 96). + """ + + # pydantic model config + model_config = ConfigDict(validate_assignment=True) + + # discriminator used for choosing the pydantic model in Model + architecture: Literal["UNet"] + + # parameters + # validate_defaults allow ignoring default values in the dump if they were not set + conv_dims: Literal[2, 3] = Field(default=2, validate_default=True) + num_classes: int = Field(default=1, ge=1, validate_default=True) + in_channels: int = Field(default=1, ge=1, validate_default=True) + depth: int = Field(default=2, ge=1, le=10, validate_default=True) + num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True) + final_activation: Literal[ + "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU" + ] = Field(default="None", validate_default=True) + n2v2: bool = Field(default=False, validate_default=True) + + @field_validator("num_channels_init") + @classmethod + def validate_num_channels_init(cls, num_channels_init: int) -> int: + """ + Validate that num_channels_init is even. + + Parameters + ---------- + num_channels_init : int + Number of channels. + + Returns + ------- + int + Validated number of channels. + + Raises + ------ + ValueError + If the number of channels is odd. + """ + # if odd + if num_channels_init % 2 != 0: + raise ValueError( + f"Number of channels for the bottom layer must be even" + f" (got {num_channels_init})." + ) + + return num_channels_init + + def set_3D(self, is_3D: bool) -> None: + """ + Set 3D model by setting the `conv_dims` parameters. + + Parameters + ---------- + is_3D : bool + Whether the algorithm is 3D or not. + """ + if is_3D: + self.conv_dims = 3 + else: + self.conv_dims = 2 + + def is_3D(self) -> bool: + """ + Return whether the model is 3D or not. + + Returns + ------- + bool + Whether the model is 3D or not. + """ + return self.conv_dims == 3 diff --git a/src/careamics/config/architectures/vae_model.py b/src/careamics/config/architectures/vae_model.py new file mode 100644 index 00000000..03c7eb60 --- /dev/null +++ b/src/careamics/config/architectures/vae_model.py @@ -0,0 +1,39 @@ +from typing import Literal + +from pydantic import ( + ConfigDict, +) + +from .architecture_model import ArchitectureModel + + +class VAEModel(ArchitectureModel): + """VAE model placeholder.""" + + model_config = ConfigDict( + use_enum_values=True, protected_namespaces=(), validate_assignment=True + ) + + architecture: Literal["VAE"] + + def set_3D(self, is_3D: bool) -> None: + """ + Set 3D model by setting the `conv_dims` parameters. + + Parameters + ---------- + is_3D : bool + Whether the algorithm is 3D or not. + """ + raise NotImplementedError("VAE is not implemented yet.") + + def is_3D(self) -> bool: + """ + Return whether the model is 3D or not. + + Returns + ------- + bool + Whether the model is 3D or not. + """ + raise NotImplementedError("VAE is not implemented yet.") diff --git a/src/careamics/config/callback_model.py b/src/careamics/config/callback_model.py new file mode 100644 index 00000000..a4b522ce --- /dev/null +++ b/src/careamics/config/callback_model.py @@ -0,0 +1,92 @@ +"""Checkpoint saving configuration.""" +from __future__ import annotations + +from datetime import timedelta +from typing import Literal, Optional + +from pydantic import ( + BaseModel, + ConfigDict, + Field, +) + + +class CheckpointModel(BaseModel): + """_summary_. + + Parameters + ---------- + BaseModel : _type_ + _description_ + """ + + model_config = ConfigDict( + validate_assignment=True, + ) + + monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True) + verbose: bool = Field(default=False, validate_default=True) + save_weights_only: bool = Field(default=False, validate_default=True) + mode: Literal["min", "max"] = Field(default="min", validate_default=True) + auto_insert_metric_name: bool = Field(default=False, validate_default=True) + every_n_train_steps: Optional[int] = Field( + default=None, ge=1, le=10, validate_default=True + ) + train_time_interval: Optional[timedelta] = Field( + default=None, validate_default=True + ) + every_n_epochs: Optional[int] = Field( + default=None, ge=1, le=10, validate_default=True + ) + save_last: Optional[Literal[True, False, "link"]] = Field( + default=True, validate_default=True + ) + save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True) + + +class EarlyStoppingModel(BaseModel): + """_summary_. + + Parameters + ---------- + BaseModel : _type_ + _description_ + """ + + model_config = ConfigDict( + validate_assignment=True, + ) + + monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True) + patience: int = Field(default=3, ge=1, le=10, validate_default=True) + mode: Literal["min", "max", "auto"] = Field(default="min", validate_default=True) + min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True) + check_finite: bool = Field(default=True, validate_default=True) + stop_on_nan: bool = Field(default=True, validate_default=True) + verbose: bool = Field(default=False, validate_default=True) + restore_best_weights: bool = Field(default=True, validate_default=True) + auto_lr_find: bool = Field(default=False, validate_default=True) + auto_lr_find_patience: int = Field(default=3, ge=1, le=10, validate_default=True) + auto_lr_find_mode: Literal["min", "max", "auto"] = Field( + default="min", validate_default=True + ) + auto_lr_find_direction: Literal["forward", "backward"] = Field( + default="backward", validate_default=True + ) + auto_lr_find_max_lr: float = Field( + default=10.0, ge=0.0, le=1e6, validate_default=True + ) + auto_lr_find_min_lr: float = Field( + default=1e-8, ge=0.0, le=1e6, validate_default=True + ) + auto_lr_find_num_training: int = Field( + default=100, ge=1, le=1e6, validate_default=True + ) + auto_lr_find_divergence_threshold: float = Field( + default=5.0, ge=0.0, le=1e6, validate_default=True + ) + auto_lr_find_accumulate_grad_batches: int = Field( + default=1, ge=1, le=1e6, validate_default=True + ) + auto_lr_find_stop_divergence: bool = Field(default=True, validate_default=True) + auto_lr_find_step_scale: float = Field(default=0.1, ge=0.0, le=10) diff --git a/src/careamics/config/config.py b/src/careamics/config/config.py deleted file mode 100644 index 4e467cd9..00000000 --- a/src/careamics/config/config.py +++ /dev/null @@ -1,297 +0,0 @@ -"""Pydantic CAREamics configuration.""" -from __future__ import annotations - -import re -from pathlib import Path -from typing import Dict, List, Union - -import yaml -from pydantic import ( - BaseModel, - ConfigDict, - field_validator, - model_validator, -) - -# ignore typing-only-first-party-import in this file (flake8) -from .algorithm import Algorithm # noqa: TCH001 -from .config_filter import paths_to_str -from .data import Data # noqa: TCH001 -from .training import Training # noqa: TCH001 - - -class Configuration(BaseModel): - """ - CAREamics configuration. - - To change the configuration from 2D to 3D, we recommend using the following method: - >>> set_3D(is_3D, axes) - - Attributes - ---------- - experiment_name : str - Name of the experiment. - working_directory : Union[str, Path] - Path to the working directory. - algorithm : Algorithm - Algorithm configuration. - training : Training - Training configuration. - """ - - model_config = ConfigDict(validate_assignment=True) - - # required parameters - experiment_name: str - working_directory: Path - - # Sub-configurations - algorithm: Algorithm - data: Data - training: Training - - def set_3D(self, is_3D: bool, axes: str) -> None: - """ - Set 3D flag and axes. - - Parameters - ---------- - is_3D : bool - Whether the algorithm is 3D or not. - axes : str - Axes of the data. - """ - # set the flag and axes (this will not trigger validation at the config level) - self.algorithm.is_3D = is_3D - self.data.axes = axes - - # cheap hack: trigger validation - self.algorithm = self.algorithm - - @field_validator("experiment_name") - def no_symbol(cls, name: str) -> str: - """ - Validate experiment name. - - A valid experiment name is a non-empty string with only contains letters, - numbers, underscores, dashes and spaces. - - Parameters - ---------- - name : str - Name to validate. - - Returns - ------- - str - Validated name. - - Raises - ------ - ValueError - If the name is empty or contains invalid characters. - """ - if len(name) == 0 or name.isspace(): - raise ValueError("Experiment name is empty.") - - # Validate using a regex that it contains only letters, numbers, underscores, - # dashes and spaces - if not re.match(r"^[a-zA-Z0-9_\- ]*$", name): - raise ValueError( - f"Experiment name contains invalid characters (got {name}). " - f"Only letters, numbers, underscores, dashes and spaces are allowed." - ) - - return name - - @field_validator("working_directory") - def parent_directory_exists(cls, workdir: Union[str, Path]) -> Path: - """ - Validate working directory. - - A valid working directory is a directory whose parent directory exists. If the - working directory does not exist itself, it is then created. - - Parameters - ---------- - workdir : Union[str, Path] - Working directory to validate. - - Returns - ------- - Path - Validated working directory. - - Raises - ------ - ValueError - If the working directory is not a directory, or if the parent directory does - not exist. - """ - path = Path(workdir) - - # check if it is a directory - if path.exists() and not path.is_dir(): - raise ValueError(f"Working directory is not a directory (got {workdir}).") - - # check if parent directory exists - if not path.parent.exists(): - raise ValueError( - f"Parent directory of working directory does not exist (got {workdir})." - ) - - # create directory if it does not exist already - path.mkdir(exist_ok=True) - - return path - - @model_validator(mode="after") - def validate_3D(cls, config: Configuration) -> Configuration: - """ - Check 3D flag validity. - - Check that the algorithm is_3D flag is compatible with the axes in the - data configuration. - - Parameters - ---------- - config : Configuration - Configuration to validate. - - Returns - ------- - Configuration - Validated configuration. - - Raises - ------ - ValueError - If the algorithm is 3D but the data axes are not, or if the algorithm is - not 3D but the data axes are. - """ - # check that is_3D and axes are compatible - if config.algorithm.is_3D and "Z" not in config.data.axes: - raise ValueError( - f"Algorithm is 3D but data axes are not (got axes {config.data.axes})." - ) - elif not config.algorithm.is_3D and "Z" in config.data.axes: - raise ValueError( - f"Algorithm is not 3D but data axes are (got axes {config.data.axes})." - ) - - return config - - def model_dump( - self, exclude_optionals: bool = True, *args: List, **kwargs: Dict - ) -> Dict: - """ - Override model_dump method. - - The purpose is to ensure export smooth import to yaml. It includes: - - remove entries with None value. - - remove optional values if they have the default value. - - Parameters - ---------- - exclude_optionals : bool, optional - Whether to exclude optional fields with default values or not, by default - True. - *args : List - Positional arguments, unused. - **kwargs : Dict - Keyword arguments, unused. - - Returns - ------- - dict - Dictionary containing the model parameters. - """ - dictionary = super().model_dump(exclude_none=True) - - # remove paths - dictionary = paths_to_str(dictionary) - - dictionary["algorithm"] = self.algorithm.model_dump( - exclude_optionals=exclude_optionals - ) - dictionary["data"] = self.data.model_dump() - - dictionary["training"] = self.training.model_dump( - exclude_optionals=exclude_optionals - ) - - return dictionary - - -def load_configuration(path: Union[str, Path]) -> Configuration: - """ - Load configuration from a yaml file. - - Parameters - ---------- - path : Union[str, Path] - Path to the configuration. - - Returns - ------- - Configuration - Configuration. - - Raises - ------ - FileNotFoundError - If the configuration file does not exist. - """ - # load dictionary from yaml - if not Path(path).exists(): - raise FileNotFoundError( - f"Configuration file {path} does not exist in " f" {Path.cwd()!s}" - ) - - dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader) - - return Configuration(**dictionary) - - -def save_configuration(config: Configuration, path: Union[str, Path]) -> Path: - """ - Save configuration to path. - - Parameters - ---------- - config : Configuration - Configuration to save. - path : Union[str, Path] - Path to a existing folder in which to save the configuration or to an existing - configuration file. - - Returns - ------- - Path - Path object representing the configuration. - - Raises - ------ - ValueError - If the path does not point to an existing directory or .yml file. - """ - # make sure path is a Path object - config_path = Path(path) - - # check if path is pointing to an existing directory or .yml file - if config_path.exists(): - if config_path.is_dir(): - config_path = Path(config_path, "config.yml") - elif config_path.suffix != ".yml": - raise ValueError( - f"Path must be a directory or .yml file (got {config_path})." - ) - else: - if config_path.suffix != ".yml": - raise ValueError(f"Path must be a .yml file (got {config_path}).") - - # save configuration as dictionary to yaml - with open(config_path, "w") as f: - yaml.dump(config.model_dump(), f, default_flow_style=False) - - return config_path diff --git a/src/careamics/config/config_filter.py b/src/careamics/config/config_filter.py deleted file mode 100644 index 00a1affb..00000000 --- a/src/careamics/config/config_filter.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Convenience functions to filter dictionaries resulting from a Pydantic export.""" -from pathlib import Path -from typing import Dict - - -def paths_to_str(dictionary: dict) -> dict: - """ - Replace Path objects in a dictionary by str. - - Parameters - ---------- - dictionary : dict - Dictionary to modify. - - Returns - ------- - dict - Modified dictionary. - """ - for k in dictionary.keys(): - if isinstance(dictionary[k], Path): - dictionary[k] = str(dictionary[k]) - - return dictionary - - -def remove_default_optionals(dictionary: Dict, default: Dict) -> None: - """ - Remove default arguments from a dictionary. - - The method removes arguments if they are equal to the provided default ones. - - Parameters - ---------- - dictionary : dict - Dictionary to modify. - default : dict - Dictionary containing the default values. - """ - dict_copy = dictionary.copy() - for k in dict_copy.keys(): - if k in default.keys(): - if dict_copy[k] == default[k]: - del dictionary[k] diff --git a/src/careamics/config/configuration_factory.py b/src/careamics/config/configuration_factory.py new file mode 100644 index 00000000..012cd65b --- /dev/null +++ b/src/careamics/config/configuration_factory.py @@ -0,0 +1,460 @@ +"""Convenience functions to create configurations for training and inference.""" + +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from albumentations import Compose + +from .algorithm_model import AlgorithmModel +from .architectures import UNetModel +from .configuration_model import Configuration +from .data_model import DataModel +from .inference_model import InferenceModel +from .support import ( + SupportedAlgorithm, + SupportedArchitecture, + SupportedLoss, + SupportedPixelManipulation, + SupportedTransform, +) +from .training_model import TrainingModel + + +def create_n2n_configuration( + experiment_name: str, + data_type: Literal["array", "tiff", "custom"], + axes: str, + patch_size: List[int], + batch_size: int, + num_epochs: int, + use_augmentations: bool = True, + use_n2v2: bool = False, + n_channels: int = 1, + logger: Literal["wandb", "tensorboard", "none"] = "none", + model_kwargs: Optional[dict] = None, +) -> Configuration: + """ + Create a configuration for training N2V. + + If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise + 2. + + By setting `use_augmentations` to False, the only transformation applied will be + normalization and N2V manipulation. + + The parameter `use_n2v2` overrides the corresponding `n2v2` that can be passed + in `model_kwargs`. + + If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask + will be applied to each manipulated pixel. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + data_type : Literal["array", "tiff", "custom"] + Type of the data. + axes : str + Axes of the data (e.g. SYX). + patch_size : List[int] + Size of the patches along the spatial dimensions (e.g. [64, 64]). + batch_size : int + Batch size. + num_epochs : int + Number of epochs. + use_augmentations : bool, optional + Whether to use augmentations, by default True. + use_n2v2 : bool, optional + Whether to use N2V2, by default False. + n_channels : int, optional + Number of channels (in and out), by default 1. + roi_size : int, optional + N2V pixel manipulation area, by default 11. + masked_pixel_percentage : float, optional + Percentage of pixels masked in each patch, by default 0.2. + struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional + Axis along which to apply structN2V mask, by default "none". + struct_n2v_span : int, optional + Span of the structN2V mask, by default 5. + logger : Literal["wandb", "tensorboard", "none"], optional + Logger to use, by default "none". + model_kwargs : dict, optional + UNetModel parameters, by default {}. + + Returns + ------- + Configuration + Configuration for training N2V. + """ + # model + if model_kwargs is None: + model_kwargs = {} + model_kwargs["n2v2"] = use_n2v2 + model_kwargs["conv_dims"] = 3 if "Z" in axes else 2 + model_kwargs["in_channels"] = n_channels + model_kwargs["num_classes"] = n_channels + + unet_model = UNetModel( + architecture=SupportedArchitecture.UNET.value, + **model_kwargs, + ) + + # algorithm model + algorithm = AlgorithmModel( + algorithm=SupportedAlgorithm.N2V.value, + loss=SupportedLoss.N2V.value, + model=unet_model, + ) + + # augmentations + if use_augmentations: + transforms: List[Dict[str, Any]] = [ + { + "name": SupportedTransform.NORMALIZE.value, + }, + { + "name": SupportedTransform.NDFLIP.value, + }, + { + "name": SupportedTransform.XY_RANDOM_ROTATE90.value, + }, + ] + else: + transforms = [ + { + "name": SupportedTransform.NORMALIZE.value, + }, + ] + + # data model + data = DataModel( + data_type=data_type, + axes=axes, + patch_size=patch_size, + batch_size=batch_size, + transforms=transforms, + ) + + # training model + training = TrainingModel( + num_epochs=num_epochs, + batch_size=batch_size, + logger=None if logger == "none" else logger, + ) + + # create configuration + configuration = Configuration( + experiment_name=experiment_name, + algorithm_config=algorithm, + data_config=data, + training_config=training, + ) + + return configuration + + +def create_n2v_configuration( + experiment_name: str, + data_type: Literal["array", "tiff", "custom"], + axes: str, + patch_size: List[int], + batch_size: int, + num_epochs: int, + use_augmentations: bool = True, + use_n2v2: bool = False, + n_channels: int = -1, + roi_size: int = 11, + masked_pixel_percentage: float = 0.2, + struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none", + struct_n2v_span: int = 5, + logger: Literal["wandb", "tensorboard", "none"] = "none", + model_kwargs: Optional[dict] = None, +) -> Configuration: + """ + Create a configuration for training N2V. + + N2V uses a UNet model to denoise images in a self-supervised manner. To use its + variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span` + (structN2V) parameters, or set `use_n2v2` to True (N2V2). + + N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip + connections, thus removing checkboard artefacts. StructN2V is used when vertical + or horizontal correlations are present in the noise; it applies an additional mask + to the manipulated pixel neighbors. + + If "C" is present in `axes`, then you need to set `n_channels` to the number of + channels. + + If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise + 2. + + By setting `use_augmentations` to False, the only transformations applied will be + normalization and N2V manipulation. + + The `roi_size` parameter specifies the size of the area around each pixel that will + be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many + pixels per patch will be manipulated. + + The parameters of the UNet can be specified in the `model_kwargs` (passed as a + parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the + corresponding parameters passed in `model_kwargs`. + + If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask + will be applied to each manipulated pixel. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + data_type : Literal["array", "tiff", "custom"] + Type of the data. + axes : str + Axes of the data (e.g. SYX). + patch_size : List[int] + Size of the patches along the spatial dimensions (e.g. [64, 64]). + batch_size : int + Batch size. + num_epochs : int + Number of epochs. + use_augmentations : bool, optional + Whether to use augmentations, by default True. + use_n2v2 : bool, optional + Whether to use N2V2, by default False. + n_channels : int, optional + Number of channels (in and out), by default -1. + roi_size : int, optional + N2V pixel manipulation area, by default 11. + masked_pixel_percentage : float, optional + Percentage of pixels masked in each patch, by default 0.2. + struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional + Axis along which to apply structN2V mask, by default "none". + struct_n2v_span : int, optional + Span of the structN2V mask, by default 5. + logger : Literal["wandb", "tensorboard", "none"], optional + Logger to use, by default "none". + model_kwargs : dict, optional + UNetModel parameters, by default {}. + + Returns + ------- + Configuration + Configuration for training N2V. + + Examples + -------- + Minimum example: + >>> config = create_n2v_configuration( + ... experiment_name="n2v_experiment", + ... data_type="array", + ... axes="YX", + ... patch_size=[64, 64], + ... batch_size=32, + ... num_epochs=100 + ... ) + + To use N2V2, simply pass the `use_n2v2` parameter: + >>> config = create_n2v_configuration( + ... experiment_name="n2v2_experiment", + ... data_type="tiff", + ... axes="YX", + ... patch_size=[64, 64], + ... batch_size=32, + ... num_epochs=100, + ... use_n2v2=True + ... ) + + For structN2V, there are two parameters to set, `struct_n2v_axis` and + `struct_n2v_span`: + >>> config = create_n2v_configuration( + ... experiment_name="structn2v_experiment", + ... data_type="tiff", + ... axes="YX", + ... patch_size=[64, 64], + ... batch_size=32, + ... num_epochs=100, + ... struct_n2v_axis="horizontal", + ... struct_n2v_span=7 + ... ) + + If you are training multiple channels together, then you need to specify the number + of channels: + >>> config = create_n2v_configuration( + ... experiment_name="n2v_experiment", + ... data_type="array", + ... axes="YXC", + ... patch_size=[64, 64], + ... batch_size=32, + ... num_epochs=100, + ... n_channels=3 + ... ) + + To turn off the augmentations, except normalization and N2V manipulation, use the + relevant keyword argument: + >>> config = create_n2v_configuration( + ... experiment_name="n2v_experiment", + ... data_type="array", + ... axes="YX", + ... patch_size=[64, 64], + ... batch_size=32, + ... num_epochs=100, + ... use_augmentations=False + ... ) + """ + # if there are channels, we need to specify their number + if "C" in axes and n_channels == -1: + raise ValueError( + f"Number of channels must be specified when using channels " + f"(got {n_channels} channel)." + ) + elif "C" not in axes and n_channels != -1: + raise ValueError( + f"C is not present in the axes, but number of channels is specified " + f"(got {n_channels} channel)." + ) + elif n_channels == -1: + n_channels = 1 + + # model + if model_kwargs is None: + model_kwargs = {} + model_kwargs["n2v2"] = use_n2v2 + model_kwargs["conv_dims"] = 3 if "Z" in axes else 2 + model_kwargs["in_channels"] = n_channels + model_kwargs["num_classes"] = n_channels + + unet_model = UNetModel( + architecture=SupportedArchitecture.UNET.value, + **model_kwargs, + ) + + # algorithm model + algorithm = AlgorithmModel( + algorithm=SupportedAlgorithm.N2V.value, + loss=SupportedLoss.N2V.value, + model=unet_model, + ) + + # augmentations + if use_augmentations: + transforms: List[Dict[str, Any]] = [ + { + "name": SupportedTransform.NORMALIZE.value, + }, + { + "name": SupportedTransform.NDFLIP.value, + }, + { + "name": SupportedTransform.XY_RANDOM_ROTATE90.value, + }, + ] + else: + transforms = [ + { + "name": SupportedTransform.NORMALIZE.value, + }, + ] + + # n2v2 and structn2v + nv2_transform = { + "name": SupportedTransform.N2V_MANIPULATE.value, + "strategy": SupportedPixelManipulation.MEDIAN.value + if use_n2v2 + else SupportedPixelManipulation.UNIFORM.value, + "roi_size": roi_size, + "masked_pixel_percentage": masked_pixel_percentage, + "struct_mask_axis": struct_n2v_axis, + "struct_mask_span": struct_n2v_span, + } + transforms.append(nv2_transform) + + # data model + data = DataModel( + data_type=data_type, + axes=axes, + patch_size=patch_size, + batch_size=batch_size, + transforms=transforms, + ) + + # training model + training = TrainingModel( + num_epochs=num_epochs, + batch_size=batch_size, + logger=None if logger == "none" else logger, + ) + + # create configuration + configuration = Configuration( + experiment_name=experiment_name, + algorithm_config=algorithm, + data_config=data, + training_config=training, + ) + + return configuration + + +# TODO add tests +def create_inference_configuration( + training_configuration: Configuration, + tile_size: Optional[Tuple[int, ...]] = None, + tile_overlap: Optional[Tuple[int, ...]] = None, + data_type: Optional[Literal["array", "tiff", "custom"]] = None, + axes: Optional[str] = None, + transforms: Optional[Union[List[Dict[str, Any]], Compose]] = None, + tta_transforms: bool = True, + batch_size: Optional[int] = 1, +) -> InferenceModel: + """ + Create a configuration for inference with N2V. + + If not provided, `data_type` and `axes` are taken from the training + configuration. If `transforms` are not provided, only normalization is applied. + + Parameters + ---------- + training_configuration : Configuration + Configuration used for training. + tile_size : Tuple[int, ...], optional + Size of the tiles. + tile_overlap : Tuple[int, ...], optional + Overlap of the tiles. + data_type : str, optional + Type of the data, by default "tiff". + axes : str, optional + Axes of the data, by default "YX". + transforms : List[Dict[str, Any]] or Compose, optional + Transformations to apply to the data, by default None. + tta_transforms : bool, optional + Whether to apply test-time augmentations, by default True. + batch_size : int, optional + Batch size, by default 1. + + Returns + ------- + InferenceConfiguration + Configuration for inference with N2V. + """ + if ( + training_configuration.data_config.mean is None + or training_configuration.data_config.std is None + ): + raise ValueError("Mean and std must be provided in the training configuration.") + + if transforms is None: + transforms = [ + { + "name": SupportedTransform.NORMALIZE.value, + }, + ] + + return InferenceModel( + data_type=data_type or training_configuration.data_config.data_type, + tile_size=tile_size, + tile_overlap=tile_overlap, + axes=axes or training_configuration.data_config.axes, + mean=training_configuration.data_config.mean, + std=training_configuration.data_config.std, + transforms=transforms, + tta_transforms=tta_transforms, + batch_size=batch_size, + ) diff --git a/src/careamics/config/configuration_model.py b/src/careamics/config/configuration_model.py new file mode 100644 index 00000000..306bcd53 --- /dev/null +++ b/src/careamics/config/configuration_model.py @@ -0,0 +1,596 @@ +"""Pydantic CAREamics configuration.""" +from __future__ import annotations + +import re +from pathlib import Path +from pprint import pformat +from typing import Dict, List, Literal, Union + +import yaml +from bioimageio.spec.generic.v0_3 import CiteEntry +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from typing_extensions import Self + +from .algorithm_model import AlgorithmModel +from .data_model import DataModel +from .references import ( + CARE, + CUSTOM, + N2N, + N2V, + N2V2, + STRUCT_N2V, + STRUCT_N2V2, + CAREDescription, + CARERef, + N2NDescription, + N2NRef, + N2V2Description, + N2V2Ref, + N2VDescription, + N2VRef, + StructN2V2Description, + StructN2VDescription, + StructN2VRef, +) +from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform +from .training_model import TrainingModel +from .transformations.n2v_manipulate_model import ( + N2VManipulateModel, +) + + +class Configuration(BaseModel): + """ + CAREamics configuration. + + The configuration defines all parameters used to build and train a CAREamics model. + These parameters are validated to ensure that they are compatible with each other. + + It contains three sub-configurations: + + - AlgorithmModel: configuration for the algorithm training, which includes the + architecture, loss function, optimizer, and other hyperparameters. + - DataModel: configuration for the dataloader, which includes the type of data, + transformations, mean/std and other parameters. + - TrainingModel: configuration for the training, which includes the number of + epochs or the callbacks. + + Attributes + ---------- + experiment_name : str + Name of the experiment, used when saving logs and checkpoints. + algorithm : AlgorithmModel + Algorithm configuration. + data : DataModel + Data configuration. + training : TrainingModel + Training configuration. + + Methods + ------- + set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None + Switch configuration between 2D and 3D. + set_N2V2(use_n2v2: bool) -> None + Switch N2V algorithm between N2V and N2V2. + set_structN2V( + mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int) -> None + Set StructN2V parameters. + model_dump( + exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict + ) -> Dict + Export configuration to a dictionary. + + Raises + ------ + ValueError + Configuration parameter type validation errors. + ValueError + If the experiment name contains invalid characters or is empty. + ValueError + If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm + with "Z" in data axes. + ValueError + Algorithm, data or training validation errors. + + Notes + ----- + We provide convenience methods to create standards configurations, for instance + for N2V, in the `careamics.config.configuration_factory` module. + >>> from careamics.config.configuration_factory import create_n2v_configuration + >>> config = create_n2v_configuration( + ... experiment_name="n2v_experiment", + ... data_type="array", + ... axes="YX", + ... patch_size=[64, 64], + ... batch_size=32, + ... num_epochs=100 + ... ) + + The configuration can be exported to a dictionary using the model_dump method: + >>> config_dict = config.model_dump() + + Configurations can also be exported or imported from yaml files: + >>> from careamics.config import save_configuration, load_configuration + >>> path_to_config = save_configuration(config, my_path / "config.yml") + >>> other_config = load_configuration(path_to_config) + + Examples + -------- + Minimum example: + >>> from careamics.config import Configuration + >>> config_dict = { + ... "experiment_name": "N2V_experiment", + ... "algorithm_config": { + ... "algorithm": "n2v", + ... "loss": "n2v", + ... "model": { + ... "architecture": "UNet", + ... }, + ... }, + ... "training_config": { + ... "num_epochs": 200, + ... }, + ... "data_config": { + ... "data_type": "tiff", + ... "patch_size": [64, 64], + ... "axes": "SYX", + ... }, + ... } + >>> config = Configuration(**config_dict) + """ + + model_config = ConfigDict( + validate_assignment=True, + set_arbitrary_types_allowed=True, + ) + + # version + version: Literal["0.1.0"] = Field( + default="0.1.0", description="Version of the CAREamics configuration." + ) + + # required parameters + experiment_name: str = Field( + ..., description="Name of the experiment, used to name logs and checkpoints." + ) + + # Sub-configurations + algorithm_config: AlgorithmModel + data_config: DataModel + training_config: TrainingModel + + @field_validator("experiment_name") + @classmethod + def no_symbol(cls, name: str) -> str: + """ + Validate experiment name. + + A valid experiment name is a non-empty string with only contains letters, + numbers, underscores, dashes and spaces. + + Parameters + ---------- + name : str + Name to validate. + + Returns + ------- + str + Validated name. + + Raises + ------ + ValueError + If the name is empty or contains invalid characters. + """ + if len(name) == 0 or name.isspace(): + raise ValueError("Experiment name is empty.") + + # Validate using a regex that it contains only letters, numbers, underscores, + # dashes and spaces + if not re.match(r"^[a-zA-Z0-9_\- ]*$", name): + raise ValueError( + f"Experiment name contains invalid characters (got {name}). " + f"Only letters, numbers, underscores, dashes and spaces are allowed." + ) + + return name + + @model_validator(mode="after") + def validate_3D(self: Self) -> Self: + """ + Change algorithm dimensions to match data.axes. + + Only for non-custom algorithms. + + Returns + ------- + Self + Validated configuration. + """ + if self.algorithm_config.algorithm != SupportedAlgorithm.CUSTOM: + if "Z" in self.data_config.axes and not self.algorithm_config.model.is_3D(): + # change algorithm to 3D + self.algorithm_config.model.set_3D(True) + elif ( + "Z" not in self.data_config.axes and self.algorithm_config.model.is_3D() + ): + # change algorithm to 2D + self.algorithm_config.model.set_3D(False) + + return self + + @model_validator(mode="after") + def validate_algorithm_and_data(self: Self) -> Self: + """ + Validate algorithm and data compatibility. + + In particular, the validation does the following: + + - If N2V is used, it enforces the presence of N2V_Maniuplate in the transforms + - If N2V2 is used, it enforces the correct manipulation strategy + + Returns + ------- + Self + Validated configuration. + """ + if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: + # if we have a list of transform (as opposed to Compose) + if self.data_config.has_transform_list(): + # missing N2V_MANIPULATE + if not self.data_config.has_n2v_manipulate(): + self.data_config.transforms.append( + N2VManipulateModel( + name=SupportedTransform.N2V_MANIPULATE.value, + ) + ) + + median = SupportedPixelManipulation.MEDIAN.value + uniform = SupportedPixelManipulation.UNIFORM.value + strategy = median if self.algorithm_config.model.n2v2 else uniform + self.data_config.set_N2V2_strategy(strategy) + else: + # if we have a list of transform, remove N2V manipulate if present + if self.data_config.has_transform_list(): + if self.data_config.has_n2v_manipulate(): + self.data_config.remove_n2v_manipulate() + + return self + + def __str__(self) -> str: + """ + Pretty string reprensenting the configuration. + + Returns + ------- + str + Pretty string. + """ + return pformat(self.model_dump()) + + def set_3D(self, is_3D: bool, axes: str, patch_size: List[int]) -> None: + """ + Set 3D flag and axes. + + Parameters + ---------- + is_3D : bool + Whether the algorithm is 3D or not. + axes : str + Axes of the data. + patch_size : List[int] + Patch size. + """ + # set the flag and axes (this will not trigger validation at the config level) + self.algorithm_config.model.set_3D(is_3D) + self.data_config.set_3D(axes, patch_size) + + # cheap hack: trigger validation + self.algorithm_config = self.algorithm_config + + def set_N2V2(self, use_n2v2: bool) -> None: + """ + Switch N2V algorithm between N2V and N2V2. + + Parameters + ---------- + use_n2v2 : bool + Whether to use N2V2 or not. + + Raises + ------ + ValueError + If the algorithm is not N2V. + """ + if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: + self.algorithm_config.model.n2v2 = use_n2v2 + strategy = ( + SupportedPixelManipulation.MEDIAN.value + if use_n2v2 + else SupportedPixelManipulation.UNIFORM.value + ) + self.data_config.set_N2V2_strategy(strategy) + else: + raise ValueError("N2V2 can only be set for N2V algorithm.") + + def set_structN2V( + self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int + ) -> None: + """ + Set StructN2V parameters. + + Parameters + ---------- + mask_axis : Literal["horizontal", "vertical", "none"] + Axis of the structural mask. + mask_span : int + Span of the structural mask. + """ + self.data_config.set_structN2V_mask(mask_axis, mask_span) + + def get_algorithm_flavour(self) -> str: + """ + Get the algorithm name. + + Returns + ------- + str + Algorithm name. + """ + if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: + use_n2v2 = self.algorithm_config.model.n2v2 + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" + + # return the n2v flavour + if use_n2v2 and use_structN2V: + return STRUCT_N2V2 + elif use_n2v2: + return N2V2 + elif use_structN2V: + return STRUCT_N2V + else: + return N2V + elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N: + return N2N + elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE: + return CARE + else: + return CUSTOM + + def get_algorithm_description(self) -> str: + """ + Return a description of the algorithm. + + This method is used to generate the README of the BioImage Model Zoo export. + + Returns + ------- + str + Description of the algorithm. + """ + algorithm_flavour = self.get_algorithm_flavour() + + if algorithm_flavour == CUSTOM: + return f"Custom algorithm, named {self.algorithm_config.model.name}" + else: # currently only N2V flavours + if algorithm_flavour == N2V: + return N2VDescription().description + elif algorithm_flavour == N2V2: + return N2V2Description().description + elif algorithm_flavour == STRUCT_N2V: + return StructN2VDescription().description + elif algorithm_flavour == STRUCT_N2V2: + return StructN2V2Description().description + elif algorithm_flavour == N2N: + return N2NDescription().description + elif algorithm_flavour == CARE: + return CAREDescription().description + + return "" + + def get_algorithm_citations(self) -> List[CiteEntry]: + """ + Return a list of citation entries of the current algorithm. + + This is used to generate the model description for the BioImage Model Zoo. + + Returns + ------- + List[CiteEntry] + List of citation entries. + """ + if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: + use_n2v2 = self.algorithm_config.model.n2v2 + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" + + # return the (struct)N2V(2) references + if use_n2v2 and use_structN2V: + return [N2VRef, N2V2Ref, StructN2VRef] + elif use_n2v2: + return [N2VRef, N2V2Ref] + elif use_structN2V: + return [N2VRef, StructN2VRef] + else: + return [N2VRef] + elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N: + return [N2NRef] + elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE: + return [CARERef] + + raise ValueError("Citation not available for custom algorithm.") + + def get_algorithm_references(self) -> str: + """ + Get the algorithm references. + + This is used to generate the README of the BioImage Model Zoo export. + + Returns + ------- + str + Algorithm references. + """ + if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: + use_n2v2 = self.algorithm_config.model.n2v2 + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" + + references = [ + N2VRef.text + " doi: " + N2VRef.doi, + N2V2Ref.text + " doi: " + N2V2Ref.doi, + StructN2VRef.text + " doi: " + StructN2VRef.doi, + ] + + # return the (struct)N2V(2) references + if use_n2v2 and use_structN2V: + return "".join(references) + elif use_n2v2: + references.pop(-1) + return "".join(references) + elif use_structN2V: + references.pop(-2) + return "".join(references) + else: + return references[0] + + return "" + + def get_algorithm_keywords(self) -> List[str]: + """ + Get algorithm keywords. + + Returns + ------- + List[str] + List of keywords. + """ + if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: + use_n2v2 = self.algorithm_config.model.n2v2 + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" + + keywords = [ + "denoising", + "restoration", + "UNet", + "3D" if "Z" in self.data_config.axes else "2D", + "CAREamics", + "pytorch", + N2V, + ] + + if use_n2v2: + keywords.append(N2V2) + if use_structN2V: + keywords.append(STRUCT_N2V) + else: + keywords = ["CAREamics"] + + return keywords + + def model_dump( + self, + exclude_defaults: bool = False, + exclude_none: bool = True, + **kwargs: Dict, + ) -> Dict: + """ + Override model_dump method in order to set default values. + + Parameters + ---------- + exclude_defaults : bool, optional + Whether to exclude fields with default values or not, by default + True. + exclude_none : bool, optional + Whether to exclude fields with None values or not, by default True. + **kwargs : Dict + Keyword arguments. + + Returns + ------- + dict + Dictionary containing the model parameters. + """ + dictionary = super().model_dump( + exclude_none=exclude_none, exclude_defaults=exclude_defaults, **kwargs + ) + + return dictionary + + +def load_configuration(path: Union[str, Path]) -> Configuration: + """ + Load configuration from a yaml file. + + Parameters + ---------- + path : Union[str, Path] + Path to the configuration. + + Returns + ------- + Configuration + Configuration. + + Raises + ------ + FileNotFoundError + If the configuration file does not exist. + """ + # load dictionary from yaml + if not Path(path).exists(): + raise FileNotFoundError( + f"Configuration file {path} does not exist in " f" {Path.cwd()!s}" + ) + + dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader) + + return Configuration(**dictionary) + + +def save_configuration(config: Configuration, path: Union[str, Path]) -> Path: + """ + Save configuration to path. + + Parameters + ---------- + config : Configuration + Configuration to save. + path : Union[str, Path] + Path to a existing folder in which to save the configuration or to an existing + configuration file. + + Returns + ------- + Path + Path object representing the configuration. + + Raises + ------ + ValueError + If the path does not point to an existing directory or .yml file. + """ + # make sure path is a Path object + config_path = Path(path) + + # check if path is pointing to an existing directory or .yml file + if config_path.exists(): + if config_path.is_dir(): + config_path = Path(config_path, "config.yml") + elif config_path.suffix != ".yml" and config_path.suffix != ".yaml": + raise ValueError( + f"Path must be a directory or .yml or .yaml file (got {config_path})." + ) + else: + if config_path.suffix != ".yml" and config_path.suffix != ".yaml": + raise ValueError( + f"Path must be a directory or .yml or .yaml file (got {config_path})." + ) + + # save configuration as dictionary to yaml + with open(config_path, "w") as f: + # dump configuration + yaml.dump(config.model_dump(), f, default_flow_style=False) + + return config_path diff --git a/src/careamics/config/data.py b/src/careamics/config/data.py deleted file mode 100644 index bcf0da23..00000000 --- a/src/careamics/config/data.py +++ /dev/null @@ -1,194 +0,0 @@ -"""Data configuration.""" -from __future__ import annotations - -from enum import Enum -from typing import Dict, List, Optional - -from pydantic import ( - BaseModel, - ConfigDict, - Field, - field_validator, - model_validator, -) - -from careamics.utils import check_axes_validity - - -class SupportedExtension(str, Enum): - """ - Supported extensions for input data. - - Currently supported: - - tif/tiff: .tiff files. - """ - - TIFF = "tiff" # TODO these should be a single one - TIF = "tif" - - @classmethod - def _missing_(cls, value: object) -> str: - """ - Override default behaviour for missing values. - - This method is called when `value` is not found in the enum values. It converts - `value` to lowercase, removes "." if it is the first character and tries to - match it with enum values. - - Parameters - ---------- - value : object - Value to be matched with enum values. - - Returns - ------- - str - Matched enum value. - """ - if isinstance(value, str): - lower_value = value.lower() - - if lower_value.startswith("."): - lower_value = lower_value[1:] - - # attempt to match lowercase value with enum values - for member in cls: - if member.value == lower_value: - return member - - # still missing - return super()._missing_(value) - - -class Data(BaseModel): - """ - Data configuration. - - If std is specified, mean must be specified as well. Note that setting the std first - and then the mean (if they were both `None` before) will raise a validation error. - Prefer instead the following: - >>> set_mean_and_std(mean, std) - - Attributes - ---------- - in_memory : bool - Whether to load the data in memory or not. - data_format : SupportedExtension - Extension of the data, without period. - axes : str - Axes of the data. - mean: Optional[float] - Expected data mean. - std: Optional[float] - Expected data standard deviation. - """ - - # Pydantic class configuration - model_config = ConfigDict( - use_enum_values=True, - validate_assignment=True, - ) - - # Mandatory fields - in_memory: bool - data_format: SupportedExtension - axes: str - - # Optional fields - mean: Optional[float] = Field(default=None, ge=0) - std: Optional[float] = Field(default=None, gt=0) - - def set_mean_and_std(self, mean: float, std: float) -> None: - """ - Set mean and standard deviation of the data. - - This method is preferred to setting the fields directly, as it ensures that the - mean is set first, then the std; thus avoiding a validation error to be thrown. - - Parameters - ---------- - mean : float - Mean of the data. - std : float - Standard deviation of the data. - """ - self.mean = mean - self.std = std - - @field_validator("axes") - def valid_axes(cls, axes: str) -> str: - """ - Validate axes. - - Axes must be a subset of STZYX, must contain YX, be in the right order - and not contain both S and T. - - Parameters - ---------- - axes : str - Axes of the training data. - - Returns - ------- - str - Validated axes of the training data. - - Raises - ------ - ValueError - If axes are not valid. - """ - # validate axes - check_axes_validity(axes) - - return axes - - @model_validator(mode="after") - def std_only_with_mean(cls, data_model: Data) -> Data: - """ - Check that mean and std are either both None, or both specified. - - If we enforce both None or both specified, we cannot set the values one by one - due to the ConfDict enforcing the validation on assignment. Therefore, we check - only when the std is not None and the mean is None. - - Parameters - ---------- - data_model : Data - Data model. - - Returns - ------- - Data - Validated data model. - - Raises - ------ - ValueError - If std is not None and mean is None. - """ - if data_model.std is not None and data_model.mean is None: - raise ValueError("Cannot have std non None if mean is None.") - - return data_model - - def model_dump(self, *args: List, **kwargs: Dict) -> dict: - """ - Override model_dump method. - - The purpose is to ensure export smooth import to yaml. It includes: - - remove entries with None value. - - Parameters - ---------- - *args : List - Positional arguments, unused. - **kwargs : Dict - Keyword arguments, unused. - - Returns - ------- - dict - Dictionary containing the model parameters. - """ - return super().model_dump(exclude_none=True) diff --git a/src/careamics/config/data_model.py b/src/careamics/config/data_model.py new file mode 100644 index 00000000..25a6b3e8 --- /dev/null +++ b/src/careamics/config/data_model.py @@ -0,0 +1,555 @@ +"""Data configuration.""" +from __future__ import annotations + +from pprint import pformat +from typing import Any, List, Literal, Optional, Union + +from albumentations import Compose +from pydantic import ( + BaseModel, + ConfigDict, + Discriminator, + Field, + field_validator, + model_validator, +) +from typing_extensions import Annotated, Self + +from .support import SupportedTransform +from .transformations.n2v_manipulate_model import N2VManipulateModel +from .transformations.nd_flip_model import NDFlipModel +from .transformations.normalize_model import NormalizeModel +from .transformations.xy_random_rotate90_model import XYRandomRotate90Model +from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 + +TRANSFORMS_UNION = Annotated[ + Union[ + NDFlipModel, + XYRandomRotate90Model, + NormalizeModel, + N2VManipulateModel, + ], + Discriminator("name"), # used to tell the different transform models apart +] + + +class DataModel(BaseModel): + """ + Data configuration. + + If std is specified, mean must be specified as well. Note that setting the std first + and then the mean (if they were both `None` before) will raise a validation error. + Prefer instead `set_mean_and_std` to set both at once. + + Examples + -------- + Minimum example: + + >>> data = DataModel( + ... data_type="array", # defined in SupportedData + ... patch_size=[128, 128], + ... batch_size=4, + ... axes="YX" + ... ) + + To change the mean and std of the data: + >>> data.set_mean_and_std(mean=214.3, std=84.5) + + One can pass also a list of transformations, by keyword, using the + SupportedTransform or the name of an Albumentation transform: + >>> from careamics.config.support import SupportedTransform + >>> data = DataModel( + ... data_type="tiff", + ... patch_size=[128, 128], + ... batch_size=4, + ... axes="YX", + ... transforms=[ + ... { + ... "name": SupportedTransform.NORMALIZE.value, + ... "mean": 167.6, + ... "std": 47.2, + ... }, + ... { + ... "name": "NDFlip", + ... "is_3D": True, + ... "flip_z": True, + ... } + ... ] + ... ) + """ + + # Pydantic class configuration + model_config = ConfigDict( + validate_assignment=True, + arbitrary_types_allowed=True, # Allow Compose declaration + ) + + # Dataset configuration + data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData + patch_size: Union[List[int]] = Field(..., min_length=2, max_length=3) + batch_size: int = Field(default=1, ge=1, validate_default=True) + axes: str + + # Optional fields + mean: Optional[float] = None + std: Optional[float] = None + + transforms: Union[List[TRANSFORMS_UNION], Compose] = Field( + default=[ + { + "name": SupportedTransform.NORMALIZE.value, + }, + { + "name": SupportedTransform.NDFLIP.value, + }, + { + "name": SupportedTransform.XY_RANDOM_ROTATE90.value, + }, + { + "name": SupportedTransform.N2V_MANIPULATE.value, + }, + ], + validate_default=True, + ) + + dataloader_params: Optional[dict] = None + + @field_validator("patch_size") + @classmethod + def all_elements_power_of_2_minimum_8( + cls, patch_list: Union[List[int]] + ) -> Union[List[int]]: + """ + Validate patch size. + + Patch size must be powers of 2 and minimum 8. + + Parameters + ---------- + patch_list : Union[List[int]] + Patch size. + + Returns + ------- + Union[List[int]] + Validated patch size. + + Raises + ------ + ValueError + If the patch size is smaller than 8. + ValueError + If the patch size is not a power of 2. + """ + patch_size_ge_than_8_power_of_2(patch_list) + + return patch_list + + @field_validator("axes") + @classmethod + def axes_valid(cls, axes: str) -> str: + """ + Validate axes. + + Axes must: + - be a combination of 'STCZYX' + - not contain duplicates + - contain at least 2 contiguous axes: X and Y + - contain at most 4 axes + - not contain both S and T axes + + Parameters + ---------- + axes : str + Axes to validate. + + Returns + ------- + str + Validated axes. + + Raises + ------ + ValueError + If axes are not valid. + """ + # Validate axes + check_axes_validity(axes) + + return axes + + @field_validator("transforms") + @classmethod + def validate_prediction_transforms( + cls, transforms: Union[List[TRANSFORMS_UNION], Compose] + ) -> Union[List[TRANSFORMS_UNION], Compose]: + """ + Validate N2VManipulate transform position in the transform list. + + Parameters + ---------- + transforms : Union[List[Transformations_Union], Compose] + Transforms. + + Returns + ------- + Union[List[Transformations_Union], Compose] + Validated transforms. + + Raises + ------ + ValueError + If multiple instances of N2VManipulate are found. + """ + if not isinstance(transforms, Compose): + transform_list = [t.name for t in transforms] + + if SupportedTransform.N2V_MANIPULATE in transform_list: + # multiple N2V_MANIPULATE + if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1: + raise ValueError( + f"Multiple instances of " + f"{SupportedTransform.N2V_MANIPULATE} transforms " + f"are not allowed." + ) + + # N2V_MANIPULATE not the last transform + elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE: + index = transform_list.index(SupportedTransform.N2V_MANIPULATE) + transform = transforms.pop(index) + transforms.append(transform) + + return transforms + + @model_validator(mode="after") + def std_only_with_mean(self: Self) -> Self: + """ + Check that mean and std are either both None, or both specified. + + Returns + ------- + Self + Validated data model. + + Raises + ------ + ValueError + If std is not None and mean is None. + """ + # check that mean and std are either both None, or both specified + if (self.mean is None) != (self.std is None): + raise ValueError( + "Mean and std must be either both None, or both specified." + ) + + return self + + @model_validator(mode="after") + def add_std_and_mean_to_normalize(self: Self) -> Self: + """ + Add mean and std to the Normalize transform if it is present. + + Returns + ------- + Self + Data model with mean and std added to the Normalize transform. + """ + if self.mean is not None or self.std is not None: + # search in the transforms for Normalize and update parameters + if self.has_transform_list(): + for transform in self.transforms: + if transform.name == SupportedTransform.NORMALIZE.value: + transform.mean = self.mean + transform.std = self.std + + return self + + @model_validator(mode="after") + def validate_dimensions(self: Self) -> Self: + """ + Validate 2D/3D dimensions between axes, patch size and transforms. + + Returns + ------- + Self + Validated data model. + + Raises + ------ + ValueError + If the transforms are not valid. + """ + if "Z" in self.axes: + if len(self.patch_size) != 3: + raise ValueError( + f"Patch size must have 3 dimensions if the data is 3D " + f"({self.axes})." + ) + + if self.has_transform_list(): + for transform in self.transforms: + if transform.name == SupportedTransform.NDFLIP: + transform.is_3D = True + elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90: + transform.is_3D = True + + else: + if len(self.patch_size) != 2: + raise ValueError( + f"Patch size must have 3 dimensions if the data is 3D " + f"({self.axes})." + ) + + if self.has_transform_list(): + for transform in self.transforms: + if transform.name == SupportedTransform.NDFLIP: + transform.is_3D = False + elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90: + transform.is_3D = False + + return self + + def __str__(self) -> str: + """ + Pretty string reprensenting the configuration. + + Returns + ------- + str + Pretty string. + """ + return pformat(self.model_dump()) + + def _update(self, **kwargs: Any) -> None: + """ + Update multiple arguments at once. + + Parameters + ---------- + **kwargs : Any + Keyword arguments to update. + """ + self.__dict__.update(kwargs) + self.__class__.model_validate(self.__dict__) + + def has_transform_list(self) -> bool: + """ + Check if the transforms are a list, as opposed to a Compose object. + + Returns + ------- + bool + True if the transforms are a list, False otherwise. + """ + return isinstance(self.transforms, list) + + def has_n2v_manipulate(self) -> bool: + """ + Check if the transforms contain N2VManipulate. + + Use `has_transform_list` to check if the transforms are a list. + + Returns + ------- + bool + True if the transforms contain N2VManipulate, False otherwise. + + Raises + ------ + ValueError + If the transforms are a Compose object. + """ + if self.has_transform_list(): + return any( + transform.name == SupportedTransform.N2V_MANIPULATE.value + for transform in self.transforms + ) + else: + raise ValueError( + "Checking for N2VManipulate with Compose transforms is not allowed. " + "Check directly in the Compose." + ) + + def add_n2v_manipulate(self) -> None: + """ + Add N2VManipulate to the transforms. + + Use `has_transform_list` to check if the transforms are a list. + + Raises + ------ + ValueError + If the transforms are a Compose object. + """ + if self.has_transform_list(): + if not self.has_n2v_manipulate(): + self.transforms.append( + N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value) + ) + else: + raise ValueError( + "Adding N2VManipulate with Compose transforms is not allowed. Add " + "N2VManipulate directly to the transform in the Compose." + ) + + def remove_n2v_manipulate(self) -> None: + """ + Remove N2VManipulate from the transforms. + + Use `has_transform_list` to check if the transforms are a list. + + Raises + ------ + ValueError + If the transforms are a Compose object. + """ + if self.has_transform_list() and self.has_n2v_manipulate(): + self.transforms.pop(-1) + else: + raise ValueError( + "Removing N2VManipulate with Compose transforms is not allowed. Remove " + "N2VManipulate directly from the transform in the Compose." + ) + + def set_mean_and_std(self, mean: float, std: float) -> None: + """ + Set mean and standard deviation of the data. + + This method should be used instead setting the fields directly, as it would + otherwise trigger a validation error. + + Parameters + ---------- + mean : float + Mean of the data. + std : float + Standard deviation of the data. + """ + self._update(mean=mean, std=std) + + # search in the transforms for Normalize and update parameters + if self.has_transform_list(): + for transform in self.transforms: + if transform.name == SupportedTransform.NORMALIZE.value: + transform.mean = mean + transform.std = std + else: + raise ValueError( + "Setting mean and std with Compose transforms is not allowed. Add " + "mean and std parameters directly to the transform in the Compose." + ) + + def set_3D(self, axes: str, patch_size: List[int]) -> None: + """ + Set 3D parameters. + + Parameters + ---------- + axes : str + Axes. + patch_size : List[int] + Patch size. + """ + self._update(axes=axes, patch_size=patch_size) + + def set_N2V2(self, use_n2v2: bool) -> None: + """ + Set N2V2. + + Parameters + ---------- + use_n2v2 : bool + Whether to use N2V2. + + Raises + ------ + ValueError + If the N2V pixel manipulate transform is not found in the transforms. + ValueError + If the transforms are a Compose object. + """ + if use_n2v2: + self.set_N2V2_strategy("median") + else: + self.set_N2V2_strategy("uniform") + + def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None: + """ + Set N2V2 strategy. + + Parameters + ---------- + strategy : Literal["uniform", "median"] + Strategy to use for N2V2. + + Raises + ------ + ValueError + If the N2V pixel manipulate transform is not found in the transforms. + ValueError + If the transforms are a Compose object. + """ + if isinstance(self.transforms, list): + found_n2v = False + + for transform in self.transforms: + if transform.name == SupportedTransform.N2V_MANIPULATE.value: + transform.strategy = strategy + found_n2v = True + + if not found_n2v: + transforms = [t.name for t in self.transforms] + raise ValueError( + f"N2V_Manipulate transform not found in the transforms list " + f"({transforms})." + ) + + else: + raise ValueError( + "Setting N2V2 strategy with Compose transforms is not allowed. Add " + "N2V2 strategy parameters directly to the transform in the Compose." + ) + + def set_structN2V_mask( + self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int + ) -> None: + """ + Set structN2V mask parameters. + + Setting `mask_axis` to `none` will disable structN2V. + + Parameters + ---------- + mask_axis : Literal["horizontal", "vertical", "none"] + Axis along which to apply the mask. `none` will disable structN2V. + mask_span : int + Total span of the mask in pixels. + + Raises + ------ + ValueError + If the N2V pixel manipulate transform is not found in the transforms. + ValueError + If the transforms are a Compose object. + """ + if isinstance(self.transforms, list): + found_n2v = False + + for transform in self.transforms: + if transform.name == SupportedTransform.N2V_MANIPULATE.value: + transform.struct_mask_axis = mask_axis + transform.struct_mask_span = mask_span + found_n2v = True + + if not found_n2v: + transforms = [t.name for t in self.transforms] + raise ValueError( + f"N2V pixel manipulate transform not found in the transforms " + f"({transforms})." + ) + + else: + raise ValueError( + "Setting structN2VMask with Compose transforms is not allowed. Add " + "structN2VMask parameters directly to the transform in the Compose." + ) diff --git a/src/careamics/config/inference_model.py b/src/careamics/config/inference_model.py new file mode 100644 index 00000000..16b0fc0e --- /dev/null +++ b/src/careamics/config/inference_model.py @@ -0,0 +1,283 @@ +"""Pydantic model representing CAREamics prediction configuration.""" +from __future__ import annotations + +from typing import Any, List, Literal, Optional, Union + +from albumentations import Compose +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from typing_extensions import Self + +from .support import SupportedTransform +from .transformations.normalize_model import NormalizeModel +from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 + +TRANSFORMS_UNION = Union[NormalizeModel] + + +class InferenceModel(BaseModel): + """Configuration class for the prediction model.""" + + model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) + + # Mandatory fields + data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData + tile_size: Optional[Union[List[int]]] = Field( + default=None, min_length=2, max_length=3 + ) + tile_overlap: Optional[Union[List[int]]] = Field( + default=None, min_length=2, max_length=3 + ) + + axes: str + + mean: float + std: float = Field(..., ge=0.0) + + transforms: Union[List[TRANSFORMS_UNION], Compose] = Field( + default=[ + { + "name": SupportedTransform.NORMALIZE.value, + }, + ], + validate_default=True, + ) + + # only default TTAs are supported for now + tta_transforms: bool = Field(default=True) + + # Dataloader parameters + batch_size: int = Field(default=1, ge=1) + + @field_validator("tile_overlap") + @classmethod + def all_elements_non_zero_even( + cls, patch_list: Optional[Union[List[int]]] + ) -> Optional[Union[List[int]]]: + """ + Validate patch size. + + Patch size must be non-zero, positive and even. + + Parameters + ---------- + patch_list : Optional[Union[List[int]]] + Patch size. + + Returns + ------- + Optional[Union[List[int]]] + Validated patch size. + + Raises + ------ + ValueError + If the patch size is 0. + ValueError + If the patch size is not even. + """ + if patch_list is not None: + for dim in patch_list: + if dim < 1: + raise ValueError( + f"Patch size must be non-zero positive (got {dim})." + ) + + if dim % 2 != 0: + raise ValueError(f"Patch size must be even (got {dim}).") + + return patch_list + + @field_validator("tile_size") + @classmethod + def tile_min_8_power_of_2( + cls, tile_list: Optional[Union[List[int]]] + ) -> Optional[Union[List[int]]]: + """ + Validate that each entry is greater or equal than 8 and a power of 2. + + Parameters + ---------- + tile_list : List[int] + Patch size. + + Returns + ------- + List[int] + Validated patch size. + + Raises + ------ + ValueError + If the patch size if smaller than 8. + ValueError + If the patch size is not a power of 2. + """ + patch_size_ge_than_8_power_of_2(tile_list) + + return tile_list + + @field_validator("axes") + @classmethod + def axes_valid(cls, axes: str) -> str: + """ + Validate axes. + + Axes must: + - be a combination of 'STCZYX' + - not contain duplicates + - contain at least 2 contiguous axes: X and Y + - contain at most 4 axes + - not contain both S and T axes + + Parameters + ---------- + axes : str + Axes to validate. + + Returns + ------- + str + Validated axes. + + Raises + ------ + ValueError + If axes are not valid. + """ + # Validate axes + check_axes_validity(axes) + + return axes + + @field_validator("transforms") + @classmethod + def validate_transforms( + cls, transforms: Union[List[TRANSFORMS_UNION], Compose] + ) -> Union[List[TRANSFORMS_UNION], Compose]: + """ + Validate that transforms do not have N2V pixel manipulate transforms. + + Parameters + ---------- + transforms : Union[List[TransformModel], Compose] + Transforms. + + Returns + ------- + Union[List[Transformations_Union], Compose] + Validated transforms. + + Raises + ------ + ValueError + If transforms contain N2V pixel manipulate transforms. + """ + if not isinstance(transforms, Compose) and transforms is not None: + for transform in transforms: + if transform.name == SupportedTransform.N2V_MANIPULATE.value: + raise ValueError( + "N2V_Manipulate transform is not allowed in " + "prediction transforms." + ) + + return transforms + + @model_validator(mode="after") + def validate_dimensions(self: Self) -> Self: + """ + Validate 2D/3D dimensions between axes and tile size. + + Returns + ------- + Self + Validated prediction model. + """ + expected_len = 3 if "Z" in self.axes else 2 + + if self.tile_size is not None and self.tile_overlap is not None: + if len(self.tile_size) != expected_len: + raise ValueError( + f"Tile size must have {expected_len} dimensions given axes " + f"{self.axes} (got {self.tile_size})." + ) + + if len(self.tile_overlap) != expected_len: + raise ValueError( + f"Tile overlap must have {expected_len} dimensions given axes " + f"{self.axes} (got {self.tile_overlap})." + ) + + if any((i >= j) for i, j in zip(self.tile_overlap, self.tile_size)): + raise ValueError("Tile overlap must be smaller than tile size.") + + return self + + @model_validator(mode="after") + def std_only_with_mean(self: Self) -> Self: + """ + Check that mean and std are either both None, or both specified. + + Returns + ------- + Self + Validated prediction model. + + Raises + ------ + ValueError + If std is not None and mean is None. + """ + # check that mean and std are either both None, or both specified + if (self.mean is None) != (self.std is None): + raise ValueError( + "Mean and std must be either both None, or both specified." + ) + + return self + + @model_validator(mode="after") + def add_std_and_mean_to_normalize(self: Self) -> Self: + """ + Add mean and std to the Normalize transform if it is present. + + Returns + ------- + Self + Inference model with mean and std added to the Normalize transform. + """ + if self.mean is not None or self.std is not None: + # search in the transforms for Normalize and update parameters + if not isinstance(self.transforms, Compose): + for transform in self.transforms: + if transform.name == SupportedTransform.NORMALIZE.value: + transform.mean = self.mean + transform.std = self.std + + return self + + def _update(self, **kwargs: Any) -> None: + """ + Update multiple arguments at once. + + Parameters + ---------- + **kwargs : Any + Key-value pairs of arguments to update. + """ + self.__dict__.update(kwargs) + self.__class__.model_validate(self.__dict__) + + def set_3D(self, axes: str, tile_size: List[int], tile_overlap: List[int]) -> None: + """ + Set 3D parameters. + + Parameters + ---------- + axes : str + Axes. + tile_size : List[int] + Tile size. + tile_overlap : List[int] + Tile overlap. + """ + self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap) diff --git a/src/careamics/config/noise_models.py b/src/careamics/config/noise_models.py new file mode 100644 index 00000000..6dd01fa4 --- /dev/null +++ b/src/careamics/config/noise_models.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from enum import Enum +from typing import Dict, Union + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class NoiseModelType(str, Enum): + """ + Available noise models. + + Currently supported noise models: + + - hist: Histogram noise model. + - gmm: Gaussian mixture model noise model.F + """ + + NONE = "none" + HIST = "hist" + GMM = "gmm" + + # TODO add validator decorator + @classmethod + def validate_noise_model_type( + cls, noise_model: Union[str, NoiseModel], parameters: dict + ) -> None: + """_summary_. + + Parameters + ---------- + noise_model : Union[str, NoiseModel] + _description_ + parameters : dict + _description_ + + Returns + ------- + BaseModel + _description_ + """ + if noise_model == NoiseModelType.HIST.value: + HistogramNoiseModel(**parameters) + return HistogramNoiseModel().model_dump() if not parameters else parameters + + elif noise_model == NoiseModelType.GMM.value: + GaussianMixtureNoiseModel(**parameters) + return ( + GaussianMixtureNoiseModel().model_dump() + if not parameters + else parameters + ) + + +class NoiseModel(BaseModel): + """_summary_. + + Parameters + ---------- + BaseModel : _type_ + _description_ + + Returns + ------- + _type_ + _description_ + + Raises + ------ + ValueError + _description_ + """ + + model_config = ConfigDict( + use_enum_values=True, + protected_namespaces=(), # allows to use model_* as a field name + validate_assignment=True, + ) + + model_type: NoiseModelType + parameters: Dict = Field(default_factory=dict, validate_default=True) + + @field_validator("parameters") + @classmethod + def validate_parameters(cls, data, values) -> Dict: + """_summary_. + + Parameters + ---------- + parameters : Dict + _description_ + + Returns + ------- + Dict + _description_ + """ + if values.data["model_type"] not in [NoiseModelType.GMM, NoiseModelType.HIST]: + raise ValueError( + f"Incorrect noise model {values.data['model_type']}." + f"Please refer to the documentation" # TODO add link to documentation + ) + + parameters = NoiseModelType.validate_noise_model_type( + values.data["model_type"], data + ) + return parameters + + +class HistogramNoiseModel(BaseModel): + """ + Histogram noise model. + + Attributes + ---------- + min_value : float + Minimum value in the input. + max_value : float + Maximum value in the input. + bins : int + Number of bins of the histogram. + """ + + min_value: float = Field(default=350.0, ge=0.0, le=65535.0) + max_value: float = Field(default=6500.0, ge=0.0, le=65535.0) + bins: int = Field(default=256, ge=1) + + +class GaussianMixtureNoiseModel(BaseModel): + """ + Gaussian mixture model noise model. + + Attributes + ---------- + min_signal : float + Minimum signal intensity expected in the image. + max_signal : float + Maximum signal intensity expected in the image. + weight : array + A [3*n_gaussian, n_coeff] sized array containing the values of the weights + describing the noise model. + Each gaussian contributes three parameters (mean, standard deviation and weight), + hence the number of rows in `weight` are 3*n_gaussian. + If `weight = None`, the weight array is initialized using the `min_signal` and + `max_signal` parameters. + n_gaussian: int + Number of gaussians. + n_coeff: int + Number of coefficients to describe the functional relationship between gaussian + parameters and the signal. + 2 implies a linear relationship, 3 implies a quadratic relationship and so on. + device: device + GPU device. + min_sigma: int + """ + + num_components: int = Field(default=3, ge=1) + min_value: float = Field(default=350.0, ge=0.0, le=65535.0) + max_value: float = Field(default=6500.0, ge=0.0, le=65535.0) + n_gaussian: int = Field(default=3, ge=1) + n_coeff: int = Field(default=2, ge=1) + min_sigma: int = Field(default=50, ge=1) diff --git a/src/careamics/config/optimizer_models.py b/src/careamics/config/optimizer_models.py new file mode 100644 index 00000000..43acbd6e --- /dev/null +++ b/src/careamics/config/optimizer_models.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from typing import Dict, Literal + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationInfo, + field_validator, + model_validator, +) +from torch import optim +from typing_extensions import Self + +from careamics.utils.torch_utils import filter_parameters + +from .support import SupportedOptimizer + + +class OptimizerModel(BaseModel): + """ + Torch optimizer. + + Only parameters supported by the corresponding torch optimizer will be taken + into account. For more details, check: + https://pytorch.org/docs/stable/optim.html#algorithms + + Note that mandatory parameters (see the specific Optimizer signature in the + link above) must be provided. For example, SGD requires `lr`. + + Attributes + ---------- + name : TorchOptimizer + Name of the optimizer. + parameters : dict + Parameters of the optimizer (see torch documentation). + """ + + # Pydantic class configuration + model_config = ConfigDict( + validate_assignment=True, + ) + + # Mandatory field + name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True) + + # Optional parameters, empty dict default value to allow filtering dictionary + parameters: dict = Field( + default={ + "lr": 1e-4, + }, + validate_default=True, + ) + + @field_validator("parameters") + @classmethod + def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict: + """ + Validate optimizer parameters. + + This method filters out unknown parameters, given the optimizer name. + + Parameters + ---------- + user_params : dict + Parameters passed on to the torch optimizer. + values : ValidationInfo + Pydantic field validation info, used to get the optimizer name. + + Returns + ------- + Dict + Filtered optimizer parameters. + + Raises + ------ + ValueError + If the optimizer name is not specified. + """ + optimizer_name = values.data["name"] + + # retrieve the corresponding optimizer class + optimizer_class = getattr(optim, optimizer_name) + + # filter the user parameters according to the optimizer's signature + parameters = filter_parameters(optimizer_class, user_params) + + return parameters + + @model_validator(mode="after") + def sgd_lr_parameter(self) -> Self: + """ + Check that SGD optimizer has the mandatory `lr` parameter specified. + + This is specific for PyTorch < 2.2. + + Returns + ------- + Self + Validated optimizer. + + Raises + ------ + ValueError + If the optimizer is SGD and the lr parameter is not specified. + """ + if self.name == SupportedOptimizer.SGD and "lr" not in self.parameters: + raise ValueError( + "SGD optimizer requires `lr` parameter, check that it has correctly " + "been specified in `parameters`." + ) + + return self + + +class LrSchedulerModel(BaseModel): + """ + Torch learning rate scheduler. + + Only parameters supported by the corresponding torch lr scheduler will be taken + into account. For more details, check: + https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate + + Note that mandatory parameters (see the specific LrScheduler signature in the + link above) must be provided. For example, StepLR requires `step_size`. + + Attributes + ---------- + name : TorchLRScheduler + Name of the learning rate scheduler. + parameters : dict + Parameters of the learning rate scheduler (see torch documentation). + """ + + # Pydantic class configuration + model_config = ConfigDict( + validate_assignment=True, + ) + + # Mandatory field + name: Literal["ReduceLROnPlateau", "StepLR"] = Field(default="ReduceLROnPlateau") + + # Optional parameters + parameters: dict = Field(default={}, validate_default=True) + + @field_validator("parameters") + @classmethod + def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict: + """Filter parameters based on the learning rate scheduler's signature. + + Parameters + ---------- + user_params : dict + User parameters. + values : ValidationInfo + Pydantic field validation info, used to get the scheduler name. + + Returns + ------- + Dict + Filtered scheduler parameters. + + Raises + ------ + ValueError + If the scheduler is StepLR and the step_size parameter is not specified. + """ + # retrieve the corresponding scheduler class + scheduler_class = getattr(optim.lr_scheduler, values.data["name"]) + + # filter the user parameters according to the scheduler's signature + parameters = filter_parameters(scheduler_class, user_params) + + if values.data["name"] == "StepLR" and "step_size" not in parameters: + raise ValueError( + "StepLR scheduler requires `step_size` parameter, check that it has " + "correctly been specified in `parameters`." + ) + + return parameters diff --git a/src/careamics/config/references/__init__.py b/src/careamics/config/references/__init__.py new file mode 100644 index 00000000..b314f1c8 --- /dev/null +++ b/src/careamics/config/references/__init__.py @@ -0,0 +1,45 @@ +"""Module containing references to the algorithm used in CAREamics.""" + +__all__ = [ + "N2V2Ref", + "N2VRef", + "StructN2VRef", + "N2VDescription", + "N2V2Description", + "StructN2VDescription", + "StructN2V2Description", + "N2V", + "N2V2", + "STRUCT_N2V", + "STRUCT_N2V2", + "CUSTOM", + "N2N", + "CARE", + "CAREDescription", + "N2NDescription", + "CARERef", + "N2NRef", +] + +from .algorithm_descriptions import ( + CARE, + CUSTOM, + N2N, + N2V, + N2V2, + STRUCT_N2V, + STRUCT_N2V2, + CAREDescription, + N2NDescription, + N2V2Description, + N2VDescription, + StructN2V2Description, + StructN2VDescription, +) +from .references import ( + CARERef, + N2NRef, + N2V2Ref, + N2VRef, + StructN2VRef, +) diff --git a/src/careamics/config/references/algorithm_descriptions.py b/src/careamics/config/references/algorithm_descriptions.py new file mode 100644 index 00000000..97e4ba83 --- /dev/null +++ b/src/careamics/config/references/algorithm_descriptions.py @@ -0,0 +1,131 @@ +"""Descriptions of the algorithms used in CAREmics.""" +from pydantic import BaseModel + +CUSTOM = "Custom" +N2V = "Noise2Void" +N2V2 = "N2V2" +STRUCT_N2V = "StructN2V" +STRUCT_N2V2 = "StructN2V2" +N2N = "Noise2Noise" +CARE = "CARE" + + +N2V_DESCRIPTION = ( + "Noise2Void is a UNet-based self-supervised algorithm that " + "uses blind-spot training to denoise images. In short, in every " + "patches during training, random pixels are selected and their " + "value replaced by a neighboring pixel value. The network is then " + "trained to predict the original pixel value. The algorithm " + "relies on the continuity of the signal (neighboring pixels have " + "similar values) and the pixel-wise independence of the noise " + "(the noise in a pixel is not correlated with the noise in " + "neighboring pixels)." +) + + +class AlgorithmDescription(BaseModel): + """Description of an algorithm. + + Attributes + ---------- + description : str + Description of the algorithm. + """ + + description: str + + +class N2VDescription(AlgorithmDescription): + """Description of Noise2Void. + + Attributes + ---------- + description : str + Description of Noise2Void. + """ + + description: str = N2V_DESCRIPTION + + +class N2V2Description(AlgorithmDescription): + """Description of N2V2. + + Attributes + ---------- + description : str + Description of N2V2. + """ + + description: str = ( + "N2V2 is a variant of Noise2Void. " + + N2V_DESCRIPTION + + "\nN2V2 introduces blur-pool layers and removed skip " + "connections in the UNet architecture to remove checkboard " + "artefacts, a common artefacts ocurring in Noise2Void." + ) + + +class StructN2VDescription(AlgorithmDescription): + """Description of StructN2V. + + Attributes + ---------- + description : str + Description of StructN2V. + """ + + description: str = ( + "StructN2V is a variant of Noise2Void. " + + N2V_DESCRIPTION + + "\nStructN2V uses a linear mask (horizontal or vertical) to replace " + "the pixel values of neighbors of the masked pixels by a random " + "value. Such masking allows removing 1D structured noise from the " + "the images, the main failure case of the original N2V." + ) + + +class StructN2V2Description(AlgorithmDescription): + """Description of StructN2V2. + + Attributes + ---------- + description : str + Description of StructN2V2. + """ + + description: str = ( + "StructN2V2 is a a variant of Noise2Void that uses both " + "structN2V and N2V2. " + + N2V_DESCRIPTION + + "\nStructN2V2 uses a linear mask (horizontal or vertical) to replace " + "the pixel values of neighbors of the masked pixels by a random " + "value. Such masking allows removing 1D structured noise from the " + "the images, the main failure case of the original N2V." + "\nN2V2 introduces blur-pool layers and removed skip connections in " + "the UNet architecture to remove checkboard artefacts, a common " + "artefacts ocurring in Noise2Void." + ) + + +class N2NDescription(AlgorithmDescription): + """Description of Noise2Noise. + + Attributes + ---------- + description : str + Description of Noise2Noise. + """ + + description: str = "Noise2Noise" # TODO + + +class CAREDescription(AlgorithmDescription): + """Description of CARE. + + Attributes + ---------- + description : str + Description of CARE. + """ + + description: str = "CARE" # TODO diff --git a/src/careamics/config/references/references.py b/src/careamics/config/references/references.py new file mode 100644 index 00000000..60c8413f --- /dev/null +++ b/src/careamics/config/references/references.py @@ -0,0 +1,38 @@ +"""References for the CAREamics algorithms.""" +from bioimageio.spec.generic.v0_3 import CiteEntry + +N2VRef = CiteEntry( + text='Krull, A., Buchholz, T.O. and Jug, F., 2019. "Noise2Void - Learning ' + 'denoising from single noisy images". In Proceedings of the IEEE/CVF ' + "conference on computer vision and pattern recognition (pp. 2129-2137).", + doi="10.1109/cvpr.2019.00223", +) + +N2V2Ref = CiteEntry( + text="Hรถck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., " + '2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified ' + 'sampling strategies and a tweaked network architecture". In European ' + "Conference on Computer Vision (pp. 503-518).", + doi="10.1007/978-3-031-25069-9_33", +) + +StructN2VRef = CiteEntry( + text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020." + '"Removing structured noise with self-supervised blind-spot ' + 'networks". In 2020 IEEE 17th International Symposium on Biomedical ' + "Imaging (ISBI) (pp. 159-163).", + doi="10.1109/isbi45749.2020.9098336", +) + +N2NRef = CiteEntry( + text="Lehtinen, J., Munkberg, J., Hasselgren, J., Laine, S., Karras, T., " + 'Aittala, M. and Aila, T., 2018. "Noise2Noise: Learning image restoration ' + 'without clean data". arXiv preprint arXiv:1803.04189.', + doi="10.48550/arXiv.1803.04189", +) + +CARERef = CiteEntry( + text='Weigert, Martin, et al. "Content-aware image restoration: pushing the ' + 'limits of fluorescence microscopy." Nature methods 15.12 (2018): 1090-1097.', + doi="10.1038/s41592-018-0216-7", +) diff --git a/src/careamics/config/support/__init__.py b/src/careamics/config/support/__init__.py new file mode 100644 index 00000000..abb8284c --- /dev/null +++ b/src/careamics/config/support/__init__.py @@ -0,0 +1,33 @@ +"""Supported configuration options. + +Used throughout the code to ensure consistency. These should be kept in sync with the +corresponding configuration options in the Pydantic models. +""" + +__all__ = [ + "SupportedArchitecture", + "SupportedActivation", + "SupportedOptimizer", + "SupportedScheduler", + "SupportedLoss", + "SupportedAlgorithm", + "SupportedPixelManipulation", + "SupportedTransform", + "SupportedData", + "SupportedExtractionStrategy", + "SupportedStructAxis", + "SupportedLogger", +] + + +from .supported_activations import SupportedActivation +from .supported_algorithms import SupportedAlgorithm +from .supported_architectures import SupportedArchitecture +from .supported_data import SupportedData +from .supported_extraction_strategies import SupportedExtractionStrategy +from .supported_loggers import SupportedLogger +from .supported_losses import SupportedLoss +from .supported_optimizers import SupportedOptimizer, SupportedScheduler +from .supported_pixel_manipulations import SupportedPixelManipulation +from .supported_struct_axis import SupportedStructAxis +from .supported_transforms import SupportedTransform diff --git a/src/careamics/config/support/supported_activations.py b/src/careamics/config/support/supported_activations.py new file mode 100644 index 00000000..d7c84ae3 --- /dev/null +++ b/src/careamics/config/support/supported_activations.py @@ -0,0 +1,24 @@ +from careamics.utils import BaseEnum + + +class SupportedActivation(str, BaseEnum): + """Supported activation functions. + + - None, no activation will be used. + - Sigmoid + - Softmax + - Tanh + - ReLU + - LeakyReLU + + All activations are defined in PyTorch. + + See: https://pytorch.org/docs/stable/nn.html#loss-functions + """ + + NONE = "None" + SIGMOID = "Sigmoid" + SOFTMAX = "Softmax" + TANH = "Tanh" + RELU = "ReLU" + LEAKYRELU = "LeakyReLU" diff --git a/src/careamics/config/support/supported_algorithms.py b/src/careamics/config/support/supported_algorithms.py new file mode 100644 index 00000000..a44e179b --- /dev/null +++ b/src/careamics/config/support/supported_algorithms.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from careamics.utils import BaseEnum + + +class SupportedAlgorithm(str, BaseEnum): + """Algorithms available in CAREamics. + + # TODO + """ + + N2V = "n2v" + CUSTOM = "custom" + CARE = "care" + N2N = "n2n" + # PN2V = "pn2v" + # HDN = "hdn" + # SEG = "segmentation" diff --git a/src/careamics/config/support/supported_architectures.py b/src/careamics/config/support/supported_architectures.py new file mode 100644 index 00000000..8246cf9c --- /dev/null +++ b/src/careamics/config/support/supported_architectures.py @@ -0,0 +1,18 @@ +from careamics.utils import BaseEnum + + +class SupportedArchitecture(str, BaseEnum): + """Supported architectures. + + # TODO add details, in particular where to find the API for the models + + - UNet: classical UNet compatible with N2V2 + - VAE: variational Autoencoder + - Custom: custom model registered with `@register_model` decorator + """ + + UNET = "UNet" + VAE = "VAE" + CUSTOM = ( + "Custom" # TODO all the others tags are small letters, except the architect + ) diff --git a/src/careamics/config/support/supported_data.py b/src/careamics/config/support/supported_data.py new file mode 100644 index 00000000..cb9fd700 --- /dev/null +++ b/src/careamics/config/support/supported_data.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import Union + +from careamics.utils import BaseEnum + + +class SupportedData(str, BaseEnum): + """Supported data types. + + Attributes + ---------- + ARRAY : str + Array data. + TIFF : str + TIFF image data. + CUSTOM : str + Custom data. + """ + + ARRAY = "array" + TIFF = "tiff" + CUSTOM = "custom" + # ZARR = "zarr" + + # TODO remove? + @classmethod + def _missing_(cls, value: object) -> str: + """ + Override default behaviour for missing values. + + This method is called when `value` is not found in the enum values. It converts + `value` to lowercase, removes "." if it is the first character and tries to + match it with enum values. + + Parameters + ---------- + value : object + Value to be matched with enum values. + + Returns + ------- + str + Matched enum value. + """ + if isinstance(value, str): + lower_value = value.lower() + + if lower_value.startswith("."): + lower_value = lower_value[1:] + + # attempt to match lowercase value with enum values + for member in cls: + if member.value == lower_value: + return member + + # still missing + return super()._missing_(value) + + @classmethod + def get_extension(cls, data_type: Union[str, SupportedData]) -> str: + """ + Path.rglob and fnmatch compatible extension. + + Parameters + ---------- + data_type : SupportedData + Data type. + + Returns + ------- + str + Corresponding extension. + """ + if data_type == cls.ARRAY: + raise NotImplementedError(f"Data {data_type} are not loaded from file.") + elif data_type == cls.TIFF: + return "*.tif*" + elif data_type == cls.CUSTOM: + return "*.*" + else: + raise ValueError(f"Data type {data_type} is not supported.") diff --git a/src/careamics/dataset/extraction_strategy.py b/src/careamics/config/support/supported_extraction_strategies.py similarity index 73% rename from src/careamics/dataset/extraction_strategy.py rename to src/careamics/config/support/supported_extraction_strategies.py index 8cab86ea..d08c513f 100644 --- a/src/careamics/dataset/extraction_strategy.py +++ b/src/careamics/config/support/supported_extraction_strategies.py @@ -3,19 +3,22 @@ This module defines the various extraction strategies available in CAREamics. """ -from enum import Enum +from careamics.utils import BaseEnum -class ExtractionStrategy(str, Enum): +class SupportedExtractionStrategy(str, BaseEnum): """ Available extraction strategies. Currently supported: - random: random extraction. + # TODO - sequential: grid extraction, can miss edge values. - tiled: tiled extraction, covers the whole image. """ RANDOM = "random" + RANDOM_ZARR = "random_zarr" SEQUENTIAL = "sequential" TILED = "tiled" + NONE = "none" diff --git a/src/careamics/config/support/supported_loggers.py b/src/careamics/config/support/supported_loggers.py new file mode 100644 index 00000000..b4d4842f --- /dev/null +++ b/src/careamics/config/support/supported_loggers.py @@ -0,0 +1,8 @@ +from careamics.utils import BaseEnum + + +class SupportedLogger(str, BaseEnum): + """Available loggers.""" + + WANDB = "wandb" + TENSORBOARD = "tensorboard" diff --git a/src/careamics/config/support/supported_losses.py b/src/careamics/config/support/supported_losses.py new file mode 100644 index 00000000..4235034e --- /dev/null +++ b/src/careamics/config/support/supported_losses.py @@ -0,0 +1,25 @@ +from careamics.utils import BaseEnum + + +# TODO register loss with custom_loss decorator? +class SupportedLoss(str, BaseEnum): + """Supported losses. + + Attributes + ---------- + MSE : str + Mean Squared Error loss. + MAE : str + Mean Absolute Error loss. + N2V : str + Noise2Void loss. + """ + + MSE = "mse" + MAE = "mae" + N2V = "n2v" + # PN2V = "pn2v" + # HDN = "hdn" + # CE = "ce" + # DICE = "dice" + # CUSTOM = "custom" # TODO create mechanism for that diff --git a/src/careamics/config/support/supported_optimizers.py b/src/careamics/config/support/supported_optimizers.py new file mode 100644 index 00000000..ab8b0ff9 --- /dev/null +++ b/src/careamics/config/support/supported_optimizers.py @@ -0,0 +1,55 @@ +from careamics.utils import BaseEnum + + +class SupportedOptimizer(str, BaseEnum): + """Supported optimizers. + + Attributes + ---------- + Adam : str + Adam optimizer. + SGD : str + Stochastic Gradient Descent optimizer. + """ + + # ASGD = "ASGD" + # Adadelta = "Adadelta" + # Adagrad = "Adagrad" + Adam = "Adam" + # AdamW = "AdamW" + # Adamax = "Adamax" + # LBFGS = "LBFGS" + # NAdam = "NAdam" + # RAdam = "RAdam" + # RMSprop = "RMSprop" + # Rprop = "Rprop" + SGD = "SGD" + # SparseAdam = "SparseAdam" + + +class SupportedScheduler(str, BaseEnum): + """Supported schedulers. + + Attributes + ---------- + ReduceLROnPlateau : str + Reduce learning rate on plateau. + StepLR : str + Step learning rate. + """ + + # ChainedScheduler = "ChainedScheduler" + # ConstantLR = "ConstantLR" + # CosineAnnealingLR = "CosineAnnealingLR" + # CosineAnnealingWarmRestarts = "CosineAnnealingWarmRestarts" + # CyclicLR = "CyclicLR" + # ExponentialLR = "ExponentialLR" + # LambdaLR = "LambdaLR" + # LinearLR = "LinearLR" + # MultiStepLR = "MultiStepLR" + # MultiplicativeLR = "MultiplicativeLR" + # OneCycleLR = "OneCycleLR" + # PolynomialLR = "PolynomialLR" + ReduceLROnPlateau = "ReduceLROnPlateau" + # SequentialLR = "SequentialLR" + StepLR = "StepLR" diff --git a/src/careamics/config/support/supported_pixel_manipulations.py b/src/careamics/config/support/supported_pixel_manipulations.py new file mode 100644 index 00000000..84db6d05 --- /dev/null +++ b/src/careamics/config/support/supported_pixel_manipulations.py @@ -0,0 +1,15 @@ +from careamics.utils import BaseEnum + + +class SupportedPixelManipulation(str, BaseEnum): + """_summary_. + + - Uniform: Replace masked pixel value by a (uniformly) randomly selected neighbor + pixel value. + - Median: Replace masked pixel value by the mean of the neighborhood. + """ + + # TODO docs + + UNIFORM = "uniform" + MEDIAN = "median" diff --git a/src/careamics/config/support/supported_struct_axis.py b/src/careamics/config/support/supported_struct_axis.py new file mode 100644 index 00000000..4d82307d --- /dev/null +++ b/src/careamics/config/support/supported_struct_axis.py @@ -0,0 +1,19 @@ +from careamics.utils import BaseEnum + + +class SupportedStructAxis(str, BaseEnum): + """Supported structN2V mask axes. + + Attributes + ---------- + HORIZONTAL : str + Horizontal axis. + VERTICAL : str + Vertical axis. + NONE : str + No axis, the mask is not applied. + """ + + HORIZONTAL = "horizontal" + VERTICAL = "vertical" + NONE = "none" diff --git a/src/careamics/config/support/supported_transforms.py b/src/careamics/config/support/supported_transforms.py new file mode 100644 index 00000000..b262c169 --- /dev/null +++ b/src/careamics/config/support/supported_transforms.py @@ -0,0 +1,23 @@ +from careamics.utils import BaseEnum + + +class SupportedTransform(str, BaseEnum): + """Transforms officially supported by CAREamics. + + - Flip: from Albumentations, randomly flip the input horizontally, vertically or + both, parameter `p` can be used to set the probability to apply the transform. + - XYRandomRotate90: #TODO + - Normalize # TODO add details, in particular about the parameters + - ManipulateN2V # TODO add details, in particular about the parameters + - NDFlip + + Note that while any Albumentations (see https://albumentations.ai/) transform can be + used in CAREamics, no check are implemented to verify the compatibility of any other + transforms than the ones officially supported. + """ + + NDFLIP = "NDFlip" + XY_RANDOM_ROTATE90 = "XYRandomRotate90" + NORMALIZE = "Normalize" + N2V_MANIPULATE = "N2VManipulate" + # CUSTOM = "Custom" diff --git a/src/careamics/config/tile_information.py b/src/careamics/config/tile_information.py new file mode 100644 index 00000000..e018e0f1 --- /dev/null +++ b/src/careamics/config/tile_information.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from typing import Optional, Tuple + +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator + + +class TileInformation(BaseModel): + """ + Pydantic model containing tile information. + + This model is used to represent the information required to stitch back a tile into + a larger image. It is used throughout the prediction pipeline of CAREamics. + """ + + model_config = ConfigDict(validate_default=True) + + array_shape: Tuple[int, ...] + tiled: bool = False + last_tile: bool = False + overlap_crop_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None) + stitch_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None) + + @field_validator("array_shape") + @classmethod + def no_singleton_dimensions(cls, v: Tuple[int, ...]): + """ + Check that the array shape does not have any singleton dimensions. + + Parameters + ---------- + v : Tuple[int, ...] + Array shape to check. + + Returns + ------- + Tuple[int, ...] + The array shape if it does not contain singleton dimensions. + + Raises + ------ + ValueError + If the array shape contains singleton dimensions. + """ + if any(dim == 1 for dim in v): + raise ValueError("Array shape must not contain singleton dimensions.") + return v + + @field_validator("last_tile") + @classmethod + def only_if_tiled(cls, v: bool, values: ValidationInfo): + """ + Check that the last tile flag is only set if tiling is enabled. + + Parameters + ---------- + v : bool + Last tile flag. + values : ValidationInfo + Validation information. + + Returns + ------- + bool + The last tile flag. + """ + if not values.data["tiled"]: + return False + return v + + @field_validator("overlap_crop_coords", "stitch_coords") + @classmethod + def mandatory_if_tiled( + cls, v: Optional[Tuple[int, ...]], values: ValidationInfo + ) -> Optional[Tuple[int, ...]]: + """ + Check that the coordinates are not `None` if tiling is enabled. + + The method also return `None` if tiling is not enabled. + + Parameters + ---------- + v : Optional[Tuple[int, ...]] + Coordinates to check. + values : ValidationInfo + Validation information. + + Returns + ------- + Optional[Tuple[int, ...]] + The coordinates if tiling is enabled, otherwise `None`. + + Raises + ------ + ValueError + If the coordinates are `None` and tiling is enabled. + """ + if values.data["tiled"]: + if v is None: + raise ValueError("Value must be specified if tiling is enabled.") + + return v + else: + return None diff --git a/src/careamics/config/torch_optim.py b/src/careamics/config/torch_optim.py deleted file mode 100644 index 75db3f64..00000000 --- a/src/careamics/config/torch_optim.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Convenience functions to instantiate torch.optim optimizers and schedulers.""" -import inspect -from enum import Enum -from typing import Dict - -from torch import optim - - -class TorchOptimizer(str, Enum): - """ - Supported optimizers. - - Currently only supports Adam and SGD. - """ - - # ASGD = "ASGD" - # Adadelta = "Adadelta" - # Adagrad = "Adagrad" - Adam = "Adam" - # AdamW = "AdamW" - # Adamax = "Adamax" - # LBFGS = "LBFGS" - # NAdam = "NAdam" - # RAdam = "RAdam" - # RMSprop = "RMSprop" - # Rprop = "Rprop" - SGD = "SGD" - # SparseAdam = "SparseAdam" - - -# TODO: Test which schedulers are compatible and if not, how to make them compatible -# (if we want to support them) -class TorchLRScheduler(str, Enum): - """ - Supported learning rate schedulers. - - Currently only supports ReduceLROnPlateau and StepLR. - """ - - # ChainedScheduler = "ChainedScheduler" - # ConstantLR = "ConstantLR" - # CosineAnnealingLR = "CosineAnnealingLR" - # CosineAnnealingWarmRestarts = "CosineAnnealingWarmRestarts" - # CyclicLR = "CyclicLR" - # ExponentialLR = "ExponentialLR" - # LambdaLR = "LambdaLR" - # LinearLR = "LinearLR" - # MultiStepLR = "MultiStepLR" - # MultiplicativeLR = "MultiplicativeLR" - # OneCycleLR = "OneCycleLR" - # PolynomialLR = "PolynomialLR" - ReduceLROnPlateau = "ReduceLROnPlateau" - # SequentialLR = "SequentialLR" - StepLR = "StepLR" - - -def get_parameters( - func: type, - user_params: dict, -) -> dict: - """ - Filter parameters according to the function signature. - - Parameters - ---------- - func : type - Class object. - user_params : Dict - User provided parameters. - - Returns - ------- - Dict - Parameters matching `func`'s signature. - """ - # Get the list of all default parameters - default_params = list(inspect.signature(func).parameters.keys()) - - # Filter matching parameters - params_to_be_used = set(user_params.keys()) & set(default_params) - - return {key: user_params[key] for key in params_to_be_used} - - -def get_optimizers() -> Dict[str, str]: - """ - Return the list of all optimizers available in torch.optim. - - Returns - ------- - Dict - Optimizers available in torch.optim. - """ - optims = {} - for name, obj in inspect.getmembers(optim): - if inspect.isclass(obj) and issubclass(obj, optim.Optimizer): - if name != "Optimizer": - optims[name] = name - return optims - - -def get_schedulers() -> Dict[str, str]: - """ - Return the list of all schedulers available in torch.optim.lr_scheduler. - - Returns - ------- - Dict - Schedulers available in torch.optim.lr_scheduler. - """ - schedulers = {} - for name, obj in inspect.getmembers(optim.lr_scheduler): - if inspect.isclass(obj) and issubclass(obj, optim.lr_scheduler.LRScheduler): - if "LRScheduler" not in name: - schedulers[name] = name - elif name == "ReduceLROnPlateau": # somewhat not a subclass of LRScheduler - schedulers[name] = name - return schedulers diff --git a/src/careamics/config/training.py b/src/careamics/config/training.py deleted file mode 100644 index b3bd1ff6..00000000 --- a/src/careamics/config/training.py +++ /dev/null @@ -1,534 +0,0 @@ -"""Training configuration.""" -from __future__ import annotations - -from typing import Dict, List - -from pydantic import ( - BaseModel, - ConfigDict, - Field, - FieldValidationInfo, - field_validator, - model_validator, -) -from torch import optim - -from .config_filter import remove_default_optionals -from .torch_optim import TorchLRScheduler, TorchOptimizer, get_parameters - - -class Optimizer(BaseModel): - """ - Torch optimizer. - - Only parameters supported by the corresponding torch optimizer will be taken - into account. For more details, check: - https://pytorch.org/docs/stable/optim.html#algorithms - - Note that mandatory parameters (see the specific Optimizer signature in the - link above) must be provided. For example, SGD requires `lr`. - - Attributes - ---------- - name : TorchOptimizer - Name of the optimizer. - parameters : dict - Parameters of the optimizer (see torch documentation). - """ - - # Pydantic class configuration - model_config = ConfigDict( - use_enum_values=True, - validate_assignment=True, - ) - - # Mandatory field - name: TorchOptimizer - - # Optional parameters - parameters: dict = {} - - @field_validator("parameters") - def filter_parameters(cls, user_params: dict, values: FieldValidationInfo) -> Dict: - """ - Validate optimizer parameters. - - This method filters out unknown parameters, given the optimizer name. - - Parameters - ---------- - user_params : dict - Parameters passed on to the torch optimizer. - values : FieldValidationInfo - Pydantic field validation info, used to get the optimizer name. - - Returns - ------- - Dict - Filtered optimizer parameters. - - Raises - ------ - ValueError - If the optimizer name is not specified. - """ - if "name" in values.data: - optimizer_name = values.data["name"] - - # retrieve the corresponding optimizer class - optimizer_class = getattr(optim, optimizer_name) - - # filter the user parameters according to the optimizer's signature - return get_parameters(optimizer_class, user_params) - else: - raise ValueError( - "Cannot validate optimizer parameters without `name`, check that it " - "has correctly been specified." - ) - - @model_validator(mode="after") - def sgd_lr_parameter(cls, optimizer: Optimizer) -> Optimizer: - """ - Check that SGD optimizer has the mandatory `lr` parameter specified. - - Parameters - ---------- - optimizer : Optimizer - Optimizer to validate. - - Returns - ------- - Optimizer - Validated optimizer. - - Raises - ------ - ValueError - If the optimizer is SGD and the lr parameter is not specified. - """ - if optimizer.name == TorchOptimizer.SGD and "lr" not in optimizer.parameters: - raise ValueError( - "SGD optimizer requires `lr` parameter, check that it has correctly " - "been specified in `parameters`." - ) - - return optimizer - - def model_dump( - self, exclude_optionals: bool = True, *args: List, **kwargs: Dict - ) -> Dict: - """ - Override model_dump method. - - The purpose of this method is to ensure smooth export to yaml. It - includes: - - removing entries with None value. - - removing optional values if they have the default value. - - Parameters - ---------- - exclude_optionals : bool, optional - Whether to exclude optional arguments if they are default, by default True. - *args : List - Positional arguments, unused. - **kwargs : Dict - Keyword arguments, unused. - - Returns - ------- - dict - Dictionary containing the model parameters. - """ - dictionary = super().model_dump(exclude_none=True) - - if exclude_optionals: - # remove optional arguments if they are default - default_optionals: dict = {"parameters": {}} - - remove_default_optionals(dictionary, default_optionals) - - return dictionary - - -class LrScheduler(BaseModel): - """ - Torch learning rate scheduler. - - Only parameters supported by the corresponding torch lr scheduler will be taken - into account. For more details, check: - https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate - - Note that mandatory parameters (see the specific LrScheduler signature in the - link above) must be provided. For example, StepLR requires `step_size`. - - Attributes - ---------- - name : TorchLRScheduler - Name of the learning rate scheduler. - parameters : dict - Parameters of the learning rate scheduler (see torch documentation). - """ - - # Pydantic class configuration - model_config = ConfigDict( - use_enum_values=True, - validate_assignment=True, - ) - - # Mandatory field - name: TorchLRScheduler - - # Optional parameters - parameters: dict = {} - - @field_validator("parameters") - def filter_parameters(cls, user_params: dict, values: FieldValidationInfo) -> Dict: - """ - Validate lr scheduler parameters. - - This method filters out unknown parameters, given the lr scheduler name. - - Parameters - ---------- - user_params : dict - Parameters passed on to the torch lr scheduler. - values : FieldValidationInfo - Pydantic field validation info, used to get the lr scheduler name. - - Returns - ------- - Dict - Filtered lr scheduler parameters. - - Raises - ------ - ValueError - If the lr scheduler name is not specified. - """ - if "name" in values.data: - lr_scheduler_name = values.data["name"] - - # retrieve the corresponding lr scheduler class - lr_scheduler_class = getattr(optim.lr_scheduler, lr_scheduler_name) - - # filter the user parameters according to the lr scheduler's signature - return get_parameters(lr_scheduler_class, user_params) - else: - raise ValueError( - "Cannot validate lr scheduler parameters without `name`, check that it " - "has correctly been specified." - ) - - @model_validator(mode="after") - def step_lr_step_size_parameter(cls, lr_scheduler: LrScheduler) -> LrScheduler: - """ - Check that StepLR lr scheduler has `step_size` parameter specified. - - Parameters - ---------- - lr_scheduler : LrScheduler - Lr scheduler to validate. - - Returns - ------- - LrScheduler - Validated lr scheduler. - - Raises - ------ - ValueError - If the lr scheduler is StepLR and the step_size parameter is not specified. - """ - if ( - lr_scheduler.name == TorchLRScheduler.StepLR - and "step_size" not in lr_scheduler.parameters - ): - raise ValueError( - "StepLR lr scheduler requires `step_size` parameter, check that it has " - "correctly been specified in `parameters`." - ) - - return lr_scheduler - - def model_dump( - self, exclude_optionals: bool = True, *args: List, **kwargs: Dict - ) -> Dict: - """ - Override model_dump method. - - The purpose of this method is to ensure smooth export to yaml. It includes: - - removing entries with None value. - - removing optional values if they have the default value. - - Parameters - ---------- - exclude_optionals : bool, optional - Whether to exclude optional arguments if they are default, by default True. - *args : List - Positional arguments, unused. - **kwargs : Dict - Keyword arguments, unused. - - Returns - ------- - dict - Dictionary containing the model parameters. - """ - dictionary = super().model_dump(exclude_none=True) - - if exclude_optionals: - # remove optional arguments if they are default - default_optionals: dict = {"parameters": {}} - remove_default_optionals(dictionary, default_optionals) - - return dictionary - - -class AMP(BaseModel): - """ - Automatic mixed precision (AMP) parameters. - - See: https://pytorch.org/docs/stable/amp.html. - - Attributes - ---------- - use : bool, optional - Whether to use AMP or not, default False. - init_scale : int, optional - Initial scale used for loss scaling, default 1024. - """ - - model_config = ConfigDict( - validate_assignment=True, - ) - - use: bool = False - - # TODO review init_scale and document better - init_scale: int = Field(default=1024, ge=512, le=65536) - - @field_validator("init_scale") - def power_of_two(cls, scale: int) -> int: - """ - Validate that init_scale is a power of two. - - Parameters - ---------- - scale : int - Initial scale used for loss scaling. - - Returns - ------- - int - Validated initial scale. - - Raises - ------ - ValueError - If the init_scale is not a power of two. - """ - if not scale & (scale - 1) == 0: - raise ValueError(f"Init scale must be a power of two (got {scale}).") - - return scale - - def model_dump( - self, exclude_optionals: bool = True, *args: List, **kwargs: Dict - ) -> Dict: - """ - Override model_dump method. - - The purpose is to ensure export smooth import to yaml. It includes: - - remove entries with None value. - - remove optional values if they have the default value. - - Parameters - ---------- - exclude_optionals : bool, optional - Whether to exclude optional arguments if they are default, by default True. - *args : List - Positional arguments, unused. - **kwargs : Dict - Keyword arguments, unused. - - Returns - ------- - dict - Dictionary containing the model parameters. - """ - dictionary = super().model_dump(exclude_none=True) - - if exclude_optionals: - # remove optional arguments if they are default - defaults = { - "init_scale": 1024, - } - - remove_default_optionals(dictionary, defaults) - - return dictionary - - -class Training(BaseModel): - """ - Parameters related to the training. - - Mandatory parameters are: - - num_epochs: number of epochs, greater than 0. - - patch_size: patch size, 2D or 3D, non-zero and divisible by 2. - - batch_size: batch size, greater than 0. - - optimizer: optimizer, see `Optimizer`. - - lr_scheduler: learning rate scheduler, see `LrScheduler`. - - augmentation: whether to use data augmentation or not (True or False). - - The other fields are optional: - - use_wandb: whether to use wandb or not (default True). - - num_workers: number of workers (default 0). - - amp: automatic mixed precision parameters (disabled by default). - - Attributes - ---------- - num_epochs : int - Number of epochs, greater than 0. - patch_size : conlist(int, min_length=2, max_length=3) - Patch size, 2D or 3D, non-zero and divisible by 2. - batch_size : int - Batch size, greater than 0. - optimizer : Optimizer - Optimizer. - lr_scheduler : LrScheduler - Learning rate scheduler. - augmentation : bool - Whether to use data augmentation or not. - use_wandb : bool - Optional, whether to use wandb or not (default True). - num_workers : int - Optional, number of workers (default 0). - amp : AMP - Optional, automatic mixed precision parameters (disabled by default). - """ - - # Pydantic class configuration - model_config = ConfigDict( - use_enum_values=True, - validate_assignment=True, - ) - - # Mandatory fields - num_epochs: int - patch_size: List[int] = Field(..., min_length=2, max_length=3) - batch_size: int - - optimizer: Optimizer - lr_scheduler: LrScheduler - - augmentation: bool - - # Optional fields - use_wandb: bool = False - num_workers: int = Field(default=0, ge=0) - amp: AMP = AMP() - - @field_validator("num_epochs", "batch_size") - def greater_than_0(cls, val: int) -> int: - """ - Validate number of epochs. - - Number of epochs must be greater than 0. - - Parameters - ---------- - val : int - Number of epochs. - - Returns - ------- - int - Validated number of epochs. - - Raises - ------ - ValueError - If the number of epochs is 0. - """ - if val < 1: - raise ValueError(f"Number of epochs must be greater than 0 (got {val}).") - - return val - - @field_validator("patch_size") - def all_elements_non_zero_divisible_by_2(cls, patch_list: List[int]) -> List[int]: - """ - Validate patch size. - - Patch size must be non-zero, positive and divisible by 2. - - Parameters - ---------- - patch_list : List[int] - Patch size. - - Returns - ------- - List[int] - Validated patch size. - - Raises - ------ - ValueError - If the patch size is 0. - ValueError - If the patch size is not divisible by 2. - """ - for dim in patch_list: - if dim < 1: - raise ValueError(f"Patch size must be non-zero positive (got {dim}).") - - if dim % 2 != 0: - raise ValueError(f"Patch size must be divisible by 2 (got {dim}).") - - return patch_list - - def model_dump( - self, exclude_optionals: bool = True, *args: List, **kwargs: Dict - ) -> Dict: - """ - Override model_dump method. - - The purpose is to ensure export smooth import to yaml. It includes: - - remove entries with None value. - - remove optional values if they have the default value. - - Parameters - ---------- - exclude_optionals : bool, optional - Whether to exclude optional arguments if they are default, by default True. - *args : List - Positional arguments, unused. - **kwargs : Dict - Keyword arguments, unused. - - Returns - ------- - dict - Dictionary containing the model parameters. - """ - dictionary = super().model_dump(exclude_none=True) - - dictionary["optimizer"] = self.optimizer.model_dump(exclude_optionals) - dictionary["lr_scheduler"] = self.lr_scheduler.model_dump(exclude_optionals) - - if self.amp is not None: - dictionary["amp"] = self.amp.model_dump(exclude_optionals) - - if exclude_optionals: - # remove optional arguments if they are default - defaults = { - "use_wandb": False, - "num_workers": 0, - "amp": AMP().model_dump(), - } - - remove_default_optionals(dictionary, defaults) - - return dictionary diff --git a/src/careamics/config/training_model.py b/src/careamics/config/training_model.py new file mode 100644 index 00000000..612843e7 --- /dev/null +++ b/src/careamics/config/training_model.py @@ -0,0 +1,65 @@ +"""Training configuration.""" +from __future__ import annotations + +from pprint import pformat +from typing import Literal, Optional + +from pydantic import ( + BaseModel, + ConfigDict, + Field, +) + +from .callback_model import CheckpointModel, EarlyStoppingModel + + +class TrainingModel(BaseModel): + """ + Parameters related to the training. + + Mandatory parameters are: + - num_epochs: number of epochs, greater than 0. + - batch_size: batch size, greater than 0. + - augmentation: whether to use data augmentation or not (True or False). + + Attributes + ---------- + num_epochs : int + Number of epochs, greater than 0. + """ + + # Pydantic class configuration + model_config = ConfigDict( + validate_assignment=True, + ) + + num_epochs: int = Field(default=20, ge=1) + + logger: Optional[Literal["wandb", "tensorboard"]] = None + + checkpoint_callback: CheckpointModel = CheckpointModel() + + early_stopping_callback: Optional[EarlyStoppingModel] = Field( + default=None, validate_default=True + ) + # precision: Literal["64", "32", "16", "bf16"] = 32 + + def __str__(self) -> str: + """Pretty string reprensenting the configuration. + + Returns + ------- + str + Pretty string. + """ + return pformat(self.model_dump()) + + def has_logger(self) -> bool: + """Check if the logger is defined. + + Returns + ------- + bool + Whether the logger is defined or not. + """ + return self.logger is not None diff --git a/src/careamics/config/transformations/__init__.py b/src/careamics/config/transformations/__init__.py new file mode 100644 index 00000000..d5aaa92e --- /dev/null +++ b/src/careamics/config/transformations/__init__.py @@ -0,0 +1,14 @@ +"""CAREamics transformation Pydantic models.""" + +__all__ = [ + "N2VManipulateModel", + "NDFlipModel", + "NormalizeModel", + "XYRandomRotate90Model", +] + + +from .n2v_manipulate_model import N2VManipulateModel +from .nd_flip_model import NDFlipModel +from .normalize_model import NormalizeModel +from .xy_random_rotate90_model import XYRandomRotate90Model diff --git a/src/careamics/config/transformations/n2v_manipulate_model.py b/src/careamics/config/transformations/n2v_manipulate_model.py new file mode 100644 index 00000000..5e4a0358 --- /dev/null +++ b/src/careamics/config/transformations/n2v_manipulate_model.py @@ -0,0 +1,63 @@ +"""Pydantic model for the N2VManipulate transform.""" +from typing import Literal + +from pydantic import ConfigDict, Field, field_validator + +from .transform_model import TransformModel + + +class N2VManipulateModel(TransformModel): + """ + Pydantic model used to represent N2V manipulation. + + Attributes + ---------- + name : Literal["N2VManipulate"] + Name of the transformation. + roi_size : int + Size of the masking region, by default 11. + masked_pixel_percentage : float + Percentage of masked pixels, by default 0.2. + strategy : Literal["uniform", "median"] + Strategy pixel value replacement, by default "uniform". + struct_mask_axis : Literal["horizontal", "vertical", "none"] + Axis of the structN2V mask, by default "none". + struct_mask_span : int + Span of the structN2V mask, by default 5. + """ + + model_config = ConfigDict( + validate_assignment=True, + ) + + name: Literal["N2VManipulate"] + roi_size: int = Field(default=11, ge=3, le=21) + masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=1.0) + strategy: Literal["uniform", "median"] = Field(default="uniform") + struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none") + struct_mask_span: int = Field(default=5, ge=3, le=15) + + @field_validator("roi_size", "struct_mask_span") + @classmethod + def odd_value(cls, v: int) -> int: + """ + Validate that the value is odd. + + Parameters + ---------- + v : int + Value to validate. + + Returns + ------- + int + The validated value. + + Raises + ------ + ValueError + If the value is even. + """ + if v % 2 == 0: + raise ValueError("Size must be an odd number.") + return v diff --git a/src/careamics/config/transformations/nd_flip_model.py b/src/careamics/config/transformations/nd_flip_model.py new file mode 100644 index 00000000..9787a13a --- /dev/null +++ b/src/careamics/config/transformations/nd_flip_model.py @@ -0,0 +1,32 @@ +"""Pydantic model for the NDFlip transform.""" +from typing import Literal + +from pydantic import ConfigDict, Field + +from .transform_model import TransformModel + + +class NDFlipModel(TransformModel): + """ + Pydantic model used to represent NDFlip transformation. + + Attributes + ---------- + name : Literal["NDFlip"] + Name of the transformation. + p : float + Probability of applying the transformation, by default 0.5. + is_3D : bool + Whether the transformation should be applied in 3D, by default False. + flip_z : bool + Whether to flip the z axis, by default True. + """ + + model_config = ConfigDict( + validate_assignment=True, + ) + + name: Literal["NDFlip"] + p: float = Field(default=0.5, ge=0.0, le=1.0) + is_3D: bool = Field(default=False) + flip_z: bool = Field(default=True) diff --git a/src/careamics/config/transformations/normalize_model.py b/src/careamics/config/transformations/normalize_model.py new file mode 100644 index 00000000..cc156dbf --- /dev/null +++ b/src/careamics/config/transformations/normalize_model.py @@ -0,0 +1,31 @@ +"""Pydantic model for the Normalize transform.""" +from typing import Literal + +from pydantic import ConfigDict, Field + +from .transform_model import TransformModel + + +class NormalizeModel(TransformModel): + """ + Pydantic model used to represent Normalize transformation. + + The Normalize transform is a zero mean and unit variance transformation. + + Attributes + ---------- + name : Literal["Normalize"] + Name of the transformation. + mean : float + Mean value for normalization. + std : float + Standard deviation value for normalization. + """ + + model_config = ConfigDict( + validate_assignment=True, + ) + + name: Literal["Normalize"] + mean: float = Field(default=0.485) # albumentations defaults + std: float = Field(default=0.229) diff --git a/src/careamics/config/transformations/transform_model.py b/src/careamics/config/transformations/transform_model.py new file mode 100644 index 00000000..ffbc6022 --- /dev/null +++ b/src/careamics/config/transformations/transform_model.py @@ -0,0 +1,44 @@ +"""Parent model for the transforms.""" +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict + + +class TransformModel(BaseModel): + """ + Pydantic model used to represent a transformation. + + The `model_dump` method is overwritten to exclude the name field. + + Attributes + ---------- + name : str + Name of the transformation. + """ + + model_config = ConfigDict( + extra="forbid", # throw errors if the parameters are not properly passed + ) + + name: str + + def model_dump(self, **kwargs) -> Dict[str, Any]: + """ + Return the model as a dictionary. + + Parameters + ---------- + **kwargs + Pydantic BaseMode model_dump method keyword arguments. + + Returns + ------- + Dict[str, Any] + Dictionary representation of the model. + """ + model_dict = super().model_dump(**kwargs) + + # remove the name field + model_dict.pop("name") + + return model_dict diff --git a/src/careamics/config/transformations/xy_random_rotate90_model.py b/src/careamics/config/transformations/xy_random_rotate90_model.py new file mode 100644 index 00000000..af0cd142 --- /dev/null +++ b/src/careamics/config/transformations/xy_random_rotate90_model.py @@ -0,0 +1,29 @@ +"""Pydantic model for the XYRandomRotate90 transform.""" +from typing import Literal + +from pydantic import ConfigDict, Field + +from .transform_model import TransformModel + + +class XYRandomRotate90Model(TransformModel): + """ + Pydantic model used to represent NDFlip transformation. + + Attributes + ---------- + name : Literal["XYRandomRotate90"] + Name of the transformation. + p : float + Probability of applying the transformation, by default 0.5. + is_3D : bool + Whether the transformation should be applied in 3D, by default False. + """ + + model_config = ConfigDict( + validate_assignment=True, + ) + + name: Literal["XYRandomRotate90"] + p: float = Field(default=0.5, ge=0.0, le=1.0) + is_3D: bool = Field(default=False) diff --git a/src/careamics/config/validators/__init__.py b/src/careamics/config/validators/__init__.py new file mode 100644 index 00000000..53ddbf8d --- /dev/null +++ b/src/careamics/config/validators/__init__.py @@ -0,0 +1,5 @@ +"""Validator utilities.""" + +__all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"] + +from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2 diff --git a/src/careamics/config/validators/validator_utils.py b/src/careamics/config/validators/validator_utils.py new file mode 100644 index 00000000..6743d249 --- /dev/null +++ b/src/careamics/config/validators/validator_utils.py @@ -0,0 +1,100 @@ +""" +Validator functions. + +These functions are used to validate dimensions and axes of inputs. +""" +from typing import List, Optional, Tuple, Union + +_AXES = "STCZYX" + + +def check_axes_validity(axes: str) -> None: + """ + Sanity check on axes. + + The constraints on the axes are the following: + - must be a combination of 'STCZYX' + - must not contain duplicates + - must contain at least 2 contiguous axes: X and Y + - must contain at most 4 axes + - cannot contain both S and T axes + + Axes do not need to be in the order 'STCZYX', as this depends on the user data. + + Parameters + ---------- + axes : str + Axes to validate. + """ + _axes = axes.upper() + + # Minimum is 2 (XY) and maximum is 4 (TZYX) + if len(_axes) < 2 or len(_axes) > 6: + raise ValueError( + f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes." + ) + + if "YX" not in _axes and "XY" not in _axes: + raise ValueError( + f"Invalid axes {axes}. Must contain at least X and Y axes consecutively." + ) + + # all characters must be in REF_AXES = 'STCZYX' + if not all(s in _AXES for s in _axes): + raise ValueError(f"Invalid axes {axes}. Must be a combination of {_AXES}.") + + # check for repeating characters + for i, s in enumerate(_axes): + if i != _axes.rfind(s): + raise ValueError( + f"Invalid axes {axes}. Cannot contain duplicate axes" + f" (got multiple {axes[i]})." + ) + + +def value_ge_than_8_power_of_2( + value: int, +) -> None: + """ + Validate that the value is greater or equal than 8 and a power of 2. + + Parameters + ---------- + value : int + Value to validate. + + Raises + ------ + ValueError + If the value is smaller than 8. + ValueError + If the value is not a power of 2. + """ + if value < 8: + raise ValueError(f"Value must be non-zero positive (got {value}).") + + if (value & (value - 1)) != 0: + raise ValueError(f"Value must be a power of 2 (got {value}).") + + +def patch_size_ge_than_8_power_of_2( + patch_list: Optional[Union[List[int], Union[Tuple[int, ...]]]], +) -> None: + """ + Validate that each entry is greater or equal than 8 and a power of 2. + + Parameters + ---------- + patch_list : Optional[Union[List[int]]] + Patch size. + + Raises + ------ + ValueError + If the patch size if smaller than 8. + ValueError + If the patch size is not a power of 2. + """ + if patch_list is not None: + for dim in patch_list: + value_ge_than_8_power_of_2(dim) diff --git a/src/careamics/conftest.py b/src/careamics/conftest.py new file mode 100644 index 00000000..e0a1fae6 --- /dev/null +++ b/src/careamics/conftest.py @@ -0,0 +1,26 @@ +"""File used to discover python modules and run doctest. + +See https://sybil.readthedocs.io/en/latest/use.html#pytest +""" +from pathlib import Path + +import pytest +from pytest import TempPathFactory +from sybil import Sybil +from sybil.parsers.codeblock import PythonCodeBlockParser +from sybil.parsers.doctest import DocTestParser + + +@pytest.fixture(scope="module") +def my_path(tmpdir_factory: TempPathFactory) -> Path: + return tmpdir_factory.mktemp("my_path") + + +pytest_collect_file = Sybil( + parsers=[ + DocTestParser(), + PythonCodeBlockParser(future_imports=["print_function"]), + ], + pattern="*.py", + fixtures=["my_path"], +).pytest() diff --git a/src/careamics/dataset/__init__.py b/src/careamics/dataset/__init__.py index cf7c1ed0..b3c9cdba 100644 --- a/src/careamics/dataset/__init__.py +++ b/src/careamics/dataset/__init__.py @@ -1 +1,6 @@ """Dataset module.""" + +__all__ = ["InMemoryDataset", "PathIterableDataset"] + +from .in_memory_dataset import InMemoryDataset +from .iterable_dataset import PathIterableDataset diff --git a/src/careamics/dataset/dataset_utils.py b/src/careamics/dataset/dataset_utils.py deleted file mode 100644 index 2edb8ab4..00000000 --- a/src/careamics/dataset/dataset_utils.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Convenience methods for datasets.""" -import logging -from pathlib import Path -from typing import List, Union - -import numpy as np -import tifffile - - -def list_files(data_path: Union[str, Path], data_format: str) -> List[Path]: - """ - Return a list of path to files in a directory. - - Parameters - ---------- - data_path : str - Path to the folder containing the data. - data_format : str - Extension of the files to load, without period, e.g. `tif`. - - Returns - ------- - List[Path] - List of pathlib.Path objects. - """ - files = sorted(Path(data_path).rglob(f"*.{data_format}*")) - return files - - -def _update_axes(array: np.ndarray, axes: str) -> np.ndarray: - """ - Update axes of the sample to match the config axes. - - This method concatenate the S and T axes. - - Parameters - ---------- - array : np.ndarray - Input array. - axes : str - Description of axes in format STCZYX. - - Returns - ------- - np.ndarray - Updated array. - """ - # concatenate ST axes to N, return NCZYX - if "S" in axes or "T" in axes: - new_axes_len = len(axes.replace("Z", "").replace("YX", "")) - # TODO test reshape as it can scramble data, moveaxis is probably better - array = array.reshape(-1, *array.shape[new_axes_len:]).astype(np.float32) - - else: - array = np.expand_dims(array, axis=0).astype(np.float32) - - return array - - -def read_tiff(file_path: Path, axes: str) -> np.ndarray: - """ - Read a tiff file and return a numpy array. - - Parameters - ---------- - file_path : Path - Path to a file. - axes : str - Description of axes in format STCZYX. - - Returns - ------- - np.ndarray - Resulting array. - - Raises - ------ - ValueError - If the file failed to open. - OSError - If the file failed to open. - ValueError - If the file is not a valid tiff. - ValueError - If the data dimensions are incorrect. - ValueError - If the axes length is incorrect. - """ - if file_path.suffix[:4] == ".tif": - try: - sample = tifffile.imread(file_path) - except (ValueError, OSError) as e: - logging.exception(f"Exception in file {file_path}: {e}, skipping it.") - raise e - else: - raise ValueError(f"File {file_path} is not a valid tiff.") - - sample = sample.squeeze() - - if len(sample.shape) < 2 or len(sample.shape) > 4: - raise ValueError( - f"Incorrect data dimensions. Must be 2, 3 or 4 (got {sample.shape} for" - f"file {file_path})." - ) - - # check number of axes - if len(axes) != len(sample.shape): - raise ValueError(f"Incorrect axes length (got {axes} for file {file_path}).") - sample = _update_axes(sample, axes) - - return sample diff --git a/src/careamics/dataset/dataset_utils/__init__.py b/src/careamics/dataset/dataset_utils/__init__.py new file mode 100644 index 00000000..01517da8 --- /dev/null +++ b/src/careamics/dataset/dataset_utils/__init__.py @@ -0,0 +1,19 @@ +"""Files and arrays utils used in the datasets.""" + + +__all__ = [ + "reshape_array", + "get_files_size", + "list_files", + "validate_source_target_files", + "read_tiff", + "get_read_func", + "read_zarr", +] + + +from .dataset_utils import reshape_array +from .file_utils import get_files_size, list_files, validate_source_target_files +from .read_tiff import read_tiff +from .read_utils import get_read_func +from .read_zarr import read_zarr diff --git a/src/careamics/dataset/dataset_utils/dataset_utils.py b/src/careamics/dataset/dataset_utils/dataset_utils.py new file mode 100644 index 00000000..ace44bc9 --- /dev/null +++ b/src/careamics/dataset/dataset_utils/dataset_utils.py @@ -0,0 +1,100 @@ +"""Convenience methods for datasets.""" +from typing import List, Tuple + +import numpy as np + +from careamics.utils.logging import get_logger + +logger = get_logger(__name__) + + +def _get_shape_order( + shape_in: Tuple[int, ...], axes_in: str, ref_axes: str = "STCZYX" +) -> Tuple[Tuple[int, ...], str, List[int]]: + """ + Compute a new shape for the array based on the reference axes. + + Parameters + ---------- + shape_in : Tuple + Input shape. + ref_axes : str + Reference axes. + axes_in : str + Input axes. + + Returns + ------- + Tuple[Tuple[int, ...], str, List[int]] + New shape, new axes, indices of axes in the new axes order. + """ + indices = [axes_in.find(k) for k in ref_axes] + + # remove all non-existing axes (index == -1) + new_indices = list(filter(lambda k: k != -1, indices)) + + # find axes order and get new shape + new_axes = [axes_in[ind] for ind in new_indices] + new_shape = tuple([shape_in[ind] for ind in new_indices]) + + return new_shape, "".join(new_axes), new_indices + + +def reshape_array(x: np.ndarray, axes: str) -> np.ndarray: + """Reshape the data to (S, C, (Z), Y, X) by moving axes. + + If the data has both S and T axes, the two axes will be merged. A singleton + dimension is added if there are no C axis. + + Parameters + ---------- + x : np.ndarray + Input array. + axes : str + Description of axes in format `STCZYX`. + + Returns + ------- + np.ndarray + Reshaped array with shape (S, C, (Z), Y, X). + """ + _x = x + _axes = axes + + # sanity checks + if len(_axes) != len(_x.shape): + raise ValueError( + f"Incompatible data shape ({_x.shape}) and axes ({_axes}). Are the axes " + f"correct?" + ) + + # get new x shape + new_x_shape, new_axes, indices = _get_shape_order(_x.shape, _axes) + + # if S is not in the list of axes, then add a singleton S + if "S" not in new_axes: + new_axes = "S" + new_axes + _x = _x[np.newaxis, ...] + new_x_shape = (1,) + new_x_shape + + # need to change the array of indices + indices = [0] + [1 + i for i in indices] + + # reshape by moving axes + destination = list(range(len(indices))) + _x = np.moveaxis(_x, indices, destination) + + # remove T if necessary + if "T" in new_axes: + new_x_shape = (-1,) + new_x_shape[2:] # remove T and S + new_axes = new_axes.replace("T", "") + + # reshape S and T together + _x = _x.reshape(new_x_shape) + + # add channel + if "C" not in new_axes: + # Add channel axis after S + _x = np.expand_dims(_x, new_axes.index("S") + 1) + + return _x diff --git a/src/careamics/dataset/dataset_utils/file_utils.py b/src/careamics/dataset/dataset_utils/file_utils.py new file mode 100644 index 00000000..67b65f40 --- /dev/null +++ b/src/careamics/dataset/dataset_utils/file_utils.py @@ -0,0 +1,140 @@ +from fnmatch import fnmatch +from pathlib import Path +from typing import List, Union + +import numpy as np + +from careamics.config.support import SupportedData +from careamics.utils.logging import get_logger + +logger = get_logger(__name__) + + +def get_files_size(files: List[Path]) -> float: + """ + Get files size in MB. + + Parameters + ---------- + files : List[Path] + List of files. + + Returns + ------- + float + Total size of the files in MB. + """ + return np.sum([f.stat().st_size / 1024**2 for f in files]) + + +def list_files( + data_path: Union[str, Path], + data_type: Union[str, SupportedData], + extension_filter: str = "", +) -> List[Path]: + """Creates a recursive list of files in `data_path`. + + If `data_path` is a file, its name is validated against the `data_type` using + `fnmatch`, and the method returns `data_path` itself. + + By default, if `data_type` is equal to `custom`, all files will be listed. To + further filter the files, use `extension_filter`. + + `extension_filter` must be compatible with `fnmatch` and `Path.rglob`, e.g. "*.npy" + or "*.czi". + + Parameters + ---------- + data_path : Union[str, Path] + Path to the folder containing the data. + data_type : Union[str, SupportedData] + One of the supported data type (e.g. tif, custom). + extension_filter : str, optional + Extension filter, by default "". + + Returns + ------- + List[Path] + List of pathlib.Path objects. + + Raises + ------ + FileNotFoundError + If the data path does not exist. + ValueError + If the data path is empty or no files with the extension were found. + ValueError + If the file does not match the requested extension. + """ + # convert to Path + data_path = Path(data_path) + + # raise error if does not exists + if not data_path.exists(): + raise FileNotFoundError(f"Data path {data_path} does not exist.") + + # get extension compatible with fnmatch and rglob search + extension = SupportedData.get_extension(data_type) + + if data_type == SupportedData.CUSTOM and extension_filter != "": + extension = extension_filter + + # search recurively + if data_path.is_dir(): + # search recursively the path for files with the extension + files = sorted(data_path.rglob(extension)) + else: + # raise error if it has the wrong extension + if not fnmatch(str(data_path.absolute()), extension): + raise ValueError( + f"File {data_path} does not match the requested extension " + f'"{extension}".' + ) + + # save in list + files = [data_path] + + # raise error if no files were found + if len(files) == 0: + raise ValueError( + f'Data path {data_path} is empty or files with extension "{extension}" ' + f"were not found." + ) + + return files + + +def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -> None: + """ + Validate source and target path lists. + + The two lists should have the same number of files, and the filenames should match. + + Parameters + ---------- + src_files : List[Path] + List of source files. + tar_files : List[Path] + List of target files. + + Raises + ------ + ValueError + If the number of files in source and target folders is not the same. + ValueError + If some filenames in Train and target folders are not the same. + """ + # check equal length + if len(src_files) != len(tar_files): + raise ValueError( + f"The number of source files ({len(src_files)}) is not equal to the number " + f"of target files ({len(tar_files)})." + ) + + # check identical names + src_names = {f.name for f in src_files} + tar_names = {f.name for f in tar_files} + difference = src_names.symmetric_difference(tar_names) + + if len(difference) > 0: + raise ValueError(f"Source and target files have different names: {difference}.") diff --git a/src/careamics/dataset/dataset_utils/read_tiff.py b/src/careamics/dataset/dataset_utils/read_tiff.py new file mode 100644 index 00000000..7b4dd8e0 --- /dev/null +++ b/src/careamics/dataset/dataset_utils/read_tiff.py @@ -0,0 +1,61 @@ +import logging +from fnmatch import fnmatch +from pathlib import Path + +import numpy as np +import tifffile + +from careamics.config.support import SupportedData +from careamics.utils.logging import get_logger + +logger = get_logger(__name__) + + +def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray: + """ + Read a tiff file and return a numpy array. + + Parameters + ---------- + file_path : Path + Path to a file. + axes : str + Description of axes in format STCZYX. + + Returns + ------- + np.ndarray + Resulting array. + + Raises + ------ + ValueError + If the file failed to open. + OSError + If the file failed to open. + ValueError + If the file is not a valid tiff. + ValueError + If the data dimensions are incorrect. + ValueError + If the axes length is incorrect. + """ + if fnmatch(file_path.suffix, SupportedData.get_extension(SupportedData.TIFF)): + try: + array = tifffile.imread(file_path) + except (ValueError, OSError) as e: + logging.exception(f"Exception in file {file_path}: {e}, skipping it.") + raise e + else: + raise ValueError(f"File {file_path} is not a valid tiff.") + + # check dimensions + # TODO or should this really be done here? probably in the LightningDataModule + # TODO this should also be centralized somewhere else (validate_dimensions) + if len(array.shape) < 2 or len(array.shape) > 6: + raise ValueError( + f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for" + f"file {file_path})." + ) + + return array diff --git a/src/careamics/dataset/dataset_utils/read_utils.py b/src/careamics/dataset/dataset_utils/read_utils.py new file mode 100644 index 00000000..558626a8 --- /dev/null +++ b/src/careamics/dataset/dataset_utils/read_utils.py @@ -0,0 +1,25 @@ +from typing import Callable, Union + +from careamics.config.support import SupportedData + +from .read_tiff import read_tiff + + +def get_read_func(data_type: Union[SupportedData, str]) -> Callable: + """ + Get the read function for the data type. + + Parameters + ---------- + data_type : SupportedData + Data type. + + Returns + ------- + Callable + Read function. + """ + if data_type == SupportedData.TIFF: + return read_tiff + else: + raise NotImplementedError(f"Data type {data_type} is not supported.") diff --git a/src/careamics/dataset/dataset_utils/read_zarr.py b/src/careamics/dataset/dataset_utils/read_zarr.py new file mode 100644 index 00000000..5878a1cf --- /dev/null +++ b/src/careamics/dataset/dataset_utils/read_zarr.py @@ -0,0 +1,56 @@ +from typing import Union + +from zarr import Group, core, hierarchy, storage + + +def read_zarr( + zarr_source: Group, axes: str +) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]: + """Reads a file and returns a pointer. + + Parameters + ---------- + file_path : Path + pathlib.Path object containing a path to a file + + Returns + ------- + np.ndarray + Pointer to zarr storage + + Raises + ------ + ValueError, OSError + if a file is not a valid tiff or damaged + ValueError + if data dimensions are not 2, 3 or 4 + ValueError + if axes parameter from config is not consistent with data dimensions + """ + if isinstance(zarr_source, hierarchy.Group): + array = zarr_source[0] + + elif isinstance(zarr_source, storage.DirectoryStore): + raise NotImplementedError("DirectoryStore not supported yet") + + elif isinstance(zarr_source, core.Array): + # array should be of shape (S, (C), (Z), Y, X), iterating over S ? + if zarr_source.dtype == "O": + raise NotImplementedError("Object type not supported yet") + else: + array = zarr_source + else: + raise ValueError(f"Unsupported zarr object type {type(zarr_source)}") + + # sanity check on dimensions + if len(array.shape) < 2 or len(array.shape) > 4: + raise ValueError( + f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape})." + ) + + # sanity check on axes length + if len(axes) != len(array.shape): + raise ValueError(f"Incorrect axes length (got {axes}).") + + # arr = fix_axes(arr, axes) + return array diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 69b6343b..becf3eaa 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -1,155 +1,356 @@ """In-memory dataset module.""" +from __future__ import annotations + +import copy from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np -import torch - -from careamics.utils import normalize -from careamics.utils.logging import get_logger - -from .dataset_utils import ( - list_files, - read_tiff, +from torch.utils.data import Dataset + +from ..config import DataModel, InferenceModel +from ..config.tile_information import TileInformation +from ..utils.logging import get_logger +from .dataset_utils import read_tiff, reshape_array +from .patching.patch_transform import get_patch_transform +from .patching.patching import ( + prepare_patches_supervised, + prepare_patches_supervised_array, + prepare_patches_unsupervised, + prepare_patches_unsupervised_array, ) -from .extraction_strategy import ExtractionStrategy -from .patching import generate_patches +from .patching.tiled_patching import extract_tiles logger = get_logger(__name__) -class InMemoryDataset(torch.utils.data.Dataset): +class InMemoryDataset(Dataset): + """Dataset storing data in memory and allowing generating patches from it.""" + + def __init__( + self, + data_config: DataModel, + inputs: Union[np.ndarray, List[Path]], + data_target: Optional[Union[np.ndarray, List[Path]]] = None, + read_source_func: Callable = read_tiff, + **kwargs: Any, + ) -> None: + """ + Constructor. + + # TODO + """ + self.data_config = data_config + self.inputs = inputs + self.data_target = data_target + self.axes = self.data_config.axes + self.patch_size = self.data_config.patch_size + + # read function + self.read_source_func = read_source_func + + # Generate patches + supervised = self.data_target is not None + patches = self._prepare_patches(supervised) + + # Add results to members + self.data, self.data_targets, computed_mean, computed_std = patches + + if not self.data_config.mean or not self.data_config.std: + self.mean, self.std = computed_mean, computed_std + logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}") + + # if the transforms are not an instance of Compose + if self.data_config.has_transform_list(): + # update mean and std in configuration + # the object is mutable and should then be recorded in the CAREamist obj + self.data_config.set_mean_and_std(self.mean, self.std) + else: + self.mean, self.std = self.data_config.mean, self.data_config.std + + # get transforms + self.patch_transform = get_patch_transform( + patch_transforms=self.data_config.transforms, + with_target=self.data_target is not None, + ) + + def _prepare_patches( + self, supervised: bool + ) -> Tuple[np.ndarray, Optional[np.ndarray], float, float]: + """ + Iterate over data source and create an array of patches. + + Parameters + ---------- + supervised : bool + Whether the dataset is supervised or not. + + Returns + ------- + np.ndarray + Array of patches. + """ + if supervised: + if isinstance(self.inputs, np.ndarray) and isinstance( + self.data_target, np.ndarray + ): + return prepare_patches_supervised_array( + self.inputs, + self.axes, + self.data_target, + self.patch_size, + ) + elif isinstance(self.inputs, list) and isinstance(self.data_target, list): + return prepare_patches_supervised( + self.inputs, + self.data_target, + self.axes, + self.patch_size, + self.read_source_func, + ) + else: + raise ValueError( + f"Data and target must be of the same type, either both numpy " + f"arrays or both lists of paths, got {type(self.inputs)} (data) " + f"and {type(self.data_target)} (target)." + ) + else: + if isinstance(self.inputs, np.ndarray): + return prepare_patches_unsupervised_array( + self.inputs, + self.axes, + self.patch_size, + ) + else: + return prepare_patches_unsupervised( + self.inputs, + self.axes, + self.patch_size, + self.read_source_func, + ) + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + Length of the dataset. + """ + return len(self.data) + + def __getitem__(self, index: int) -> Tuple[np.ndarray]: + """ + Return the patch corresponding to the provided index. + + Parameters + ---------- + index : int + Index of the patch to return. + + Returns + ------- + Tuple[np.ndarray] + Patch. + + Raises + ------ + ValueError + If dataset mean and std are not set. + """ + patch = self.data[index] + + # if there is a target + if self.data_target is not None: + # get target + target = self.data_targets[index] + + # Albumentations requires Channel last + c_patch = np.moveaxis(patch, 0, -1) + c_target = np.moveaxis(target, 0, -1) + + # Apply transforms + transformed = self.patch_transform(image=c_patch, target=c_target) + + # move axes back + patch = np.moveaxis(transformed["image"], -1, 0) + target = np.moveaxis(transformed["target"], -1, 0) + + return patch, target + + elif self.data_config.has_n2v_manipulate(): + # Albumentations requires Channel last + patch = np.moveaxis(patch, 0, -1) + + # Apply transforms + transformed_patch = self.patch_transform(image=patch)["image"] + manip_patch, patch, mask = transformed_patch + + # move C axes back + manip_patch = np.moveaxis(manip_patch, -1, 0) + patch = np.moveaxis(patch, -1, 0) + mask = np.moveaxis(mask, -1, 0) + + return (manip_patch, patch, mask) + else: + raise ValueError( + "Something went wrong! No target provided (not supervised training) " + "and no N2V manipulation (no N2V training)." + ) + + def split_dataset( + self, + percentage: float = 0.1, + minimum_patches: int = 1, + ) -> InMemoryDataset: + """Split a new dataset away from the current one. + + This method is used to extract random validation patches from the dataset. + + Parameters + ---------- + percentage : float, optional + Percentage of patches to extract, by default 0.1. + minimum_patches : int, optional + Minimum number of patches to extract, by default 5. + + Returns + ------- + InMemoryDataset + New dataset with the extracted patches. + + Raises + ------ + ValueError + If `percentage` is not between 0 and 1. + ValueError + If `minimum_number` is not between 1 and the number of patches. + """ + if percentage < 0 or percentage > 1: + raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.") + + if minimum_patches < 1 or minimum_patches > len(self): + raise ValueError( + f"Minimum number of patches must be between 1 and " + f"{len(self)} (number of patches), got " + f"{minimum_patches}. Adjust the patch size or the minimum number of " + f"patches." + ) + + total_patches = len(self) + + # number of patches to extract (either percentage rounded or minimum number) + n_patches = max(round(total_patches * percentage), minimum_patches) + + # get random indices + indices = np.random.choice(total_patches, n_patches, replace=False) + + # extract patches + val_patches = self.data[indices] + + # remove patches from self.patch + self.data = np.delete(self.data, indices, axis=0) + + # same for targets + if self.data_targets is not None: + val_targets = self.data_targets[indices] + self.data_targets = np.delete(self.data_targets, indices, axis=0) + + # clone the dataset + dataset = copy.deepcopy(self) + + # reassign patches + dataset.data = val_patches + + # reassign targets + if self.data_targets is not None: + dataset.data_targets = val_targets + + return dataset + + +class InMemoryPredictionDataset(Dataset): """ Dataset storing data in memory and allowing generating patches from it. - Parameters - ---------- - data_path : Union[str, Path] - Path to the data, must be a directory. - data_format : str - Extension of the data files, without period. - axes : str - Description of axes in format STCZYX. - patch_extraction_method : ExtractionStrategies - Patch extraction strategy, as defined in extraction_strategy. - patch_size : Union[List[int], Tuple[int]] - Size of the patches along each axis, must be of dimension 2 or 3. - patch_overlap : Optional[Union[List[int], Tuple[int]]], optional - Overlap of the patches, must be of dimension 2 or 3, by default None. - mean : Optional[float], optional - Expected mean of the dataset, by default None. - std : Optional[float], optional - Expected standard deviation of the dataset, by default None. - patch_transform : Optional[Callable], optional - Patch transform to apply, by default None. - patch_transform_params : Optional[Dict], optional - Patch transform parameters, by default None. + # TODO """ def __init__( self, - data_path: Union[str, Path], - data_format: str, - axes: str, - patch_extraction_method: ExtractionStrategy, - patch_size: Union[List[int], Tuple[int]], - patch_overlap: Optional[Union[List[int], Tuple[int]]] = None, - mean: Optional[float] = None, - std: Optional[float] = None, - patch_transform: Optional[Callable] = None, - patch_transform_params: Optional[Dict] = None, + prediction_config: InferenceModel, + inputs: np.ndarray, + data_target: Optional[np.ndarray] = None, + read_source_func: Optional[Callable] = read_tiff, ) -> None: - """ - Constructor. + """Constructor. Parameters ---------- - data_path : Union[str, Path] - Path to the data, must be a directory. - data_format : str - Extension of the data files, without period. + array : np.ndarray + Array containing the data. axes : str Description of axes in format STCZYX. - patch_extraction_method : ExtractionStrategies - Patch extraction strategy, as defined in extraction_strategy. - patch_size : Union[List[int], Tuple[int]] - Size of the patches along each axis, must be of dimension 2 or 3. - patch_overlap : Optional[Union[List[int], Tuple[int]]], optional - Overlap of the patches, must be of dimension 2 or 3, by default None. - mean : Optional[float], optional - Expected mean of the dataset, by default None. - std : Optional[float], optional - Expected standard deviation of the dataset, by default None. - patch_transform : Optional[Callable], optional - Patch transform to apply, by default None. - patch_transform_params : Optional[Dict], optional - Patch transform parameters, by default None. Raises ------ ValueError If data_path is not a directory. """ - self.data_path = Path(data_path) - if not self.data_path.is_dir(): - raise ValueError("Path to data should be an existing folder.") - - self.data_format = data_format - self.axes = axes + self.pred_config = prediction_config + self.input_array = inputs + self.axes = self.pred_config.axes + self.tile_size = self.pred_config.tile_size + self.tile_overlap = self.pred_config.tile_overlap + self.mean = self.pred_config.mean + self.std = self.pred_config.std + self.data_target = data_target - self.patch_transform = patch_transform + # tiling only if both tile size and overlap are provided + self.tiling = self.tile_size is not None and self.tile_overlap is not None - self.files = list_files(self.data_path, self.data_format) - - self.patch_size = patch_size - self.patch_overlap = patch_overlap - self.patch_extraction_method = patch_extraction_method - self.patch_transform = patch_transform - self.patch_transform_params = patch_transform_params - - self.mean = mean - self.std = std + # read function + self.read_source_func = read_source_func # Generate patches - self.data, computed_mean, computed_std = self._prepare_patches() + self.data = self._prepare_tiles() + self.mean, self.std = self.pred_config.mean, self.pred_config.std - if not mean or not std: - self.mean, self.std = computed_mean, computed_std - logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}") + # get transforms + self.patch_transform = get_patch_transform( + patch_transforms=self.pred_config.transforms, + with_target=self.data_target is not None, + ) - assert self.mean is not None - assert self.std is not None - - def _prepare_patches(self) -> Tuple[np.ndarray, float, float]: + def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]: """ Iterate over data source and create an array of patches. Returns ------- - np.ndarray - Array of patches. + List[XArrayTile] + List of tiles. """ - means, stds, num_samples = 0, 0, 0 - self.all_patches = [] - for filename in self.files: - sample = read_tiff(filename, self.axes) - means += sample.mean() - stds += np.std(sample) - num_samples += 1 - - # generate patches, return a generator - patches = generate_patches( - sample, - self.patch_extraction_method, - self.patch_size, - self.patch_overlap, + # reshape array + reshaped_sample = reshape_array(self.input_array, self.axes) + + if self.tiling: + # generate patches, which returns a generator + patch_generator = extract_tiles( + arr=reshaped_sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, ) + patches_list = list(patch_generator) - # convert generator to list and add to all_patches - self.all_patches.extend(list(patches)) + if len(patches_list) == 0: + raise ValueError("No tiles generated, ") - result_mean, result_std = means / num_samples, stds / num_samples - return np.concatenate(self.all_patches), result_mean, result_std + return patches_list + else: + array_shape = reshaped_sample.squeeze().shape + return [(reshaped_sample, TileInformation(array_shape=array_shape))] def __len__(self) -> int: """ @@ -160,10 +361,9 @@ def __len__(self) -> int: int Length of the dataset. """ - # convert to numpy array to convince mypy that it is not a generator - return sum(np.array(s).shape[0] for s in self.all_patches) + return len(self.data) - def __getitem__(self, index: int) -> Tuple[np.ndarray]: + def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]: """ Return the patch corresponding to the provided index. @@ -174,29 +374,18 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray]: Returns ------- - Tuple[np.ndarray] - Patch. - - Raises - ------ - ValueError - If dataset mean and std are not set. + Tuple[np.ndarray, TileInformation] + Transformed patch. """ - patch = self.data[index].squeeze() + tile_array, tile_info = self.data[index] - if self.mean is not None and self.std is not None: - if isinstance(patch, tuple): - patch = normalize(img=patch[0], mean=self.mean, std=self.std) - patch = (patch, *patch[1:]) - else: - patch = normalize(img=patch, mean=self.mean, std=self.std) + # Albumentations requires channel last, use the XArrayTile array + patch = np.moveaxis(tile_array, 0, -1) - if self.patch_transform is not None: - # replace None self.patch_transform_params with empty dict - if self.patch_transform_params is None: - self.patch_transform_params = {} + # Apply transforms + transformed_patch = self.patch_transform(image=patch)["image"] - patch = self.patch_transform(patch, **self.patch_transform_params) - return patch - else: - raise ValueError("Dataset mean and std must be set before using it.") + # move C axes back + transformed_patch = np.moveaxis(transformed_patch, -1, 0) + + return transformed_patch, tile_info diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py new file mode 100644 index 00000000..53df2e61 --- /dev/null +++ b/src/careamics/dataset/iterable_dataset.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +import copy +from pathlib import Path +from typing import Any, Callable, Generator, List, Optional, Tuple, Union + +import numpy as np +from torch.utils.data import IterableDataset, get_worker_info + +from ..config import DataModel, InferenceModel +from ..config.tile_information import TileInformation +from ..utils.logging import get_logger +from .dataset_utils import read_tiff, reshape_array +from .patching import ( + get_patch_transform, +) +from .patching.random_patching import extract_patches_random +from .patching.tiled_patching import extract_tiles + +logger = get_logger(__name__) + + +class PathIterableDataset(IterableDataset): + """ + Dataset allowing extracting patches w/o loading whole data into memory. + + Parameters + ---------- + data_path : Union[str, Path] + Path to the data, must be a directory. + axes : str + Description of axes in format STCZYX. + patch_extraction_method : Union[ExtractionStrategies, None] + Patch extraction strategy, as defined in extraction_strategy. + patch_size : Optional[Union[List[int], Tuple[int]]], optional + Size of the patches in each dimension, by default None. + patch_overlap : Optional[Union[List[int], Tuple[int]]], optional + Overlap of the patches in each dimension, by default None. + mean : Optional[float], optional + Expected mean of the dataset, by default None. + std : Optional[float], optional + Expected standard deviation of the dataset, by default None. + patch_transform : Optional[Callable], optional + Patch transform callable, by default None. + """ + + def __init__( + self, + data_config: Union[DataModel, InferenceModel], + src_files: List[Path], + target_files: Optional[List[Path]] = None, + read_source_func: Callable = read_tiff, + ) -> None: + self.data_config = data_config + self.data_files = src_files + self.target_files = target_files + self.data_config = data_config + self.read_source_func = read_source_func + + # compute mean and std over the dataset + if not data_config.mean or not data_config.std: + self.mean, self.std = self._calculate_mean_and_std() + + # if the transforms are not an instance of Compose + # Check if the data_config is an instance of DataModel or InferenceModel + # isinstance isn't working properly here + if hasattr(data_config, "has_transform_list"): + if data_config.has_transform_list(): + # update mean and std in configuration + # the object is mutable and should then be recorded in the CAREamist + data_config.set_mean_and_std(self.mean, self.std) + else: + data_config.set_mean_and_std(self.mean, self.std) + + else: + self.mean = data_config.mean + self.std = data_config.std + + # get transforms + self.patch_transform = get_patch_transform( + patch_transforms=data_config.transforms, + with_target=target_files is not None, + ) + + def _calculate_mean_and_std(self) -> Tuple[float, float]: + """ + Calculate mean and std of the dataset. + + Returns + ------- + Tuple[float, float] + Tuple containing mean and standard deviation. + """ + means, stds = 0, 0 + num_samples = 0 + + for sample, _ in self._iterate_over_files(): + means += sample.mean() + stds += sample.std() + num_samples += 1 + + if num_samples == 0: + raise ValueError("No samples found in the dataset.") + + result_mean = means / num_samples + result_std = stds / num_samples + + logger.info(f"Calculated mean and std for {num_samples} images") + logger.info(f"Mean: {result_mean}, std: {result_std}") + return result_mean, result_std + + def _iterate_over_files( + self, + ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]: + """ + Iterate over data source and yield whole image. + + Yields + ------ + np.ndarray + Image. + """ + # When num_workers > 0, each worker process will have a different copy of the + # dataset object + # Configuring each copy independently to avoid having duplicate data returned + # from the workers + worker_info = get_worker_info() + worker_id = worker_info.id if worker_info is not None else 0 + num_workers = worker_info.num_workers if worker_info is not None else 1 + + # iterate over the files + for i, filename in enumerate(self.data_files): + # retrieve file corresponding to the worker id + if i % num_workers == worker_id: + try: + # read data + sample = self.read_source_func(filename, self.data_config.axes) + + # read target, if available + if self.target_files is not None: + if filename.name != self.target_files[i].name: + raise ValueError( + f"File {filename} does not match target file " + f"{self.target_files[i]}. Have you passed sorted " + f"arrays?" + ) + + # read target + target = self.read_source_func( + self.target_files[i], self.data_config.axes + ) + + yield sample, target + else: + yield sample, None + + except Exception as e: + logger.error(f"Error reading file {filename}: {e}") + + def __iter__( + self, + ) -> Generator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], None, None]: + """ + Iterate over data source and yield single patch. + + Yields + ------ + np.ndarray + Single patch. + """ + assert ( + self.mean is not None and self.std is not None + ), "Mean and std must be provided" + + # iterate over files + for sample_input, sample_target in self._iterate_over_files(): + reshaped_sample = reshape_array(sample_input, self.data_config.axes) + reshaped_target = ( + None + if sample_target is None + else reshape_array(sample_target, self.data_config.axes) + ) + + patches = extract_patches_random( + arr=reshaped_sample, + patch_size=self.data_config.patch_size, + target=reshaped_target, + ) + + # iterate over patches + # patches are tuples of (patch, target) if target is available + # or (patch, None) only if no target is available + # patch is of dimensions (C)ZYX + for patch_data in patches: + # if there is a target + if self.target_files is not None: + # Albumentations expects the channel dimension to be last + # Taking the first element because patch_data can include target + c_patch = np.moveaxis(patch_data[0], 0, -1) + c_target = np.moveaxis(patch_data[1], 0, -1) + + # apply the transform to the patch and the target + transformed = self.patch_transform( + image=c_patch, + target=c_target, + ) + + # move the axes back to the original position + c_patch = np.moveaxis(transformed["image"], -1, 0) + c_target = np.moveaxis(transformed["target"], -1, 0) + + yield (c_patch, c_target) + elif self.data_config.has_n2v_manipulate(): + # Albumentations expects the channel dimension to be last + # Taking the first element because patch_data can include target + patch = np.moveaxis(patch_data[0], 0, -1) + + # apply transform + transformed = self.patch_transform(image=patch) + + # retrieve the output of ManipulateN2V + results = transformed["image"] + masked_patch: np.ndarray = results[0] + original_patch: np.ndarray = results[1] + mask: np.ndarray = results[2] + + # move C axes back + masked_patch = np.moveaxis(masked_patch, -1, 0) + original_patch = np.moveaxis(original_patch, -1, 0) + mask = np.moveaxis(mask, -1, 0) + + yield (masked_patch, original_patch, mask) + else: + raise ValueError( + "Something went wrong! Not target file (no supervised " + "training) and no N2V transform (no n2v training either)." + ) + + def get_number_of_files(self) -> int: + """ + Return the number of files in the dataset. + + Returns + ------- + int + Number of files in the dataset. + """ + return len(self.data_files) + + def split_dataset( + self, + percentage: float = 0.1, + minimum_number: int = 5, + ) -> PathIterableDataset: + """Split up dataset in two. + + Splits the datest sing a percentage of the data (files) to extract, or the + minimum number of the percentage is less than the minimum number. + + Parameters + ---------- + percentage : float, optional + Percentage of files to split up, by default 0.1 + minimum_number : int, optional + Minimum number of files to split up, by default 5 + + Returns + ------- + IterableDataset + Dataset containing the split data. + + Raises + ------ + ValueError + If the percentage is smaller than 0 or larger than 1. + ValueError + If the minimum number is smaller than 1 or larger than the number of files. + """ + if percentage < 0 or percentage > 1: + raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.") + + if minimum_number < 1 or minimum_number > self.get_number_of_files(): + raise ValueError( + f"Minimum number of files must be between 1 and " + f"{self.get_number_of_files()} (number of files), got " + f"{minimum_number}." + ) + + # compute number of files + total_files = self.get_number_of_files() + n_files = max(round(percentage * total_files), minimum_number) + + # get random indices + indices = np.random.choice(total_files, n_files, replace=False) + + # extract files + val_files = [self.data_files[i] for i in indices] + + # remove patches from self.patch + data_files = [] + for i, file in enumerate(self.data_files): + if i not in indices: + data_files.append(file) + self.data_files = data_files + + # same for targets + if self.target_files is not None: + val_target_files = [self.target_files[i] for i in indices] + + data_target_files = [] + for i, file in enumerate(self.target_files): + if i not in indices: + data_target_files.append(file) + self.target_files = data_target_files + + # clone the dataset + dataset = copy.deepcopy(self) + + # reassign patches + dataset.data_files = val_files + + # reassign targets + if self.target_files is not None: + dataset.target_files = val_target_files + + return dataset + + +class IterablePredictionDataset(PathIterableDataset): + """ + Dataset allowing extracting patches w/o loading whole data into memory. + + Parameters + ---------- + data_path : Union[str, Path] + Path to the data, must be a directory. + axes : str + Description of axes in format STCZYX. + mean : Optional[float], optional + Expected mean of the dataset, by default None. + std : Optional[float], optional + Expected standard deviation of the dataset, by default None. + patch_transform : Optional[Callable], optional + Patch transform callable, by default None. + """ + + def __init__( + self, + prediction_config: InferenceModel, + src_files: List[Path], + read_source_func: Callable = read_tiff, + **kwargs: Any, + ) -> None: + super().__init__( + data_config=prediction_config, + src_files=src_files, + read_source_func=read_source_func, + ) + + self.prediction_config = prediction_config + self.axes = prediction_config.axes + self.tile_size = self.prediction_config.tile_size + self.tile_overlap = self.prediction_config.tile_overlap + self.read_source_func = read_source_func + + # tile only if both tile size and overlaps are provided + self.tile = self.tile_size is not None and self.tile_overlap is not None + + # get tta transforms + self.patch_transform = get_patch_transform( + patch_transforms=prediction_config.transforms, + with_target=False, + ) + + def __iter__( + self, + ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: + """ + Iterate over data source and yield single patch. + + Yields + ------ + np.ndarray + Single patch. + """ + assert ( + self.mean is not None and self.std is not None + ), "Mean and std must be provided" + + for sample, _ in self._iterate_over_files(): + # reshape array + reshaped_sample = reshape_array(sample, self.axes) + + if self.tile: + # generate patches, return a generator + patch_gen = extract_tiles( + arr=reshaped_sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, + ) + else: + # just wrap the sample in a generator with default tiling info + array_shape = reshaped_sample.squeeze().shape + patch_gen = ( + (reshaped_sample, TileInformation(array_shape=array_shape)) + for _ in range(1) + ) + + # apply transform to patches + for patch_array, tile_info in patch_gen: + # albumentations expects the channel dimension to be last + patch = np.moveaxis(patch_array, 0, -1) + transformed_patch = self.patch_transform(image=patch) + transformed_patch = np.moveaxis(transformed_patch["image"], -1, 0) + + yield transformed_patch, tile_info diff --git a/src/careamics/dataset/patching.py b/src/careamics/dataset/patching.py deleted file mode 100644 index 2860f51c..00000000 --- a/src/careamics/dataset/patching.py +++ /dev/null @@ -1,492 +0,0 @@ -""" -Tiling submodule. - -These functions are used to tile images into patches or tiles. -""" -import itertools -from typing import Generator, List, Optional, Tuple, Union - -import numpy as np -from skimage.util import view_as_windows - -from careamics.utils.logging import get_logger - -from .extraction_strategy import ExtractionStrategy - -logger = get_logger(__name__) - - -def _compute_number_of_patches( - arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]] -) -> Tuple[int, ...]: - """ - Compute the number of patches that fit in each dimension. - - Array must have one dimension more than the patches (C dimension). - - Parameters - ---------- - arr : np.ndarray - Input array. - patch_sizes : Tuple[int] - Size of the patches. - - Returns - ------- - Tuple[int] - Number of patches in each dimension. - """ - n_patches = [ - np.ceil(arr.shape[i + 1] / patch_sizes[i]).astype(int) - for i in range(len(patch_sizes)) - ] - return tuple(n_patches) - - -def _compute_overlap( - arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]] -) -> Tuple[int, ...]: - """ - Compute the overlap between patches in each dimension. - - Array must be of dimensions C(Z)YX, and patches must be of dimensions YX or ZYX. - If the array dimensions are divisible by the patch sizes, then the overlap is 0. - Otherwise, it is the result of the division rounded to the upper value. - - Parameters - ---------- - arr : np.ndarray - Input array 3 or 4 dimensions. - patch_sizes : Tuple[int] - Size of the patches. - - Returns - ------- - Tuple[int] - Overlap between patches in each dimension. - """ - n_patches = _compute_number_of_patches(arr, patch_sizes) - - overlap = [ - np.ceil( - np.clip(n_patches[i] * patch_sizes[i] - arr.shape[i + 1], 0, None) - / max(1, (n_patches[i] - 1)) - ).astype(int) - for i in range(len(patch_sizes)) - ] - return tuple(overlap) - - -def _compute_crop_and_stitch_coords_1d( - axis_size: int, tile_size: int, overlap: int -) -> Tuple[List[Tuple[int, int]], ...]: - """ - Compute the coordinates of each tile along an axis, given the overlap. - - Parameters - ---------- - axis_size : int - Length of the axis. - tile_size : int - Size of the tile for the given axis. - overlap : int - Size of the overlap for the given axis. - - Returns - ------- - Tuple[Tuple[int]] - Tuple of all coordinates for given axis. - """ - # Compute the step between tiles - step = tile_size - overlap - crop_coords = [] - stitch_coords = [] - overlap_crop_coords = [] - # Iterate over the axis with a certain step - for i in range(0, axis_size - overlap, step): - # Check if the tile fits within the axis - if i + tile_size <= axis_size: - # Add the coordinates to crop one tile - crop_coords.append((i, i + tile_size)) - # Add the pixel coordinates of the cropped tile in the original image space - stitch_coords.append( - ( - i + overlap // 2 if i > 0 else 0, - i + tile_size - overlap // 2 - if crop_coords[-1][1] < axis_size - else axis_size, - ) - ) - # Add the coordinates to crop the overlap from the prediction. - overlap_crop_coords.append( - ( - overlap // 2 if i > 0 else 0, - tile_size - overlap // 2 - if crop_coords[-1][1] < axis_size - else tile_size, - ) - ) - # If the tile does not fit within the axis, perform the abovementioned - # operations starting from the end of the axis - else: - # if (axis_size - tile_size, axis_size) not in crop_coords: - crop_coords.append((axis_size - tile_size, axis_size)) - last_tile_end_coord = stitch_coords[-1][1] - stitch_coords.append((last_tile_end_coord, axis_size)) - overlap_crop_coords.append( - (tile_size - (axis_size - last_tile_end_coord), tile_size) - ) - break - return crop_coords, stitch_coords, overlap_crop_coords - - -def _compute_patch_steps( - patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...] -) -> Tuple[int, ...]: - """ - Compute steps between patches. - - Parameters - ---------- - patch_sizes : Tuple[int] - Size of the patches. - overlaps : Tuple[int] - Overlap between patches. - - Returns - ------- - Tuple[int] - Steps between patches. - """ - steps = [ - min(patch_sizes[i] - overlaps[i], patch_sizes[i]) - for i in range(len(patch_sizes)) - ] - return tuple(steps) - - -def _compute_reshaped_view( - arr: np.ndarray, - window_shape: Tuple[int, ...], - step: Tuple[int, ...], - output_shape: Tuple[int, ...], -) -> np.ndarray: - """ - Compute reshaped views of an array, where views correspond to patches. - - Parameters - ---------- - arr : np.ndarray - Array from which the views are extracted. - window_shape : Tuple[int] - Shape of the views. - step : Tuple[int] - Steps between views. - output_shape : Tuple[int] - Shape of the output array. - - Returns - ------- - np.ndarray - Array with views dimension. - """ - rng = np.random.default_rng() - patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape( - *output_shape - ) - rng.shuffle(patches, axis=0) - return patches - - -def _patches_sanity_check( - arr: np.ndarray, - patch_size: Union[List[int], Tuple[int, ...]], - is_3d_patch: bool, -) -> None: - """ - Check patch size and array compatibility. - - This method validates the patch sizes with respect to the array dimensions: - - The patch sizes must have one dimension fewer than the array (C dimension). - - Chack that patch sizes are smaller than array dimensions. - - Parameters - ---------- - arr : np.ndarray - Input array. - patch_size : Union[List[int], Tuple[int, ...]] - Size of the patches along each dimension of the array, except the first. - is_3d_patch : bool - Whether the patch is 3D or not. - - Raises - ------ - ValueError - If the patch size is not consistent with the array shape (one more array - dimension). - ValueError - If the patch size in Z is larger than the array dimension. - ValueError - If either of the patch sizes in X or Y is larger than the corresponding array - dimension. - """ - if len(patch_size) != len(arr.shape[1:]): - raise ValueError( - f"There must be a patch size for each spatial dimensions " - f"(got {patch_size} patches for dims {arr.shape})." - ) - - # Sanity checks on patch sizes versus array dimension - if is_3d_patch and patch_size[0] > arr.shape[-3]: - raise ValueError( - f"Z patch size is inconsistent with image shape " - f"(got {patch_size[0]} patches for dim {arr.shape[1]})." - ) - - if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]: - raise ValueError( - f"At least one of YX patch dimensions is inconsistent with image shape " - f"(got {patch_size} patches for dims {arr.shape[-2:]})." - ) - - -# formerly : -# in dataloader.py#L52, 00d536c -def _extract_patches_sequential( - arr: np.ndarray, patch_size: Union[List[int], Tuple[int]] -) -> Generator[np.ndarray, None, None]: - """ - Generate patches from an array in a sequential manner. - - Array dimensions should be C(Z)YX, where C can be a singleton dimension. The patches - are generated sequentially and cover the whole array. - - Parameters - ---------- - arr : np.ndarray - Input image array. - patch_size : Tuple[int] - Patch sizes in each dimension. - - Returns - ------- - Generator[np.ndarray, None, None] - Generator of patches. - """ - # Patches sanity check - is_3d_patch = len(patch_size) == 3 - - _patches_sanity_check(arr, patch_size, is_3d_patch) - - # Compute overlap - overlaps = _compute_overlap(arr=arr, patch_sizes=patch_size) - - # Create view window and overlaps - window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps) - - # Correct for first dimension for computing windowed views - window_shape = (1, *patch_size) - window_steps = (1, *window_steps) - - if is_3d_patch and patch_size[0] == 1: - output_shape = (-1,) + window_shape[1:] - else: - output_shape = (-1, *window_shape) - - # Generate a view of the input array containing pre-calculated number of patches - # in each dimension with overlap. - # Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches,C, Y, X) - patches = _compute_reshaped_view( - arr, window_shape=window_shape, step=window_steps, output_shape=output_shape - ) - logger.info(f"Extracted {patches.shape[0]} patches from input array.") - - # return a generator of patches - return (patches[i, ...] for i in range(patches.shape[0])) - - -def _extract_patches_random( - arr: np.ndarray, patch_size: Union[List[int], Tuple[int]] -) -> Generator[np.ndarray, None, None]: - """ - Generate patches from an array in a random manner. - - The method calculates how many patches the image can be divided into and then - extracts an equal number of random patches. - - Parameters - ---------- - arr : np.ndarray - Input image array. - patch_size : Tuple[int] - Patch sizes in each dimension. - - Yields - ------ - Generator[np.ndarray, None, None] - Generator of patches. - """ - is_3d_patch = len(patch_size) == 3 - - # Patches sanity check - _patches_sanity_check(arr, patch_size, is_3d_patch) - - rng = np.random.default_rng() - # shuffle the array along the first axis TODO do we need shuffling? - rng.shuffle(arr, axis=0) - - for sample_idx in range(arr.shape[0]): - sample = arr[sample_idx] - # calculate how many number of patches can image area be divided into - n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int) - for _ in range(n_patches): - crop_coords = [ - rng.integers(0, arr.shape[i + 1] - patch_size[i]) - for i in range(len(patch_size)) - ] - patch = ( - sample[ - ( - ..., - *[ - slice(c, c + patch_size[i]) - for i, c in enumerate(crop_coords) - ], - ) - ] - .copy() - .astype(np.float32) - ) - yield patch - - -def _extract_tiles( - arr: np.ndarray, - tile_size: Union[List[int], Tuple[int]], - overlaps: Union[List[int], Tuple[int]], -) -> Generator: - """ - Generate tiles from the input array with specified overlap. - - The tiles cover the whole array. - - Parameters - ---------- - arr : np.ndarray - Array of shape (S, (Z), Y, X). - tile_size : Union[List[int], Tuple[int]] - Tile sizes in each dimension, of length 2 or 3. - overlaps : Union[List[int], Tuple[int]] - Overlap values in each dimension, of length 2 or 3. - - Yields - ------ - Generator - Tile generator that yields the tile with corresponding coordinates to stitch - back the tiles together. - """ - # Iterate over num samples (S) - for sample_idx in range(arr.shape[0]): - sample = arr[sample_idx] - - # Create an array of coordinates for cropping and stitching all axes. - # Shape: (axes, type_of_coord, tile_num, start/end coord) - crop_and_stitch_coords_list = [ - _compute_crop_and_stitch_coords_1d( - sample.shape[i], tile_size[i], overlaps[i] - ) - for i in range(len(tile_size)) - ] - - # Rearrange crop coordinates from a list of coordinate pairs per axis to a list - # grouped by type. - # For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d - # will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)]), - # where the first list is crop coordinates for 1st axis. - all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip( - *crop_and_stitch_coords_list - ) - - # Iterate over generated coordinate pairs: - for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate( - zip( - itertools.product(*all_crop_coords), - itertools.product(*all_stitch_coords), - itertools.product(*all_overlap_crop_coords), - ) - ): - tile = sample[(..., *[slice(c[0], c[1]) for c in list(crop_coords)])] - - # Check if we are at the end of the sample. - # To check that we compute the length of the array that contains all the - # tiles - if tile_idx == np.prod([len(axis) for axis in all_crop_coords]) - 1: - last_tile = True - else: - last_tile = False - yield ( - np.expand_dims(tile.astype(np.float32), 0), - last_tile, - arr.shape[1:], - overlap_crop_coords, - stitch_coords, - ) - - -def generate_patches( - sample: np.ndarray, - patch_extraction_method: ExtractionStrategy, - patch_size: Optional[Union[List[int], Tuple[int]]] = None, - patch_overlap: Optional[Union[List[int], Tuple[int]]] = None, -) -> Generator[np.ndarray, None, None]: - """ - Generate patches from a sample. - - Parameters - ---------- - sample : np.ndarray - Input array. - patch_extraction_method : ExtractionStrategies - Patch extraction method, as defined in extraction_strategy.ExtractionStrategy. - patch_size : Optional[Union[List[int], Tuple[int]]] - Size of the patches along each dimension of the array, except the first. - patch_overlap : Optional[Union[List[int], Tuple[int]]] - Overlap between patches. - - Returns - ------- - Generator[np.ndarray, None, None] - Generator yielding patches/tiles. - - Raises - ------ - ValueError - If overlap is not specified when using tiling. - ValueError - If patches is None. - """ - patches = None - - if patch_size is not None: - patches = None - - if patch_extraction_method == ExtractionStrategy.TILED: - if patch_overlap is None: - raise ValueError( - "Overlaps must be specified when using tiling (got None)." - ) - patches = _extract_tiles( - arr=sample, tile_size=patch_size, overlaps=patch_overlap - ) - - elif patch_extraction_method == ExtractionStrategy.SEQUENTIAL: - patches = _extract_patches_sequential(sample, patch_size=patch_size) - - else: - # random patching - patches = _extract_patches_random(sample, patch_size=patch_size) - - return patches - else: - # no patching, return a generator for the sample - return (sample for _ in range(1)) diff --git a/src/careamics/dataset/patching/__init__.py b/src/careamics/dataset/patching/__init__.py new file mode 100644 index 00000000..c9f5f219 --- /dev/null +++ b/src/careamics/dataset/patching/__init__.py @@ -0,0 +1,8 @@ +"""Patching and tiling functions.""" + + +__all__ = [ + "get_patch_transform", +] + +from .patch_transform import get_patch_transform diff --git a/src/careamics/dataset/patching/patch_transform.py b/src/careamics/dataset/patching/patch_transform.py new file mode 100644 index 00000000..15cde203 --- /dev/null +++ b/src/careamics/dataset/patching/patch_transform.py @@ -0,0 +1,44 @@ +from typing import List, Union + +import albumentations as Aug + +from careamics.config.data_model import TRANSFORMS_UNION +from careamics.transforms import get_all_transforms + + +# TODO add some explanations on how the additional_targets is used +def get_patch_transform( + patch_transforms: Union[List[TRANSFORMS_UNION], Aug.Compose], + with_target: bool, + normalize_mask: bool = True, +) -> Aug.Compose: + """Return a pixel manipulation function.""" + # if we passed a Compose, we just return it + if isinstance(patch_transforms, Aug.Compose): + return patch_transforms + + # empty list of transforms is a NoOp + elif len(patch_transforms) == 0: + return Aug.Compose( + [Aug.NoOp()], + additional_targets={}, # TODO this part need be checked (wrt segmentation) + ) + + # else we have a list of transforms + else: + # retrieve all transforms + all_transforms = get_all_transforms() + + # instantiate all transforms + transforms = [ + all_transforms[transform.name](**transform.model_dump()) + for transform in patch_transforms + ] + + return Aug.Compose( + transforms, + # apply image aug to "target" + additional_targets={"target": "image"} + if (with_target and normalize_mask) # TODO check this + else {}, + ) diff --git a/src/careamics/dataset/patching/patching.py b/src/careamics/dataset/patching/patching.py new file mode 100644 index 00000000..def9ce25 --- /dev/null +++ b/src/careamics/dataset/patching/patching.py @@ -0,0 +1,212 @@ +""" +Tiling submodule. + +These functions are used to tile images into patches or tiles. +""" +from pathlib import Path +from typing import Callable, List, Tuple, Union + +import numpy as np + +from ...utils.logging import get_logger +from ..dataset_utils import reshape_array +from .sequential_patching import extract_patches_sequential + +logger = get_logger(__name__) + + +# called by in memory dataset +def prepare_patches_supervised( + train_files: List[Path], + target_files: List[Path], + axes: str, + patch_size: Union[List[int], Tuple[int]], + read_source_func: Callable, +) -> Tuple[np.ndarray, np.ndarray, float, float]: + """ + Iterate over data source and create an array of patches and corresponding targets. + + Returns + ------- + np.ndarray + Array of patches. + """ + train_files.sort() + target_files.sort() + + means, stds, num_samples = 0, 0, 0 + all_patches, all_targets = [], [] + for train_filename, target_filename in zip(train_files, target_files): + try: + sample: np.ndarray = read_source_func(train_filename, axes) + target: np.ndarray = read_source_func(target_filename, axes) + means += sample.mean() + stds += sample.std() + num_samples += 1 + + # reshape array + sample = reshape_array(sample, axes) + target = reshape_array(target, axes) + + # generate patches, return a generator + patches, targets = extract_patches_sequential( + sample, patch_size=patch_size, target=target + ) + + # convert generator to list and add to all_patches + all_patches.append(patches) + + # ensure targets are not None (type checking) + if targets is not None: + all_targets.append(targets) + else: + raise ValueError(f"No target found for {target_filename}.") + + except Exception as e: + # emit warning and continue + logger.error(f"Failed to read {train_filename} or {target_filename}: {e}") + + # raise error if no valid samples found + if num_samples == 0 or len(all_patches) == 0: + raise ValueError( + f"No valid samples found in the input data: {train_files} and " + f"{target_files}." + ) + + result_mean, result_std = means / num_samples, stds / num_samples + + patch_array: np.ndarray = np.concatenate(all_patches, axis=0) + target_array: np.ndarray = np.concatenate(all_targets, axis=0) + logger.info(f"Extracted {patch_array.shape[0]} patches from input array.") + + return ( + patch_array, + target_array, + result_mean, + result_std, + ) + + +# called by in_memory_dataset +def prepare_patches_unsupervised( + train_files: List[Path], + axes: str, + patch_size: Union[List[int], Tuple[int]], + read_source_func: Callable, +) -> Tuple[np.ndarray, None, float, float]: + """ + Iterate over data source and create an array of patches. + + Returns + ------- + np.ndarray + Array of patches. + """ + means, stds, num_samples = 0, 0, 0 + all_patches = [] + for filename in train_files: + try: + sample: np.ndarray = read_source_func(filename, axes) + means += sample.mean() + stds += sample.std() + num_samples += 1 + + # reshape array + sample = reshape_array(sample, axes) + + # generate patches, return a generator + patches, _ = extract_patches_sequential(sample, patch_size=patch_size) + + # convert generator to list and add to all_patches + all_patches.append(patches) + except Exception as e: + # emit warning and continue + logger.error(f"Failed to read {filename}: {e}") + + # raise error if no valid samples found + if num_samples == 0: + raise ValueError(f"No valid samples found in the input data: {train_files}.") + + result_mean, result_std = means / num_samples, stds / num_samples + + patch_array: np.ndarray = np.concatenate(all_patches) + logger.info(f"Extracted {patch_array.shape[0]} patches from input array.") + + return patch_array, _, result_mean, result_std # TODO return object? + + +# called on arrays by in memory dataset +def prepare_patches_supervised_array( + data: np.ndarray, + axes: str, + data_target: np.ndarray, + patch_size: Union[List[int], Tuple[int]], +) -> Tuple[np.ndarray, np.ndarray, float, float]: + """Iterate over data source and create an array of patches. + + This method expects an array of shape SC(Z)YX, where S and C can be singleton + dimensions. + + Patches returned are of shape SC(Z)YX, where S is now the patches dimension. + + Returns + ------- + np.ndarray + Array of patches. + """ + # compute statistics + mean = data.mean() + std = data.std() + + # reshape array + reshaped_sample = reshape_array(data, axes) + reshaped_target = reshape_array(data_target, axes) + + # generate patches, return a generator + patches, patch_targets = extract_patches_sequential( + reshaped_sample, patch_size=patch_size, target=reshaped_target + ) + + if patch_targets is None: + raise ValueError("No target extracted.") + + logger.info(f"Extracted {patches.shape[0]} patches from input array.") + + return ( + patches, + patch_targets, + mean, + std, + ) + + +# called by in memory dataset +def prepare_patches_unsupervised_array( + data: np.ndarray, + axes: str, + patch_size: Union[List[int], Tuple[int]], +) -> Tuple[np.ndarray, None, float, float]: + """ + Iterate over data source and create an array of patches. + + This method expects an array of shape SC(Z)YX, where S and C can be singleton + dimensions. + + Patches returned are of shape SC(Z)YX, where S is now the patches dimension. + + Returns + ------- + np.ndarray + Array of patches. + """ + # calculate mean and std + mean = data.mean() + std = data.std() + + # reshape array + reshaped_sample = reshape_array(data, axes) + + # generate patches, return a generator + patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size) + + return patches, _, mean, std # TODO inelegant, replace by dataclass? diff --git a/src/careamics/dataset/patching/random_patching.py b/src/careamics/dataset/patching/random_patching.py new file mode 100644 index 00000000..c06c5bbd --- /dev/null +++ b/src/careamics/dataset/patching/random_patching.py @@ -0,0 +1,190 @@ +from typing import Generator, List, Optional, Tuple, Union + +import numpy as np +import zarr + +from .validate_patch_dimension import validate_patch_dimensions + + +# TOOD split in testable functions +def extract_patches_random( + arr: np.ndarray, + patch_size: Union[List[int], Tuple[int, ...]], + target: Optional[np.ndarray] = None, +) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]: + """ + Generate patches from an array in a random manner. + + The method calculates how many patches the image can be divided into and then + extracts an equal number of random patches. + + It returns a generator that yields the following: + + - patch: np.ndarray, dimension C(Z)YX. + - target_patch: np.ndarray, dimension C(Z)YX, if the target is present, None + otherwise. + + Parameters + ---------- + arr : np.ndarray + Input image array. + patch_size : Tuple[int] + Patch sizes in each dimension. + + Yields + ------ + Generator[np.ndarray, None, None] + Generator of patches. + """ + is_3d_patch = len(patch_size) == 3 + + # patches sanity check + validate_patch_dimensions(arr, patch_size, is_3d_patch) + + # Update patch size to encompass S and C dimensions + patch_size = [1, arr.shape[1], *patch_size] + + # random generator + rng = np.random.default_rng() + + # iterate over the number of samples (S or T) + for sample_idx in range(arr.shape[0]): + # get sample array + sample: np.ndarray = arr[sample_idx, ...] + + # same for target + if target is not None: + target_sample: np.ndarray = target[sample_idx, ...] + + # calculate the number of patches + n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int) + + # iterate over the number of patches + for _ in range(n_patches): + # get crop coordinates + crop_coords = [ + rng.integers(0, sample.shape[i] - patch_size[1:][i], endpoint=True) + for i in range(len(patch_size[1:])) + ] + + # extract patch + patch = ( + sample[ + ( + ..., # type: ignore + *[ # type: ignore + slice(c, c + patch_size[1:][i]) + for i, c in enumerate(crop_coords) + ], + ) + ] + .copy() + .astype(np.float32) + ) + + # same for target + if target is not None: + target_patch = ( + target_sample[ + ( + ..., # type: ignore + *[ # type: ignore + slice(c, c + patch_size[1:][i]) + for i, c in enumerate(crop_coords) + ], + ) + ] + .copy() + .astype(np.float32) + ) + # return patch and target patch + yield patch, target_patch + else: + # return patch + yield patch, None + + +def extract_patches_random_from_chunks( + arr: zarr.Array, + patch_size: Union[List[int], Tuple[int, ...]], + chunk_size: Union[List[int], Tuple[int, ...]], + chunk_limit: Optional[int] = None, +) -> Generator[np.ndarray, None, None]: + """ + Generate patches from an array in a random manner. + + The method calculates how many patches the image can be divided into and then + extracts an equal number of random patches. + + Parameters + ---------- + arr : np.ndarray + Input image array. + patch_size : Tuple[int] + Patch sizes in each dimension. + chunk_size : Tuple[int] + Chunk sizes to load from the. + + Yields + ------ + Generator[np.ndarray, None, None] + Generator of patches. + """ + is_3d_patch = len(patch_size) == 3 + + # Patches sanity check + validate_patch_dimensions(arr, patch_size, is_3d_patch) + + rng = np.random.default_rng() + num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape) + + # Iterate over num chunks in the array + for _ in range(num_chunks): + chunk_crop_coords = [ + rng.integers(0, max(0, arr.shape[i] - chunk_size[i]), endpoint=True) + for i in range(len(chunk_size)) + ] + chunk = arr[ + ( + ..., + *[slice(c, c + chunk_size[i]) for i, c in enumerate(chunk_crop_coords)], + ) + ].squeeze() + + # Add a singleton dimension if the chunk does not have a sample dimension + if len(chunk.shape) == len(patch_size): + chunk = np.expand_dims(chunk, axis=0) + + # Iterate over num samples (S) + for sample_idx in range(chunk.shape[0]): + spatial_chunk = chunk[sample_idx] + assert len(spatial_chunk.shape) == len( + patch_size + ), "Requested chunk shape is not equal to patch size" + + n_patches = np.ceil( + np.prod(spatial_chunk.shape) / np.prod(patch_size) + ).astype(int) + + # Iterate over the number of patches + for _ in range(n_patches): + patch_crop_coords = [ + rng.integers( + 0, spatial_chunk.shape[i] - patch_size[i], endpoint=True + ) + for i in range(len(patch_size)) + ] + patch = ( + spatial_chunk[ + ( + ..., + *[ + slice(c, c + patch_size[i]) + for i, c in enumerate(patch_crop_coords) + ], + ) + ] + .copy() + .astype(np.float32) + ) + yield patch diff --git a/src/careamics/dataset/patching/sequential_patching.py b/src/careamics/dataset/patching/sequential_patching.py new file mode 100644 index 00000000..a62f657a --- /dev/null +++ b/src/careamics/dataset/patching/sequential_patching.py @@ -0,0 +1,206 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np +from skimage.util import view_as_windows + +from .validate_patch_dimension import validate_patch_dimensions + + +def _compute_number_of_patches( + arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]] +) -> Tuple[int, ...]: + """ + Compute the number of patches that fit in each dimension. + + Parameters + ---------- + arr : Tuple[int, ...] + Shape of the input array. + patch_sizes : Tuple[int] + Shape of the patches. + + Returns + ------- + Tuple[int] + Number of patches in each dimension. + """ + if len(arr_shape) != len(patch_sizes): + raise ValueError( + f"Array shape {arr_shape} and patch size {patch_sizes} should have the " + f"same dimension, including singleton dimension for S and equal dimension " + f"for C." + ) + + try: + n_patches = [ + np.ceil(arr_shape[i] / patch_sizes[i]).astype(int) + for i in range(len(patch_sizes)) + ] + except IndexError as e: + raise ValueError( + f"Patch size {patch_sizes} is not compatible with array shape {arr_shape}" + ) from e + + return tuple(n_patches) + + +def _compute_overlap( + arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]] +) -> Tuple[int, ...]: + """ + Compute the overlap between patches in each dimension. + + If the array dimensions are divisible by the patch sizes, then the overlap is + 0. Otherwise, it is the result of the division rounded to the upper value. + + Parameters + ---------- + arr : Tuple[int, ...] + Input array shape. + patch_sizes : Tuple[int] + Size of the patches. + + Returns + ------- + Tuple[int] + Overlap between patches in each dimension. + """ + n_patches = _compute_number_of_patches(arr_shape, patch_sizes) + + overlap = [ + np.ceil( + np.clip(n_patches[i] * patch_sizes[i] - arr_shape[i], 0, None) + / max(1, (n_patches[i] - 1)) + ).astype(int) + for i in range(len(patch_sizes)) + ] + return tuple(overlap) + + +def _compute_patch_steps( + patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...] +) -> Tuple[int, ...]: + """ + Compute steps between patches. + + Parameters + ---------- + patch_sizes : Tuple[int] + Size of the patches. + overlaps : Tuple[int] + Overlap between patches. + + Returns + ------- + Tuple[int] + Steps between patches. + """ + steps = [ + min(patch_sizes[i] - overlaps[i], patch_sizes[i]) + for i in range(len(patch_sizes)) + ] + return tuple(steps) + + +# TODO why stack the target here and not on a different dimension before this function? +def _compute_patch_views( + arr: np.ndarray, + window_shape: List[int], + step: Tuple[int, ...], + output_shape: List[int], + target: Optional[np.ndarray] = None, +) -> np.ndarray: + """ + Compute views of an array corresponding to patches. + + Parameters + ---------- + arr : np.ndarray + Array from which the views are extracted. + window_shape : Tuple[int] + Shape of the views. + step : Tuple[int] + Steps between views. + output_shape : Tuple[int] + Shape of the output array. + + Returns + ------- + np.ndarray + Array with views dimension. + """ + rng = np.random.default_rng() + + if target is not None: + arr = np.stack([arr, target], axis=0) + window_shape = [arr.shape[0], *window_shape] + step = (arr.shape[0], *step) + output_shape = [arr.shape[0], -1, arr.shape[2], *output_shape[2:]] + + patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape( + *output_shape + ) + if target is not None: + rng.shuffle(patches, axis=1) + else: + rng.shuffle(patches, axis=0) + return patches + + +def extract_patches_sequential( + arr: np.ndarray, + patch_size: Union[List[int], Tuple[int, ...]], + target: Optional[np.ndarray] = None, +) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Generate patches from an array in a sequential manner. + + Array dimensions should be SC(Z)YX, where S and C can be singleton dimensions. The + patches are generated sequentially and cover the whole array. + + Parameters + ---------- + arr : np.ndarray + Input image array. + patch_size : Tuple[int] + Patch sizes in each dimension. + + Returns + ------- + Generator[Tuple[np.ndarray, ...], None, None] + Generator of patches. + """ + is_3d_patch = len(patch_size) == 3 + + # Patches sanity check + validate_patch_dimensions(arr, patch_size, is_3d_patch) + + # Update patch size to encompass S and C dimensions + patch_size = [1, arr.shape[1], *patch_size] + + # Compute overlap + overlaps = _compute_overlap(arr_shape=arr.shape, patch_sizes=patch_size) + + # Create view window and overlaps + window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps) + + output_shape = [ + -1, + ] + patch_size[1:] + + # Generate a view of the input array containing pre-calculated number of patches + # in each dimension with overlap. + # Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches, C, Y, X) + patches = _compute_patch_views( + arr, + window_shape=patch_size, + step=window_steps, + output_shape=output_shape, + target=target, + ) + + if target is not None: + # target was concatenated to patches in _compute_reshaped_view + return (patches[0, ...], patches[1, ...]) # TODO in _compute_reshaped_view? + else: + return patches, None diff --git a/src/careamics/dataset/patching/tiled_patching.py b/src/careamics/dataset/patching/tiled_patching.py new file mode 100644 index 00000000..bd618dae --- /dev/null +++ b/src/careamics/dataset/patching/tiled_patching.py @@ -0,0 +1,158 @@ +import itertools +from typing import Generator, List, Tuple, Union + +import numpy as np + +from careamics.config.tile_information import TileInformation + + +def _compute_crop_and_stitch_coords_1d( + axis_size: int, tile_size: int, overlap: int +) -> Tuple[List[Tuple[int, ...]], ...]: + """ + Compute the coordinates of each tile along an axis, given the overlap. + + Parameters + ---------- + axis_size : int + Length of the axis. + tile_size : int + Size of the tile for the given axis. + overlap : int + Size of the overlap for the given axis. + + Returns + ------- + Tuple[Tuple[int, ...], ...] + Tuple of all coordinates for given axis. + """ + # Compute the step between tiles + step = tile_size - overlap + crop_coords = [] + stitch_coords = [] + overlap_crop_coords = [] + + # Iterate over the axis with step + for i in range(0, max(1, axis_size - overlap), step): + # Check if the tile fits within the axis + if i + tile_size <= axis_size: + # Add the coordinates to crop one tile + crop_coords.append((i, i + tile_size)) + + # Add the pixel coordinates of the cropped tile in the original image space + stitch_coords.append( + ( + i + overlap // 2 if i > 0 else 0, + i + tile_size - overlap // 2 + if crop_coords[-1][1] < axis_size + else axis_size, + ) + ) + + # Add the coordinates to crop the overlap from the prediction. + overlap_crop_coords.append( + ( + overlap // 2 if i > 0 else 0, + tile_size - overlap // 2 + if crop_coords[-1][1] < axis_size + else tile_size, + ) + ) + + # If the tile does not fit within the axis, perform the abovementioned + # operations starting from the end of the axis + else: + # if (axis_size - tile_size, axis_size) not in crop_coords: + crop_coords.append((max(0, axis_size - tile_size), axis_size)) + last_tile_end_coord = stitch_coords[-1][1] if stitch_coords else 1 + stitch_coords.append((last_tile_end_coord, axis_size)) + overlap_crop_coords.append( + (tile_size - (axis_size - last_tile_end_coord), tile_size) + ) + break + return crop_coords, stitch_coords, overlap_crop_coords + + +def extract_tiles( + arr: np.ndarray, + tile_size: Union[List[int], Tuple[int, ...]], + overlaps: Union[List[int], Tuple[int, ...]], +) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: + """ + Generate tiles from the input array with specified overlap. + + The tiles cover the whole array. The method returns a generator that yields + tuples of array and tile information, the latter includes whether + the tile is the last one, the coordinates of the overlap crop, and the coordinates + of the stitched tile. + + The array has shape C(Z)YX, where C can be a singleton. + + Parameters + ---------- + arr : np.ndarray + Array of shape (S, C, (Z), Y, X). + tile_size : Union[List[int], Tuple[int]] + Tile sizes in each dimension, of length 2 or 3. + overlaps : Union[List[int], Tuple[int]] + Overlap values in each dimension, of length 2 or 3. + + Yields + ------ + Generator[Tuple[np.ndarray, TileInformation], None, None] + Tile generator, yields the tile and additional information. + """ + # Iterate over num samples (S) + for sample_idx in range(arr.shape[0]): + sample: np.ndarray = arr[sample_idx, ...] + + # Create a list of coordinates for cropping and stitching all axes. + # [crop coordinates, stitching coordinates, overlap crop coordinates] + # For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d + # will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)]) + crop_and_stitch_coords_list = [ + _compute_crop_and_stitch_coords_1d( + sample.shape[i + 1], tile_size[i], overlaps[i] + ) + for i in range(len(tile_size)) + ] + + # Rearrange crop coordinates from a list of coordinate pairs per axis to a list + # grouped by type. + all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip( + *crop_and_stitch_coords_list + ) + + # Maximum tile index + max_tile_idx = np.prod([len(axis) for axis in all_crop_coords]) - 1 + + # Iterate over generated coordinate pairs: + for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate( + zip( + itertools.product(*all_crop_coords), + itertools.product(*all_stitch_coords), + itertools.product(*all_overlap_crop_coords), + ) + ): + # Extract tile from the sample + tile: np.ndarray = sample[ + (..., *[slice(c[0], c[1]) for c in list(crop_coords)]) # type: ignore + ] + + # Check if we are at the end of the sample by computing the length of the + # array that contains all the tiles + if tile_idx == max_tile_idx: + last_tile = True + else: + last_tile = False + + # create tile information + tile_info = TileInformation( + array_shape=sample.squeeze().shape, + tiled=True, + last_tile=last_tile, + overlap_crop_coords=overlap_crop_coords, + stitch_coords=stitch_coords, + ) + + yield tile, tile_info diff --git a/src/careamics/dataset/patching/validate_patch_dimension.py b/src/careamics/dataset/patching/validate_patch_dimension.py new file mode 100644 index 00000000..ffd013b0 --- /dev/null +++ b/src/careamics/dataset/patching/validate_patch_dimension.py @@ -0,0 +1,60 @@ +from typing import List, Tuple, Union + +import numpy as np + + +def validate_patch_dimensions( + arr: np.ndarray, + patch_size: Union[List[int], Tuple[int, ...]], + is_3d_patch: bool, +) -> None: + """ + Check patch size and array compatibility. + + This method validates the patch sizes with respect to the array dimensions: + + - Patch must have two dimensions fewer than the array (S and C). + - Patch sizes are smaller than the corresponding array dimensions. + + If one of these conditions is not met, a ValueError is raised. + + This method should be called after inputs have been resized. + + Parameters + ---------- + arr : np.ndarray + Input array. + patch_size : Union[List[int], Tuple[int, ...]] + Size of the patches along each dimension of the array, except the first. + is_3d_patch : bool + Whether the patch is 3D or not. + + Raises + ------ + ValueError + If the patch size is not consistent with the array shape (one more array + dimension). + ValueError + If the patch size in Z is larger than the array dimension. + ValueError + If either of the patch sizes in X or Y is larger than the corresponding array + dimension. + """ + if len(patch_size) != len(arr.shape[2:]): + raise ValueError( + f"There must be a patch size for each spatial dimensions " + f"(got {patch_size} patches for dims {arr.shape})." + ) + + # Sanity checks on patch sizes versus array dimension + if is_3d_patch and patch_size[0] > arr.shape[-3]: + raise ValueError( + f"Z patch size is inconsistent with image shape " + f"(got {patch_size[0]} patches for dim {arr.shape[1]})." + ) + + if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]: + raise ValueError( + f"At least one of YX patch dimensions is larger than the corresponding " + f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]})." + ) diff --git a/src/careamics/dataset/prepare_dataset.py b/src/careamics/dataset/prepare_dataset.py deleted file mode 100644 index e99f42aa..00000000 --- a/src/careamics/dataset/prepare_dataset.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -Dataset preparation module. - -Methods to set up the datasets for training, validation and prediction. -""" -from pathlib import Path -from typing import List, Optional, Union - -from careamics.config import Configuration -from careamics.manipulation import default_manipulate -from careamics.utils import check_tiling_validity - -from .extraction_strategy import ExtractionStrategy -from .in_memory_dataset import InMemoryDataset -from .tiff_dataset import TiffDataset - - -def get_train_dataset( - config: Configuration, train_path: str -) -> Union[TiffDataset, InMemoryDataset]: - """ - Create training dataset. - - Depending on the configuration, this methods return either a TiffDataset or an - InMemoryDataset. - - Parameters - ---------- - config : Configuration - Configuration. - train_path : Union[str, Path] - Path to training data. - - Returns - ------- - Union[TiffDataset, InMemoryDataset] - Dataset. - """ - if config.data.in_memory: - dataset = InMemoryDataset( - data_path=train_path, - data_format=config.data.data_format, - axes=config.data.axes, - mean=config.data.mean, - std=config.data.std, - patch_extraction_method=ExtractionStrategy.SEQUENTIAL, - patch_size=config.training.patch_size, - patch_transform=default_manipulate, - patch_transform_params={ - "mask_pixel_percentage": config.algorithm.masked_pixel_percentage, - "roi_size": config.algorithm.roi_size, - }, - ) - else: - dataset = TiffDataset( - data_path=train_path, - data_format=config.data.data_format, - axes=config.data.axes, - mean=config.data.mean, - std=config.data.std, - patch_extraction_method=ExtractionStrategy.RANDOM, - patch_size=config.training.patch_size, - patch_transform=default_manipulate, - patch_transform_params={ - "mask_pixel_percentage": config.algorithm.masked_pixel_percentage, - "roi_size": config.algorithm.roi_size, - }, - ) - return dataset - - -def get_validation_dataset(config: Configuration, val_path: str) -> InMemoryDataset: - """ - Create validation dataset. - - Validation dataset is kept in memory. - - Parameters - ---------- - config : Configuration - Configuration. - val_path : Union[str, Path] - Path to validation data. - - Returns - ------- - TiffDataset - In memory dataset. - """ - data_path = val_path - - dataset = InMemoryDataset( - data_path=data_path, - data_format=config.data.data_format, - axes=config.data.axes, - mean=config.data.mean, - std=config.data.std, - patch_extraction_method=ExtractionStrategy.SEQUENTIAL, - patch_size=config.training.patch_size, - patch_transform=default_manipulate, - patch_transform_params={ - "mask_pixel_percentage": config.algorithm.masked_pixel_percentage - }, - ) - - return dataset - - -def get_prediction_dataset( - config: Configuration, - pred_path: Union[str, Path], - *, - tile_shape: Optional[List[int]] = None, - overlaps: Optional[List[int]] = None, - axes: Optional[str] = None, -) -> TiffDataset: - """ - Create prediction dataset. - - To use tiling, both `tile_shape` and `overlaps` must be specified, have same - length, be divisible by 2 and greater than 0. Finally, the overlaps must be - smaller than the tiles. - - By default, axes are extracted from the configuration. To use images with - different axes, set the `axes` parameter. Note that the difference between - configuration and parameter axes must be S or T, but not any of the spatial - dimensions (e.g. 2D vs 3D). - - Parameters - ---------- - config : Configuration - Configuration. - pred_path : Union[str, Path] - Path to prediction data. - tile_shape : Optional[List[int]], optional - 2D or 3D shape of the tiles, by default None. - overlaps : Optional[List[int]], optional - 2D or 3D overlaps between tiles, by default None. - axes : Optional[str], optional - Axes of the data, by default None. - - Returns - ------- - TiffDataset - Dataset. - """ - use_tiling = False # default value - - # Validate tiles and overlaps - if tile_shape is not None and overlaps is not None: - check_tiling_validity(tile_shape, overlaps) - - # Use tiling - use_tiling = True - - # Extraction strategy - if use_tiling: - patch_extraction_method = ExtractionStrategy.TILED - else: - patch_extraction_method = None - - # Create dataset - dataset = TiffDataset( - data_path=pred_path, - data_format=config.data.data_format, - axes=config.data.axes if axes is None else axes, # supersede axes - mean=config.data.mean, - std=config.data.std, - patch_size=tile_shape, - patch_overlap=overlaps, - patch_extraction_method=patch_extraction_method, - patch_transform=None, - ) - - return dataset diff --git a/src/careamics/dataset/tiff_dataset.py b/src/careamics/dataset/tiff_dataset.py deleted file mode 100644 index f8a81d7d..00000000 --- a/src/careamics/dataset/tiff_dataset.py +++ /dev/null @@ -1,212 +0,0 @@ -""" -Tiff dataset module. - -This module contains the implementation of the TiffDataset class, which allows loading -tiff files. -""" -from pathlib import Path -from typing import Callable, Dict, Generator, List, Optional, Tuple, Union - -import numpy as np -import torch - -from careamics.utils import normalize -from careamics.utils.logging import get_logger - -from .dataset_utils import ( - list_files, - read_tiff, -) -from .extraction_strategy import ExtractionStrategy -from .patching import generate_patches - -logger = get_logger(__name__) - - -class TiffDataset(torch.utils.data.IterableDataset): - """ - Dataset allowing extracting patches from tiff images. - - Parameters - ---------- - data_path : Union[str, Path] - Path to the data, must be a directory. - data_format : str - Extension of the files to load, without the period. - axes : str - Description of axes in format STCZYX. - patch_extraction_method : Union[ExtractionStrategies, None] - Patch extraction strategy, as defined in extraction_strategy. - patch_size : Optional[Union[List[int], Tuple[int]]], optional - Size of the patches in each dimension, by default None. - patch_overlap : Optional[Union[List[int], Tuple[int]]], optional - Overlap of the patches in each dimension, by default None. - mean : Optional[float], optional - Expected mean of the dataset, by default None. - std : Optional[float], optional - Expected standard deviation of the dataset, by default None. - patch_transform : Optional[Callable], optional - Patch transform callable, by default None. - patch_transform_params : Optional[Dict], optional - Patch transform parameters, by default None. - """ - - def __init__( - self, - data_path: Union[str, Path], - data_format: str, # TODO: TiffDataset should not know that they are tiff - axes: str, - patch_extraction_method: Union[ExtractionStrategy, None], - patch_size: Optional[Union[List[int], Tuple[int]]] = None, - patch_overlap: Optional[Union[List[int], Tuple[int]]] = None, - mean: Optional[float] = None, - std: Optional[float] = None, - patch_transform: Optional[Callable] = None, - patch_transform_params: Optional[Dict] = None, - ) -> None: - """ - Constructor. - - Parameters - ---------- - data_path : Union[str, Path] - Path to the data, must be a directory. - data_format : str - Extension of the files to load, without the period. - axes : str - Description of axes in format STCZYX. - patch_extraction_method : Union[ExtractionStrategies, None] - Patch extraction strategy, as defined in extraction_strategy. - patch_size : Optional[Union[List[int], Tuple[int]]], optional - Size of the patches in each dimension, by default None. - patch_overlap : Optional[Union[List[int], Tuple[int]]], optional - Overlap of the patches in each dimension, by default None. - mean : Optional[float], optional - Mean of the dataset, by default None. - std : Optional[float], optional - Standard deviation of the dataset, by default None. - patch_transform : Optional[Callable], optional - Patch transform callable, by default None. - patch_transform_params : Optional[Dict], optional - Patch transform parameters, by default None. - - Raises - ------ - ValueError - If data_path is not a directory. - """ - self.data_path = Path(data_path) - if not self.data_path.is_dir(): - raise ValueError("Path to data should be an existing folder.") - - self.data_format = data_format - self.axes = axes - - self.patch_transform = patch_transform - - self.files = list_files(self.data_path, self.data_format) - - self.mean = mean - self.std = std - if not mean or not std: - self.mean, self.std = self._calculate_mean_and_std() - - self.patch_size = patch_size - self.patch_overlap = patch_overlap - self.patch_extraction_method = patch_extraction_method - self.patch_transform = patch_transform - self.patch_transform_params = patch_transform_params - - def _calculate_mean_and_std(self) -> Tuple[float, float]: - """ - Calculate mean and std of the dataset. - - Returns - ------- - Tuple[float, float] - Tuple containing mean and standard deviation. - """ - means, stds = 0, 0 - num_samples = 0 - - for sample in self._iterate_files(): - means += sample.mean() - stds += np.std(sample) - num_samples += 1 - - result_mean = means / num_samples - result_std = stds / num_samples - - logger.info(f"Calculated mean and std for {num_samples} images") - logger.info(f"Mean: {result_mean}, std: {result_std}") - return result_mean, result_std - - def _iterate_files(self) -> Generator: - """ - Iterate over data source and yield whole image. - - Yields - ------ - np.ndarray - Image. - """ - # When num_workers > 0, each worker process will have a different copy of the - # dataset object - # Configuring each copy independently to avoid having duplicate data returned - # from the workers - worker_info = torch.utils.data.get_worker_info() - worker_id = worker_info.id if worker_info is not None else 0 - num_workers = worker_info.num_workers if worker_info is not None else 1 - - for i, filename in enumerate(self.files): - if i % num_workers == worker_id: - sample = read_tiff(filename, self.axes) - yield sample - - def __iter__(self) -> Generator[np.ndarray, None, None]: - """ - Iterate over data source and yield single patch. - - Yields - ------ - np.ndarray - Single patch. - """ - assert ( - self.mean is not None and self.std is not None - ), "Mean and std must be provided" - for sample in self._iterate_files(): - # TODO patch_extraction_method should never be None! - if self.patch_extraction_method: - # TODO: move S and T unpacking logic from patch generator - patches = generate_patches( - sample, - self.patch_extraction_method, - self.patch_size, - self.patch_overlap, - ) - - for patch in patches: - if isinstance(patch, tuple): - normalized_patch = normalize( - img=patch[0], mean=self.mean, std=self.std - ) - patch = (normalized_patch, *patch[1:]) - else: - patch = normalize(img=patch, mean=self.mean, std=self.std) - - if self.patch_transform is not None: - assert self.patch_transform_params is not None - patch = self.patch_transform( - patch, **self.patch_transform_params - ) - - yield patch - - else: - # if S or T dims are not empty - assume every image is a separate - # sample in dim 0 - for i in range(sample.shape[0]): - item = np.expand_dims(sample[i], (0, 1)) - item = normalize(img=item, mean=self.mean, std=self.std) - yield item diff --git a/src/careamics/dataset/zarr_dataset.py b/src/careamics/dataset/zarr_dataset.py new file mode 100644 index 00000000..ee54fdd2 --- /dev/null +++ b/src/careamics/dataset/zarr_dataset.py @@ -0,0 +1,149 @@ +# from itertools import islice +# from typing import Callable, Dict, List, Optional, Tuple, Union + +# import numpy as np +# import torch +# import zarr + +# from careamics.utils import RunningStats +# from careamics.utils.logging import get_logger + +# from ..utils import normalize +# from .dataset_utils import read_zarr +# from .patching.patching import ( +# generate_patches_unsupervised, +# ) + +# logger = get_logger(__name__) + + +# class ZarrDataset(torch.utils.data.IterableDataset): +# """Dataset to extract patches from a zarr storage. + +# Parameters +# ---------- +# data_source : Union[zarr.Group, zarr.Array] +# Zarr storage. +# axes : str +# Description of axes in format STCZYX. +# patch_extraction_method : Union[ExtractionStrategies, None] +# Patch extraction strategy, as defined in extraction_strategy. +# patch_size : Optional[Union[List[int], Tuple[int]]], optional +# Size of the patches in each dimension, by default None. +# num_patches : Optional[int], optional +# Number of patches to extract, by default None. +# mean : Optional[float], optional +# Expected mean of the dataset, by default None. +# std : Optional[float], optional +# Expected standard deviation of the dataset, by default None. +# patch_transform : Optional[Callable], optional +# Patch transform callable, by default None. +# patch_transform_params : Optional[Dict], optional +# Patch transform parameters, by default None. +# running_stats_window_perc : float, optional +# Percentage of the dataset to use for calculating the initial mean and standard +# deviation, by default 0.01. +# mode : str, optional +# train/predict, controls running stats calculation. +# """ + +# def __init__( +# self, +# data_source: Union[zarr.Group, zarr.Array], +# axes: str, +# patch_extraction_method: Union[SupportedExtractionStrategy, None], +# patch_size: Optional[Union[List[int], Tuple[int]]] = None, +# num_patches: Optional[int] = None, +# mean: Optional[float] = None, +# std: Optional[float] = None, +# patch_transform: Optional[Callable] = None, +# patch_transform_params: Optional[Dict] = None, +# running_stats_window_perc: float = 0.01, +# mode: str = "train", +# ) -> None: +# self.data_source = data_source +# self.axes = axes +# self.patch_extraction_method = patch_extraction_method +# self.patch_size = patch_size +# self.num_patches = num_patches +# self.mean = mean +# self.std = std +# self.patch_transform = patch_transform +# self.patch_transform_params = patch_transform_params +# self.sample = read_zarr(self.data_source, self.axes) +# self.running_stats_window = int( +# np.prod(self.sample._cdata_shape) * running_stats_window_perc +# ) +# self.mode = mode +# self.running_stats = RunningStats() + +# self._calculate_initial_mean_std() + +# def _calculate_initial_mean_std(self): +# """Calculate initial mean and std of the dataset.""" +# if self.mean is None and self.std is None: +# idxs = np.random.randint( +# 0, +# np.prod(self.sample._cdata_shape), +# size=max(1, self.running_stats_window), +# ) +# random_chunks = self.sample[idxs] +# self.running_stats.init(random_chunks.mean(), random_chunks.std()) + +# def _generate_patches(self): +# """Generate patches from the dataset and calculates running stats. + +# Yields +# ------ +# np.ndarray +# Patch. +# """ +# patches = generate_patches_unsupervised( +# self.sample, +# self.patch_extraction_method, +# self.patch_size, +# ) + +# # num_patches = np.ceil( +# # np.prod(self.sample.chunks) +# # / (np.prod(self.patch_size) * self.running_stats_window) +# # ).astype(int) + +# for idx, patch in enumerate(patches): +# if self.mode != "predict": +# self.running_stats.update(patch.mean()) +# if isinstance(patch, tuple): +# normalized_patch = normalize( +# img=patch[0], +# mean=self.running_stats.avg_mean.value, +# std=self.running_stats.avg_std.value, +# ) +# patch = (normalized_patch, *patch[1:]) +# else: +# patch = normalize( +# img=patch, +# mean=self.running_stats.avg_mean.value, +# std=self.running_stats.avg_std.value, +# ) + +# if self.patch_transform is not None: +# assert self.patch_transform_params is not None +# patch = self.patch_transform(patch, **self.patch_transform_params) +# if self.num_patches is not None and idx >= self.num_patches: +# return +# else: +# yield patch +# self.mean = self.running_stats.avg_mean.value +# self.std = self.running_stats.avg_std.value + +# def __iter__(self): +# """ +# Iterate over data source and yield single patch. + +# Yields +# ------ +# np.ndarray +# """ +# worker_info = torch.utils.data.get_worker_info() +# num_workers = worker_info.num_workers if worker_info is not None else 1 +# yield from islice(self._generate_patches(), 0, None, num_workers) diff --git a/src/careamics/engine.py b/src/careamics/engine.py deleted file mode 100644 index 2aa78571..00000000 --- a/src/careamics/engine.py +++ /dev/null @@ -1,1014 +0,0 @@ -""" -Engine module. - -This module contains the main CAREamics class, the Engine. The Engine allows training -a model and using it for prediction. -""" -from logging import FileHandler -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from torch.utils.data import DataLoader, TensorDataset - -from .bioimage import ( - get_default_model_specs, - save_bioimage_model, -) -from .config import Configuration, load_configuration -from .dataset.prepare_dataset import ( - get_prediction_dataset, - get_train_dataset, - get_validation_dataset, -) -from .losses import create_loss_function -from .models import create_model -from .prediction import ( - stitch_prediction, - tta_backward, - tta_forward, -) -from .utils import ( - MetricTracker, - add_axes, - denormalize, - get_device, - normalize, -) -from .utils.logging import ProgressBar, get_logger - - -class Engine: - """ - Class allowing training of a model and subsequent prediction. - - There are three ways to instantiate an Engine: - 1. With a CAREamics model (.pth), by passing a path. - 2. With a configuration object. - 3. With a configuration file, by passing a path. - - In each case, the parameter name must be provided explicitly. For example: - >>> engine = Engine(config_path="path/to/config.yaml") - - Note that only one of these options can be used at a time, in the order listed - above. - - Parameters - ---------- - config : Optional[Configuration], optional - Configuration object, by default None. - config_path : Optional[Union[str, Path]], optional - Path to configuration file, by default None. - model_path : Optional[Union[str, Path]], optional - Path to model file, by default None. - seed : int, optional - Seed for reproducibility, by default 42. - - Attributes - ---------- - cfg : Configuration - Configuration. - device : torch.device - Device (CPU or GPU). - model : torch.nn.Module - Model. - optimizer : torch.optim.Optimizer - Optimizer. - lr_scheduler : torch.optim.lr_scheduler._LRScheduler - Learning rate scheduler. - scaler : torch.cuda.amp.GradScaler - Gradient scaler. - loss_func : Callable - Loss function. - logger : logging.Logger - Logger. - use_wandb : bool - Whether to use wandb. - """ - - def __init__( - self, - *, - config: Optional[Configuration] = None, - config_path: Optional[Union[str, Path]] = None, - model_path: Optional[Union[str, Path]] = None, - seed: Optional[int] = 42, - ) -> None: - """ - Constructor. - - To disable the seed, set it to None. - - Parameters - ---------- - config : Optional[Configuration], optional - Configuration object, by default None. - config_path : Optional[Union[str, Path]], optional - Path to configuration file, by default None. - model_path : Optional[Union[str, Path]], optional - Path to model file, by default None. - seed : int, optional - Seed for reproducibility, by default 42. - - Raises - ------ - ValueError - If all three parameters are None. - FileNotFoundError - If the model or configuration path is provided but does not exist. - TypeError - If the configuration is not a Configuration object. - UsageError - If wandb is not correctly installed. - ModuleNotFoundError - If wandb is not installed. - ValueError - If the configuration failed to configure. - """ - if model_path is not None: - if not Path(model_path).exists(): - raise FileNotFoundError( - f"Model path {model_path} is incorrect or" - f" does not exist. Current working directory is: {Path.cwd()!s}" - ) - - # Ensure that config is None - self.cfg = None - - elif config is not None: - # Check that config is a Configuration object - if not isinstance(config, Configuration): - raise TypeError( - f"config must be a Configuration object, got {type(config)}" - ) - self.cfg = config - elif config_path is not None: - self.cfg = load_configuration(config_path) - else: - raise ValueError( - "No configuration or path provided. One of configuration " - "object, configuration path or model path must be provided." - ) - - # get device, CPU or GPU - self.device = get_device() - - # Create model, optimizer, lr scheduler and gradient scaler and load everything - # to the specified device - ( - self.model, - self.optimizer, - self.lr_scheduler, - self.scaler, - self.cfg, - ) = create_model(config=self.cfg, model_path=model_path, device=self.device) - assert self.cfg is not None - - # create loss function - self.loss_func = create_loss_function(self.cfg) - - # Set logging - log_path = self.cfg.working_directory / "log.txt" - self.logger = get_logger(__name__, log_path=log_path) - - # wandb - self.use_wandb = self.cfg.training.use_wandb - - if self.use_wandb: - try: - from wandb.errors import UsageError - - from careamics.utils.wandb import WandBLogging - - try: - self.wandb = WandBLogging( - experiment_name=self.cfg.experiment_name, - log_path=self.cfg.working_directory, - config=self.cfg, - model_to_watch=self.model, - ) - except UsageError as e: - self.logger.warning( - f"Wandb usage error, using default logger. Check whether " - f"wandb correctly configured:\n" - f"{e}" - ) - self.use_wandb = False - - except ModuleNotFoundError: - self.logger.warning( - "Wandb not installed, using default logger. Try pip install " - "wandb" - ) - self.use_wandb = False - - # BMZ inputs/outputs placeholders, filled during validation - self._input = None - self._outputs = None - - # torch version - self.torch_version = torch.__version__ - - def train( - self, - train_path: str, - val_path: str, - ) -> Tuple[List[Any], List[Any]]: - """ - Train the network. - - The training and validation data given by the paths must be compatible with the - axes and data format provided in the configuration. - - Parameters - ---------- - train_path : Union[str, Path] - Path to the training data. - val_path : Union[str, Path] - Path to the validation data. - - Returns - ------- - Tuple[List[Any], List[Any]] - Tuple of training and validation statistics. - - Raises - ------ - ValueError - Raise a ValueError if the configuration is missing. - """ - if self.cfg is None: - raise ValueError("Configuration is not defined, cannot train.") - - # General func - train_loader = self._get_train_dataloader(train_path) - - # Set mean and std from train dataset of none - if self.cfg.data.mean is None or self.cfg.data.std is None: - self.cfg.data.set_mean_and_std( - train_loader.dataset.mean, train_loader.dataset.std - ) - - eval_loader = self._get_val_dataloader(val_path) - self.logger.info(f"Starting training for {self.cfg.training.num_epochs} epochs") - - val_losses = [] - - try: - train_stats = [] - eval_stats = [] - - # loop over the dataset multiple times - for epoch in range(self.cfg.training.num_epochs): - if hasattr(train_loader.dataset, "__len__"): - epoch_size = train_loader.__len__() - else: - epoch_size = None - - progress_bar = ProgressBar( - max_value=epoch_size, - epoch=epoch, - num_epochs=self.cfg.training.num_epochs, - mode="train", - ) - # train_epoch = train_op(self._train_single_epoch,) - # Perform training step - train_outputs, epoch_size = self._train_single_epoch( - train_loader, - progress_bar, - self.cfg.training.amp.use, - ) - # Perform validation step - eval_outputs = self._evaluate(eval_loader) - val_losses.append(eval_outputs["loss"]) - learning_rate = self.optimizer.param_groups[0]["lr"] - - progress_bar.add( - 1, - values=[ - ("train_loss", train_outputs["loss"]), - ("val loss", eval_outputs["loss"]), - ("lr", learning_rate), - ], - ) - # Add update scheduler rule based on type - self.lr_scheduler.step(eval_outputs["loss"]) - - if self.use_wandb: - metrics = { - "train": train_outputs, - "eval": eval_outputs, - "lr": learning_rate, - } - self.wandb.log_metrics(metrics) - - train_stats.append(train_outputs) - eval_stats.append(eval_outputs) - - checkpoint_path = self._save_checkpoint(epoch, val_losses, "state_dict") - self.logger.info(f"Saved checkpoint to {checkpoint_path}") - - except KeyboardInterrupt: - self.logger.info("Training interrupted") - - return train_stats, eval_stats - - def _train_single_epoch( - self, - loader: torch.utils.data.DataLoader, - progress_bar: ProgressBar, - amp: bool, - ) -> Tuple[Dict[str, float], int]: - """ - Train for a single epoch. - - Parameters - ---------- - loader : torch.utils.data.DataLoader - Training dataloader. - progress_bar : ProgressBar - Progress bar. - amp : bool - Whether to use automatic mixed precision. - - Returns - ------- - Tuple[Dict[str, float], int] - Tuple of training metrics and epoch size. - - Raises - ------ - ValueError - If the configuration is missing. - """ - if self.cfg is not None: - avg_loss = MetricTracker() - self.model.train() - epoch_size = 0 - - for i, (batch, *auxillary) in enumerate(loader): - self.optimizer.zero_grad(set_to_none=True) - - with torch.cuda.amp.autocast(enabled=amp): - outputs = self.model(batch.to(self.device)) - - loss = self.loss_func( - outputs, *[a.to(self.device) for a in auxillary], self.device - ) - self.scaler.scale(loss).backward() - avg_loss.update(loss.detach(), batch.shape[0]) - - progress_bar.update( - current_step=i, - batch_size=self.cfg.training.batch_size, - ) - - self.optimizer.step() - epoch_size += 1 - - return {"loss": avg_loss.avg.to(torch.float16).cpu().numpy()}, epoch_size - else: - raise ValueError("Configuration is not defined, cannot train.") - - def _evaluate(self, val_loader: torch.utils.data.DataLoader) -> Dict[str, float]: - """ - Perform validation step. - - Parameters - ---------- - val_loader : torch.utils.data.DataLoader - Validation dataloader. - - Returns - ------- - Dict[str, float] - Loss value on the validation set. - """ - self.model.eval() - avg_loss = MetricTracker() - - with torch.no_grad(): - for patch, *auxillary in val_loader: - # if inputs is None, record a single patch - if self._input is None: - # patch has dimension SC(Z)YX - self._input = patch.clone().detach().cpu().numpy() - - # evaluate - outputs = self.model(patch.to(self.device)) - loss = self.loss_func( - outputs, *[a.to(self.device) for a in auxillary], self.device - ) - avg_loss.update(loss.detach(), patch.shape[0]) - return {"loss": avg_loss.avg.to(torch.float16).cpu().numpy()} - - def predict( - self, - input: Union[np.ndarray, str, Path], - *, - tile_shape: Optional[List[int]] = None, - overlaps: Optional[List[int]] = None, - axes: Optional[str] = None, - tta: bool = True, - ) -> Union[np.ndarray, List[np.ndarray]]: - """ - Predict using the current model on an input array or a path to data. - - The Engine must have previously been trained and mean/std be specified in - its configuration. - - Data should be compatible with the axes, either from the configuration or - as passed using the `axes` parameter. If the batch and channel dimensions are - missing, then singleton dimensions are added. - - To use tiling, both `tile_shape` and `overlaps` must be specified, have same - length, be divisible by 2 and greater than 0. Finally, the overlaps must be - smaller than the tiles. - - By setting `tta` to `True`, the prediction is performed using test time - augmentation, meaning that the input is augmented and the prediction is averaged - over the augmentations. - - Parameters - ---------- - input : Union[np.ndarra, str, Path] - Input data, either an array or a path to the data. - tile_shape : Optional[List[int]], optional - 2D or 3D shape of the tiles to be predicted, by default None. - overlaps : Optional[List[int]], optional - 2D or 3D overlaps between tiles, by default None. - axes : Optional[str], optional - Axes of the input array if different from the one in the configuration, by - default None. - tta : bool, optional - Whether to use test time augmentation, by default True. - - Returns - ------- - Union[np.ndarray, List[np.ndarray]] - Predicted image array of the same shape as the input, or list of arrays - if the arrays have inconsistent shapes. - - Raises - ------ - ValueError - If the configuration is missing. - ValueError - If the mean or std are not specified in the configuration (untrained model). - """ - if self.cfg is None: - raise ValueError("Configuration is not defined, cannot predict.") - - # Check that the mean and std are there (= has been trained) - if not self.cfg.data.mean or not self.cfg.data.std: - raise ValueError( - "Mean or std are not specified in the configuration, prediction cannot " - "be performed." - ) - - # set model to eval mode - self.model.to(self.device) - self.model.eval() - - progress_bar = ProgressBar(num_epochs=1, mode="predict") - - # Get dataloader - pred_loader, tiled = self._get_predict_dataloader( - input=input, tile_shape=tile_shape, overlaps=overlaps, axes=axes - ) - - # Start prediction - self.logger.info("Starting prediction") - if tiled: - self.logger.info("Starting tiled prediction") - prediction = self._predict_tiled(pred_loader, progress_bar, tta) - else: - self.logger.info("Starting prediction on whole sample") - prediction = self._predict_full(pred_loader, progress_bar, tta) - - return prediction - - def _predict_tiled( - self, pred_loader: DataLoader, progress_bar: ProgressBar, tta: bool = True - ) -> Union[np.ndarray, List[np.ndarray]]: - """ - Predict using tiling. - - Parameters - ---------- - pred_loader : DataLoader - Prediction dataloader. - progress_bar : ProgressBar - Progress bar. - tta : bool, optional - Whether to use test time augmentation, by default True. - - Returns - ------- - Union[np.ndarray, List[np.ndarray]] - Predicted image, or list of predictions if the images have different sizes. - - Warns - ----- - UserWarning - If the samples have different shapes, the prediction then returns a list. - """ - # checks are done here to satisfy mypy - # check that configuration exists - if self.cfg is None: - raise ValueError("Configuration is not defined, cannot predict.") - - # Check that the mean and std are there (= has been trained) - if not self.cfg.data.mean or not self.cfg.data.std: - raise ValueError( - "Mean or std are not specified in the configuration, prediction cannot " - "be performed." - ) - - prediction = [] - tiles = [] - stitching_data = [] - - with torch.no_grad(): - for i, (tile, *auxillary) in enumerate(pred_loader): - # Unpack auxillary data into last tile indicator and data, required to - # stitch tiles together - if auxillary: - last_tile, *data = auxillary - - if tta: - augmented_tiles = tta_forward(tile) - predicted_augments = [] - for augmented_tile in augmented_tiles: - augmented_pred = self.model(augmented_tile.to(self.device)) - predicted_augments.append(augmented_pred.cpu()) - tiles.append(tta_backward(predicted_augments).squeeze()) - else: - tiles.append( - self.model(tile.to(self.device)).squeeze().cpu().numpy() - ) - - stitching_data.append(data) - - if last_tile: - # Stitch tiles together if sample is finished - predicted_sample = stitch_prediction(tiles, stitching_data) - predicted_sample = denormalize( - predicted_sample, - float(self.cfg.data.mean), - float(self.cfg.data.std), - ) - prediction.append(predicted_sample) - tiles.clear() - stitching_data.clear() - - progress_bar.update(i, 1) - if tta: - i = int(i / 8) - self.logger.info(f"Predicted {len(prediction)} samples, {i} tiles in total") - try: - return np.stack(prediction) - except ValueError: - self.logger.warning("Samples have different shapes, returning list.") - return prediction - - def _predict_full( - self, pred_loader: DataLoader, progress_bar: ProgressBar, tta: bool = True - ) -> np.ndarray: - """ - Predict whole image without tiling. - - Parameters - ---------- - pred_loader : DataLoader - Prediction dataloader. - progress_bar : ProgressBar - Progress bar. - tta : bool, optional - Whether to use test time augmentation, by default True. - - Returns - ------- - np.ndarray - Predicted image. - """ - # checks are done here to satisfy mypy - # check that configuration exists - if self.cfg is None: - raise ValueError("Configuration is not defined, cannot predict.") - - # Check that the mean and std are there (= has been trained) - if not self.cfg.data.mean or not self.cfg.data.std: - raise ValueError( - "Mean or std are not specified in the configuration, prediction cannot " - "be performed." - ) - - prediction = [] - with torch.no_grad(): - for i, sample in enumerate(pred_loader): - if tta: - augmented_preds = tta_forward(sample[0]) - predicted_augments = [] - for augmented_pred in augmented_preds: - augmented_pred = self.model(augmented_pred.to(self.device)) - predicted_augments.append(augmented_pred.cpu()) - prediction.append(tta_backward(predicted_augments).squeeze()) - else: - prediction.append( - self.model(sample[0].to(self.device)).squeeze().cpu().numpy() - ) - progress_bar.update(i, 1) - output = denormalize( - np.stack(prediction).squeeze(), - float(self.cfg.data.mean), - float(self.cfg.data.std), - ) - return output - - def _get_train_dataloader(self, train_path: str) -> DataLoader: - """ - Return a training dataloader. - - Parameters - ---------- - train_path : str - Path to the training data. - - Returns - ------- - DataLoader - Training data loader. - - Raises - ------ - ValueError - If the training configuration is None. - """ - if self.cfg is None: - raise ValueError("Configuration is not defined.") - - dataset = get_train_dataset(self.cfg, train_path) - dataloader = DataLoader( - dataset, - batch_size=self.cfg.training.batch_size, - num_workers=self.cfg.training.num_workers, - pin_memory=True, - ) - return dataloader - - def _get_val_dataloader(self, val_path: str) -> DataLoader: - """ - Return a validation dataloader. - - Parameters - ---------- - val_path : str - Path to the validation data. - - Returns - ------- - DataLoader - Validation data loader. - - Raises - ------ - ValueError - If the configuration is None. - """ - if self.cfg is None: - raise ValueError("Configuration is not defined.") - - dataset = get_validation_dataset(self.cfg, val_path) - dataloader = DataLoader( - dataset, - batch_size=self.cfg.training.batch_size, - num_workers=self.cfg.training.num_workers, - pin_memory=True, - ) - return dataloader - - def _get_predict_dataloader( - self, - input: Union[np.ndarray, str, Path], - *, - tile_shape: Optional[List[int]] = None, - overlaps: Optional[List[int]] = None, - axes: Optional[str] = None, - ) -> Tuple[DataLoader, bool]: - """ - Return a prediction dataloader. - - Parameters - ---------- - input : Union[np.ndarray, str, Path] - Input array or path to data. - tile_shape : Optional[List[int]], optional - 2D or 3D shape of the tiles, by default None. - overlaps : Optional[List[int]], optional - 2D or 3D overlaps between tiles, by default None. - axes : Optional[str], optional - Axes of the input array if different from the one in the configuration. - - Returns - ------- - Tuple[DataLoader, bool] - Tuple of prediction data loader, and whether the data is tiled. - - Raises - ------ - ValueError - If the configuration is None. - ValueError - If the mean or std are not specified in the configuration. - ValueError - If the input is None. - """ - if self.cfg is None: - raise ValueError("Configuration is not defined.") - - if self.cfg.data.mean is None or self.cfg.data.std is None: - raise ValueError( - "Mean or std are not specified in the configuration, prediction cannot " - "be performed. Was the model trained?" - ) - - if input is None: - raise ValueError("Input cannot be None.") - - # Create dataset - if isinstance(input, np.ndarray): # np.ndarray - # Validate axes and add missing dimensions (S)C if necessary - img_axes = self.cfg.data.axes if axes is None else axes - input_expanded = add_axes(input, img_axes) - - # Check if tiling requested - tiled = tile_shape is not None and overlaps is not None - - # Validate tiles and overlaps - if tiled: - raise NotImplementedError( - "Tiling with in memory array is currently not implemented." - ) - - # Normalize input and cast to float32 - normalized_input = normalize( - img=input_expanded, mean=self.cfg.data.mean, std=self.cfg.data.std - ) - normalized_input = normalized_input.astype(np.float32) - - # Create dataset - dataset = TensorDataset(torch.from_numpy(normalized_input)) - - elif isinstance(input, str) or isinstance(input, Path): # path - # Create dataset - dataset = get_prediction_dataset( - self.cfg, - pred_path=input, - tile_shape=tile_shape, - overlaps=overlaps, - axes=axes, - ) - - tiled = ( - hasattr(dataset, "patch_extraction_method") - and dataset.patch_extraction_method is not None - ) - return ( - DataLoader( - dataset, - batch_size=1, - num_workers=0, - pin_memory=True, - ), - tiled, - ) - - def _save_checkpoint( - self, epoch: int, losses: List[float], save_method: str - ) -> Path: - """ - Save checkpoint. - - Currently only supports saving using `save_method="state_dict"`. - - Parameters - ---------- - epoch : int - Last epoch. - losses : List[float] - List of losses. - save_method : str - Method to save the model. Currently only supports `state_dict`. - - Returns - ------- - Path - Path to the saved checkpoint. - - Raises - ------ - ValueError - If the configuration is None. - NotImplementedError - If the requested save method is not supported. - """ - if self.cfg is None: - raise ValueError("Configuration is not defined.") - - if epoch == 0 or losses[-1] == min(losses): - name = f"{self.cfg.experiment_name}_best.pth" - else: - name = f"{self.cfg.experiment_name}_latest.pth" - workdir = self.cfg.working_directory - workdir.mkdir(parents=True, exist_ok=True) - - if save_method == "state_dict": - checkpoint = { - "epoch": epoch, - "model_state_dict": self.model.state_dict(), - "optimizer_state_dict": self.optimizer.state_dict(), - "scheduler_state_dict": self.lr_scheduler.state_dict(), - "grad_scaler_state_dict": self.scaler.state_dict(), - "loss": losses[-1], - "config": self.cfg.model_dump(), - } - torch.save(checkpoint, workdir / name) - else: - raise NotImplementedError("Invalid save method.") - - return self.cfg.working_directory.absolute() / name - - def __del__(self) -> None: - """Exit logger.""" - if hasattr(self, "logger"): - for handler in self.logger.handlers: - if isinstance(handler, FileHandler): - self.logger.removeHandler(handler) - handler.close() - - def _get_sample_io_files( - self, - input_array: Optional[np.ndarray] = None, - axes: Optional[str] = None, - ) -> Tuple[List[str], List[str]]: - """ - Create numpy format for use as inputs and outputs in the bioimage.io archive. - - Parameters - ---------- - input_array : Optional[np.ndarray], optional - Input array to use for the bioimage.io model zoo, by default None. - axes : Optional[str], optional - Axes from the configuration. - - Returns - ------- - Tuple[List[str], List[str]] - Tuple of input and output file paths. - - Raises - ------ - ValueError - If the configuration is not defined. - """ - if self.cfg is not None and self._input is not None: - # use the input array if provided, otherwise use the first validation sample - if input_array is not None: - array_in = input_array - - # add axes to be compatible with the axes declared in the RDF specs - add_axes(array_in, axes) - else: - array_in = self._input - - # predict (no tta since BMZ does not apply it) - array_out = self.predict(array_in, tta=False) - - # add singleton dimensions (for compatibility with model axes) - # indeed, BMZ applies the model but CAREamics function are meant - # to work on user data (potentially with no S or C axe) - array_out = array_out[np.newaxis, np.newaxis, ...] - - # save numpy files - workdir = self.cfg.working_directory - in_file = workdir.joinpath("test_inputs.npy") - np.save(in_file, array_in) - out_file = workdir.joinpath("test_outputs.npy") - np.save(out_file, array_out) - - return [str(in_file.absolute())], [str(out_file.absolute())] - else: - raise ValueError("Configuration is not defined or model was not trained.") - - def _generate_rdf( - self, - *, - model_specs: Optional[dict] = None, - input_array: Optional[np.ndarray] = None, - ) -> dict: - """ - Generate rdf data for bioimage.io format export. - - Parameters - ---------- - model_specs : Optional[dict], optional - Custom specs if different than the default ones, by default None. - input_array : Optional[np.ndarray], optional - Input array to use for the bioimage.io model zoo, by default None. - - Returns - ------- - dict - RDF specs. - - Raises - ------ - ValueError - If the mean or std are not specified in the configuration. - ValueError - If the configuration is not defined. - """ - if self.cfg is not None: - if self.cfg.data.mean is None or self.cfg.data.std is None: - raise ValueError( - "Mean or std are not specified in the configuration, export to " - "bioimage.io format is not possible." - ) - - # set in/out axes from config - axes = self.cfg.data.axes.lower().replace("s", "") - if "c" not in axes: - axes = "c" + axes - if "b" not in axes: - axes = "b" + axes - - # get in/out samples' files - test_inputs, test_outputs = self._get_sample_io_files( - input_array, self.cfg.data.axes - ) - - specs = get_default_model_specs( - "Noise2Void", - self.cfg.data.mean, - self.cfg.data.std, - self.cfg.algorithm.is_3D, - ) - if model_specs is not None: - specs.update(model_specs) - - specs.update( - { - "test_inputs": test_inputs, - "test_outputs": test_outputs, - "input_axes": [axes], - "output_axes": [axes], - } - ) - return specs - else: - raise ValueError("Configuration is not defined or model was not trained.") - - def save_as_bioimage( - self, - output_zip: Union[Path, str], - model_specs: Optional[dict] = None, - input_array: Optional[np.ndarray] = None, - ) -> None: - """ - Export the current model to BioImage.io model zoo format. - - Custom specs can be passed in `model_specs (e.g. maintainers). For a description - of the model RDF, refer to - github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/model_spec_latest.md. - - Parameters - ---------- - output_zip : Union[Path, str] - Where to save the model zip file. - model_specs : Optional[dict] - A dictionary with keys being the bioimage-core build_model parameters. If - None then it will be populated by the model default specs. - input_array : Optional[np.ndarray] - An array to use as input for the bioimage.io model zoo. If None then the - first validation sample will be used. Note that the array must have S and - C dimensions (e.g. SCYX), even if only singleton dimensions. - - Raises - ------ - ValueError - If the configuration is not defined. - """ - if self.cfg is not None: - # Generate specs - specs = self._generate_rdf(model_specs=model_specs, input_array=input_array) - - # Build model - save_bioimage_model( - path=output_zip, - config=self.cfg, - specs=specs, - ) - else: - raise ValueError("Configuration is not defined.") diff --git a/src/careamics/lightning_datamodule.py b/src/careamics/lightning_datamodule.py new file mode 100644 index 00000000..56f48d47 --- /dev/null +++ b/src/careamics/lightning_datamodule.py @@ -0,0 +1,665 @@ +from pathlib import Path +from typing import Any, Callable, Dict, List, Literal, Optional, Union + +import numpy as np +import pytorch_lightning as L +from albumentations import Compose +from torch.utils.data import DataLoader + +from careamics.config import DataModel +from careamics.config.data_model import TRANSFORMS_UNION +from careamics.config.support import SupportedData +from careamics.dataset.dataset_utils import ( + get_files_size, + get_read_func, + list_files, + validate_source_target_files, +) +from careamics.dataset.in_memory_dataset import ( + InMemoryDataset, +) +from careamics.dataset.iterable_dataset import ( + PathIterableDataset, +) +from careamics.utils import get_logger, get_ram_size + +DatasetType = Union[InMemoryDataset, PathIterableDataset] + +logger = get_logger(__name__) + + +class CAREamicsWood(L.LightningDataModule): + """ + LightningDataModule for training and validation datasets. + + The data module can be used with Path, str or numpy arrays. In the case of + numpy arrays, it loads and computes all the patches in memory. For Path and str + inputs, it calculates the total file size and estimate whether it can fit in + memory. If it does not, it iterates through the files. This behaviour can be + deactivated by setting `use_in_memory` to False, in which case it will + always use the iterating dataset to train on a Path or str. + + The data can be either a folder containing images or a single file. + + Validation can be omitted, in which case the validation data is extracted from + the training data. The percentage of the training data to use for validation, + as well as the minimum number of patches or files to split from the training + data can be set using `val_percentage` and `val_minimum_split`, respectively. + + To read custom data types, you can set `data_type` to `custom` in `data_config` + and provide a function that returns a numpy array from a path as + `read_source_func` parameter. The function will receive a Path object and + an axies string as arguments, the axes being derived from the `data_config`. + + You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g. + "*.czi") to filter the files extension using `extension_filter`. + """ + + def __init__( + self, + data_config: DataModel, + train_data: Union[Path, str, np.ndarray], + val_data: Optional[Union[Path, str, np.ndarray]] = None, + train_data_target: Optional[Union[Path, str, np.ndarray]] = None, + val_data_target: Optional[Union[Path, str, np.ndarray]] = None, + read_source_func: Optional[Callable] = None, + extension_filter: str = "", + val_percentage: float = 0.1, + val_minimum_split: int = 5, + use_in_memory: bool = True, + ) -> None: + """ + Constructor. + + Parameters + ---------- + data_config : DataModel + Pydantic model for CAREamics data configuration. + train_data : Union[Path, str, np.ndarray] + Training data, can be a path to a folder, a file or a numpy array. + val_data : Optional[Union[Path, str, np.ndarray]], optional + Validation data, can be a path to a folder, a file or a numpy array, by + default None. + train_data_target : Optional[Union[Path, str, np.ndarray]], optional + Training target data, can be a path to a folder, a file or a numpy array, by + default None. + val_data_target : Optional[Union[Path, str, np.ndarray]], optional + Validation target data, can be a path to a folder, a file or a numpy array, + by default None. + read_source_func : Optional[Callable], optional + Function to read the source data, by default None. Only used for `custom` + data type (see DataModel). + extension_filter : str, optional + Filter for file extensions, by default "". Only used for `custom` data types + (see DataModel). + val_percentage : float, optional + Percentage of the training data to use for validation, by default 0.1. Only + used if `val_data` is None. + val_minimum_split : int, optional + Minimum number of patches or files to split from the training data for + validation, by default 5. Only used if `val_data` is None. + + Raises + ------ + NotImplementedError + Raised if target data is provided. + ValueError + If the input types are mixed (e.g. Path and np.ndarray). + ValueError + If the data type is `custom` and no `read_source_func` is provided. + ValueError + If the data type is `array` and the input is not a numpy array. + ValueError + If the data type is `tiff` and the input is neither a Path nor a str. + """ + super().__init__() + + # check input types coherence (no mixed types) + inputs = [train_data, val_data, train_data_target, val_data_target] + types_set = {type(i) for i in inputs} + if len(types_set) > 2: # None + expected type + raise ValueError( + f"Inputs for `train_data`, `val_data`, `train_data_target` and " + f"`val_data_target` must be of the same type or None. Got " + f"{types_set}." + ) + + # check that a read source function is provided for custom types + if data_config.data_type == SupportedData.CUSTOM and read_source_func is None: + raise ValueError( + f"Data type {SupportedData.CUSTOM} is not allowed without " + f"specifying a `read_source_func`." + ) + + # and that arrays are passed, if array type specified + elif data_config.data_type == SupportedData.ARRAY and not isinstance( + train_data, np.ndarray + ): + raise ValueError( + f"Expected array input (see configuration.data.data_type), but got " + f"{type(train_data)} instead." + ) + + # and that Path or str are passed, if tiff file type specified + elif data_config.data_type == SupportedData.TIFF and ( + not isinstance(train_data, Path) and not isinstance(train_data, str) + ): + raise ValueError( + f"Expected Path or str input (see configuration.data.data_type), " + f"but got {type(train_data)} instead." + ) + + # configuration + self.data_config = data_config + self.data_type = data_config.data_type + self.batch_size = data_config.batch_size + self.use_in_memory = use_in_memory + + # data + self.train_data = train_data + self.val_data = val_data + + self.train_data_target = train_data_target + self.val_data_target = val_data_target + self.val_percentage = val_percentage + self.val_minimum_split = val_minimum_split + + # read source function corresponding to the requested type + if data_config.data_type == SupportedData.CUSTOM: + # mypy check + assert read_source_func is not None + + self.read_source_func: Callable = read_source_func + + elif data_config.data_type != SupportedData.ARRAY: + self.read_source_func = get_read_func(data_config.data_type) + + self.extension_filter = extension_filter + + # Pytorch dataloader parameters + self.dataloader_params = ( + data_config.dataloader_params if data_config.dataloader_params else {} + ) + + def prepare_data(self) -> None: + """ + Hook used to prepare the data before calling `setup`. + + Here, we only need to examine the data if it was provided as a str or a Path. + + TODO: from lightning doc: + prepare_data is called from the main process. It is not recommended to assign + state here (e.g. self.x = y) since it is called on a single process and if you + assign states here then they won't be available for other processes. + + https://lightning.ai/docs/pytorch/stable/data/datamodule.html + """ + # if the data is a Path or a str + if ( + not isinstance(self.train_data, np.ndarray) + and not isinstance(self.val_data, np.ndarray) + and not isinstance(self.train_data_target, np.ndarray) + and not isinstance(self.val_data_target, np.ndarray) + ): + # list training files + self.train_files = list_files( + self.train_data, self.data_type, self.extension_filter + ) + self.train_files_size = get_files_size(self.train_files) + + # list validation files + if self.val_data is not None: + self.val_files = list_files( + self.val_data, self.data_type, self.extension_filter + ) + + # same for target data + if self.train_data_target is not None: + self.train_target_files: List[Path] = list_files( + self.train_data_target, self.data_type, self.extension_filter + ) + + # verify that they match the training data + validate_source_target_files(self.train_files, self.train_target_files) + + if self.val_data_target is not None: + self.val_target_files = list_files( + self.val_data_target, self.data_type, self.extension_filter + ) + + # verify that they match the validation data + validate_source_target_files(self.val_files, self.val_target_files) + + def setup(self, *args: Any, **kwargs: Any) -> None: + """Hook called at the beginning of fit, validate, or predict.""" + # if numpy array + if self.data_type == SupportedData.ARRAY: + # train dataset + self.train_dataset: DatasetType = InMemoryDataset( + data_config=self.data_config, + inputs=self.train_data, + data_target=self.train_data_target, + ) + + # validation dataset + if self.val_data is not None: + # create its own dataset + self.val_dataset: DatasetType = InMemoryDataset( + data_config=self.data_config, + inputs=self.val_data, + data_target=self.val_data_target, + ) + else: + # extract validation from the training patches + self.val_dataset = self.train_dataset.split_dataset( + percentage=self.val_percentage, + minimum_patches=self.val_minimum_split, + ) + + # else we read files + else: + # Heuristics, if the file size is smaller than 80% of the RAM, + # we run the training in memory, otherwise we switch to iterable dataset + # The switch is deactivated if use_in_memory is False + if self.use_in_memory and self.train_files_size < get_ram_size() * 0.8: + # train dataset + self.train_dataset = InMemoryDataset( + data_config=self.data_config, + inputs=self.train_files, + data_target=self.train_target_files + if self.train_data_target + else None, + read_source_func=self.read_source_func, + ) + + # validation dataset + if self.val_data is not None: + self.val_dataset = InMemoryDataset( + data_config=self.data_config, + inputs=self.val_files, + data_target=self.val_target_files + if self.val_data_target + else None, + read_source_func=self.read_source_func, + ) + else: + # split dataset + self.val_dataset = self.train_dataset.split_dataset( + percentage=self.val_percentage, + minimum_patches=self.val_minimum_split, + ) + + # else if the data is too large, load file by file during training + else: + # create training dataset + self.train_dataset = PathIterableDataset( + data_config=self.data_config, + src_files=self.train_files, + target_files=self.train_target_files + if self.train_data_target + else None, + read_source_func=self.read_source_func, + ) + + # create validation dataset + if self.val_files is not None: + # create its own dataset + self.val_dataset = PathIterableDataset( + data_config=self.data_config, + src_files=self.val_files, + target_files=self.val_target_files + if self.val_data_target + else None, + read_source_func=self.read_source_func, + ) + elif len(self.train_files) <= self.val_minimum_split: + raise ValueError( + f"Not enough files to split a minimum of " + f"{self.val_minimum_split} files, got {len(self.train_files)} " + f"files." + ) + else: + # extract validation from the training patches + self.val_dataset = self.train_dataset.split_dataset( + percentage=self.val_percentage, + minimum_files=self.val_minimum_split, + ) + + def train_dataloader(self) -> Any: + """ + Create a dataloader for training. + + Returns + ------- + Any + Training dataloader. + """ + return DataLoader( + self.train_dataset, batch_size=self.batch_size, **self.dataloader_params + ) + + def val_dataloader(self) -> Any: + """ + Create a dataloader for validation. + + Returns + ------- + Any + Validation dataloader. + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + ) + + +class CAREamicsTrainDataModule(CAREamicsWood): + """ + LightningDataModule wrapper for training and validation datasets. + + Since the lightning datamodule has no access to the model, make sure that the + parameters passed to the datamodule are consistent with the model's requirements and + are coherent. + + The data module can be used with Path, str or numpy arrays. In the case of + numpy arrays, it loads and computes all the patches in memory. For Path and str + inputs, it calculates the total file size and estimate whether it can fit in + memory. If it does not, it iterates through the files. This behaviour can be + deactivated by setting `use_in_memory` to False, in which case it will + always use the iterating dataset to train on a Path or str. + + To use array data, set `data_type` to `array` and pass a numpy array to + `train_data`. + + In particular, N2V requires a specific transformation (N2V manipulates), which is + not compatible with supervised training. The default transformations applied to the + training patches are defined in `careamics.config.data_model`. To use different + transformations, pass a list of transforms or an albumentation `Compose` as + `transforms` parameter. See examples for more details. + + By default, CAREamics only supports types defined in + `careamics.config.support.SupportedData`. To read custom data types, you can set + `data_type` to `custom` and provide a function that returns a numpy array from a + path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression (e.g. + "*.jpeg") to filter the files extension using `extension_filter`. + + In the absence of validation data, the validation data is extracted from the + training data. The percentage of the training data to use for validation, as well as + the minimum number of patches to split from the training data for validation can be + set using `val_percentage` and `val_minimum_patches`, respectively. + + In `dataloader_params`, you can pass any parameter accepted by PyTorch dataloaders, + except for `batch_size`, which is set by the `batch_size` parameter. + + Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2` to + use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to define + the axis and span of the structN2V mask. These parameters are without effect if + a `train_target_data` or if `transforms` are provided. + + Parameters + ---------- + train_data : Union[str, Path, np.ndarray] + Training data. + data_type : Union[str, SupportedData] + Data type, see `SupportedData` for available options. + patch_size : List[int] + Patch size, 2D or 3D patch size. + axes : str + Axes of the data, choosen amongst SCZYX. + batch_size : int + Batch size. + val_data : Optional[Union[str, Path]], optional + Validation data, by default None. + transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional + List of transforms to apply to training patches. If None, default transforms + are applied. + train_target_data : Optional[Union[str, Path]], optional + Training target data, by default None. + val_target_data : Optional[Union[str, Path]], optional + Validation target data, by default None. + read_source_func : Optional[Callable], optional + Function to read the source data, used if `data_type` is `custom`, by + default None. + extension_filter : str, optional + Filter for file extensions, used if `data_type` is `custom`, by default "". + val_percentage : float, optional + Percentage of the training data to use for validation if no validation data + is given, by default 0.1. + val_minimum_patches : int, optional + Minimum number of patches to split from the training data for validation if + no validation data is given, by default 5. + dataloader_params : dict, optional + Pytorch dataloader parameters, by default {}. + use_in_memory : bool, optional + Use in memory dataset if possible, by default True. + use_n2v2 : bool, optional + Use N2V2 transformation during training, by default False. + struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional + Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by + default "none". + struct_n2v_span : int, optional + Span for the structN2V mask, by default 5. + + Examples + -------- + Create a CAREamicsTrainDataModule with default transforms with a numpy array: + >>> import numpy as np + >>> from careamics import CAREamicsTrainDataModule + >>> my_array = np.arange(256).reshape(16, 16) + >>> data_module = CAREamicsTrainDataModule( + ... train_data=my_array, + ... data_type="array", + ... patch_size=(8, 8), + ... axes='YX', + ... batch_size=2, + ... ) + + For custom data types (those not supported by CAREamics), then one can pass a read + function and a filter for the files extension: + >>> import numpy as np + >>> from careamics import CAREamicsTrainDataModule + >>> + >>> def read_npy(path): + ... return np.load(path) + >>> + >>> data_module = CAREamicsTrainDataModule( + ... train_data="path/to/data", + ... data_type="custom", + ... patch_size=(8, 8), + ... axes='YX', + ... batch_size=2, + ... read_source_func=read_npy, + ... extension_filter="*.npy", + ... ) + + If you want to use a different set of transformations, you can pass a list of + transforms: + >>> import numpy as np + >>> from careamics import CAREamicsTrainDataModule + >>> from careamics.config.support import SupportedTransform + >>> my_array = np.arange(256).reshape(16, 16) + >>> my_transforms = [ + ... { + ... "name": SupportedTransform.NORMALIZE.value, + ... "mean": 0, + ... "std": 1, + ... }, + ... { + ... "name": SupportedTransform.N2V_MANIPULATE.value, + ... } + ... ] + >>> data_module = CAREamicsTrainDataModule( + ... train_data=my_array, + ... data_type="array", + ... patch_size=(8, 8), + ... axes='YX', + ... batch_size=2, + ... transforms=my_transforms, + ... ) + """ + + def __init__( + self, + train_data: Union[str, Path, np.ndarray], + data_type: Union[Literal["array", "tiff", "custom"], SupportedData], + patch_size: List[int], + axes: str, + batch_size: int, + val_data: Optional[Union[str, Path]] = None, + transforms: Optional[Union[List[TRANSFORMS_UNION], Compose]] = None, + train_target_data: Optional[Union[str, Path]] = None, + val_target_data: Optional[Union[str, Path]] = None, + read_source_func: Optional[Callable] = None, + extension_filter: str = "", + val_percentage: float = 0.1, + val_minimum_patches: int = 5, + dataloader_params: Optional[dict] = None, + use_in_memory: bool = True, + use_n2v2: bool = False, + struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none", + struct_n2v_span: int = 5, + ) -> None: + """ + LightningDataModule wrapper for training and validation datasets. + + Since the lightning datamodule has no access to the model, make sure that the + parameters passed to the datamodule are consistent with the model's requirements + and are coherent. + + The data module can be used with Path, str or numpy arrays. In the case of + numpy arrays, it loads and computes all the patches in memory. For Path and str + inputs, it calculates the total file size and estimate whether it can fit in + memory. If it does not, it iterates through the files. This behaviour can be + deactivated by setting `use_in_memory` to False, in which case it will + always use the iterating dataset to train on a Path or str. + + To use array data, set `data_type` to `array` and pass a numpy array to + `train_data`. + + In particular, N2V requires a specific transformation (N2V manipulates), which + is not compatible with supervised training. The default transformations applied + to the training patches are defined in `careamics.config.data_model`. To use + different transformations, pass a list of transforms or an albumentation + `Compose` as `transforms` parameter. See examples for more details. + + By default, CAREamics only supports types defined in + `careamics.config.support.SupportedData`. To read custom data types, you can set + `data_type` to `custom` and provide a function that returns a numpy array from a + path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression + (e.g. "*.jpeg") to filter the files extension using `extension_filter`. + + In the absence of validation data, the validation data is extracted from the + training data. The percentage of the training data to use for validation, as + well as the minimum number of patches to split from the training data for + validation can be set using `val_percentage` and `val_minimum_patches`, + respectively. + + In `dataloader_params`, you can pass any parameter accepted by PyTorch + dataloaders, except for `batch_size`, which is set by the `batch_size` + parameter. + + Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2` + to use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to + define the axis and span of the structN2V mask. These parameters are without + effect if a `train_target_data` or if `transforms` are provided. + + Parameters + ---------- + train_data : Union[str, Path, np.ndarray] + Training data. + data_type : Union[str, SupportedData] + Data type, see `SupportedData` for available options. + patch_size : List[int] + Patch size, 2D or 3D patch size. + axes : str + Axes of the data, choosen amongst SCZYX. + batch_size : int + Batch size. + val_data : Optional[Union[str, Path]], optional + Validation data, by default None. + transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional + List of transforms to apply to training patches. If None, default transforms + are applied. + train_target_data : Optional[Union[str, Path]], optional + Training target data, by default None. + val_target_data : Optional[Union[str, Path]], optional + Validation target data, by default None. + read_source_func : Optional[Callable], optional + Function to read the source data, used if `data_type` is `custom`, by + default None. + extension_filter : str, optional + Filter for file extensions, used if `data_type` is `custom`, by default "". + val_percentage : float, optional + Percentage of the training data to use for validation if no validation data + is given, by default 0.1. + val_minimum_patches : int, optional + Minimum number of patches to split from the training data for validation if + no validation data is given, by default 5. + dataloader_params : dict, optional + Pytorch dataloader parameters, by default {}. + use_in_memory : bool, optional + Use in memory dataset if possible, by default True. + use_n2v2 : bool, optional + Use N2V2 transformation during training, by default False. + struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional + Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by + default "none". + struct_n2v_span : int, optional + Span for the structN2V mask, by default 5. + + Raises + ------ + ValueError + If a target is set and N2V manipulation is present in the transforms. + """ + if dataloader_params is None: + dataloader_params = {} + data_dict: Dict[str, Any] = { + "mode": "train", + "data_type": data_type, + "patch_size": patch_size, + "axes": axes, + "batch_size": batch_size, + "dataloader_params": dataloader_params, + } + + # if transforms are passed (otherwise it will use the default ones) + if transforms is not None: + data_dict["transforms"] = transforms + + # validate configuration + self.data_config = DataModel(**data_dict) + + # N2V specific checks, N2V, structN2V, and transforms + if ( + self.data_config.has_transform_list() + and self.data_config.has_n2v_manipulate() + ): + # there is not target, n2v2 and structN2V can be changed + if train_target_data is None: + self.data_config.set_N2V2(use_n2v2) + self.data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span) + else: + raise ValueError( + "Cannot have both supervised training (target data) and " + "N2V manipulation in the transforms. Pass a list of transforms " + "that is compatible with your supervised training." + ) + + # sanity check on the dataloader parameters + if "batch_size" in dataloader_params: + # remove it + del dataloader_params["batch_size"] + + super().__init__( + data_config=self.data_config, + train_data=train_data, + val_data=val_data, + train_data_target=train_target_data, + val_data_target=val_target_data, + read_source_func=read_source_func, + extension_filter=extension_filter, + val_percentage=val_percentage, + val_minimum_split=val_minimum_patches, + use_in_memory=use_in_memory, + ) diff --git a/src/careamics/lightning_module.py b/src/careamics/lightning_module.py new file mode 100644 index 00000000..6e3c4f52 --- /dev/null +++ b/src/careamics/lightning_module.py @@ -0,0 +1,292 @@ +from typing import Any, Optional, Union + +import pytorch_lightning as L +from torch import Tensor, nn + +from careamics.config import AlgorithmModel +from careamics.config.support import ( + SupportedAlgorithm, + SupportedArchitecture, + SupportedLoss, + SupportedOptimizer, + SupportedScheduler, +) +from careamics.losses import loss_factory +from careamics.models.model_factory import model_factory +from careamics.transforms import Denormalize, ImageRestorationTTA +from careamics.utils.torch_utils import get_optimizer, get_scheduler + + +class CAREamicsKiln(L.LightningModule): + """ + CAREamics Lightning module. + + This class encapsulates the a PyTorch model along with the training, validation, + and testing logic. It is configured using an `AlgorithmModel` Pydantic class. + + Attributes + ---------- + model : nn.Module + PyTorch model. + loss_func : nn.Module + Loss function. + optimizer_name : str + Optimizer name. + optimizer_params : dict + Optimizer parameters. + lr_scheduler_name : str + Learning rate scheduler name. + """ + + def __init__(self, algorithm_config: Union[AlgorithmModel, dict]) -> None: + """ + CAREamics Lightning module. + + This class encapsulates the a PyTorch model along with the training, validation, + and testing logic. It is configured using an `AlgorithmModel` Pydantic class. + + Parameters + ---------- + algorithm_config : Union[AlgorithmModel, dict] + Algorithm configuration. + """ + super().__init__() + # if loading from a checkpoint, AlgorithmModel needs to be instantiated + if isinstance(algorithm_config, dict): + algorithm_config = AlgorithmModel(**algorithm_config) + + # create model and loss function + self.model: nn.Module = model_factory(algorithm_config.model) + self.loss_func = loss_factory(algorithm_config.loss) + + # save optimizer and lr_scheduler names and parameters + self.optimizer_name = algorithm_config.optimizer.name + self.optimizer_params = algorithm_config.optimizer.parameters + self.lr_scheduler_name = algorithm_config.lr_scheduler.name + self.lr_scheduler_params = algorithm_config.lr_scheduler.parameters + + def forward(self, x: Any) -> Any: + """Forward pass. + + Parameters + ---------- + x : Any + Input tensor. + + Returns + ------- + Any + Output tensor. + """ + return self.model(x) + + def training_step(self, batch: Tensor, batch_idx: Any) -> Any: + """Training step. + + Parameters + ---------- + batch : Tensor + Input batch. + batch_idx : Any + Batch index. + + Returns + ------- + Any + Loss value. + """ + x, *aux = batch + out = self.model(x) + loss = self.loss_func(out, *aux) + self.log( + "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + return loss + + def validation_step(self, batch: Tensor, batch_idx: Any) -> None: + """Validation step. + + Parameters + ---------- + batch : Tensor + Input batch. + batch_idx : Any + Batch index. + """ + x, *aux = batch + out = self.model(x) + val_loss = self.loss_func(out, *aux) + + # log validation loss + self.log( + "val_loss", + val_loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: + """Prediction step. + + Parameters + ---------- + batch : Tensor + Input batch. + batch_idx : Any + Batch index. + + Returns + ------- + Any + Model output. + """ + x, *aux = batch + + # apply test-time augmentation if available + # TODO: probably wont work with batch size > 1 + if self._trainer.datamodule.prediction_config.tta_transforms: + tta = ImageRestorationTTA() + augmented_batch = tta.forward(batch[0]) # list of augmented tensors + augmented_output = [] + for augmented in augmented_batch: + augmented_pred = self.model(augmented) + augmented_output.append(augmented_pred) + output = tta.backward(augmented_output) + else: + output = self.model(x) + + # Denormalize the output + denorm = Denormalize( + mean=self._trainer.datamodule.predict_dataset.mean, + std=self._trainer.datamodule.predict_dataset.std, + ) + denormalized_output = denorm(image=output)["image"] + + if len(aux) > 0: + return denormalized_output, aux + else: + return denormalized_output + + def configure_optimizers(self) -> Any: + """Configure optimizers and learning rate schedulers. + + Returns + ------- + Any + Optimizer and learning rate scheduler. + """ + # instantiate optimizer + optimizer_func = get_optimizer(self.optimizer_name) + optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params) + + # and scheduler + scheduler_func = get_scheduler(self.lr_scheduler_name) + scheduler = scheduler_func(optimizer, **self.lr_scheduler_params) + + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": "val_loss", # otherwise triggers MisconfigurationException + } + + +class CAREamicsModule(CAREamicsKiln): + """Class defining the API for CAREamics Lightning layer. + + This class exposes parameters used to create an AlgorithmModel instance, triggering + parameters validation. + + Parameters + ---------- + algorithm : Union[SupportedAlgorithm, str] + Algorithm to use for training (see SupportedAlgorithm). + loss : Union[SupportedLoss, str] + Loss function to use for training (see SupportedLoss). + architecture : Union[SupportedArchitecture, str] + Model architecture to use for training (see SupportedArchitecture). + model_parameters : dict, optional + Model parameters to use for training, by default {}. Model parameters are + defined in the relevant `torch.nn.Module` class, or Pyddantic model (see + `careamics.config.architectures`). + optimizer : Union[SupportedOptimizer, str], optional + Optimizer to use for training, by default "Adam" (see SupportedOptimizer). + optimizer_parameters : dict, optional + Optimizer parameters to use for training, as defined in `torch.optim`, by + default {}. + lr_scheduler : Union[SupportedScheduler, str], optional + Learning rate scheduler to use for training, by default "ReduceLROnPlateau" + (see SupportedScheduler). + lr_scheduler_parameters : dict, optional + Learning rate scheduler parameters to use for training, as defined in + `torch.optim`, by default {}. + """ + + def __init__( + self, + algorithm: Union[SupportedAlgorithm, str], + loss: Union[SupportedLoss, str], + architecture: Union[SupportedArchitecture, str], + model_parameters: Optional[dict] = None, + optimizer: Union[SupportedOptimizer, str] = "Adam", + optimizer_parameters: Optional[dict] = None, + lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau", + lr_scheduler_parameters: Optional[dict] = None, + ) -> None: + """ + Wrapper for the CAREamics model, exposing all algorithm configuration arguments. + + Parameters + ---------- + algorithm : Union[SupportedAlgorithm, str] + Algorithm to use for training (see SupportedAlgorithm). + loss : Union[SupportedLoss, str] + Loss function to use for training (see SupportedLoss). + architecture : Union[SupportedArchitecture, str] + Model architecture to use for training (see SupportedArchitecture). + model_parameters : dict, optional + Model parameters to use for training, by default {}. Model parameters are + defined in the relevant `torch.nn.Module` class, or Pyddantic model (see + `careamics.config.architectures`). + optimizer : Union[SupportedOptimizer, str], optional + Optimizer to use for training, by default "Adam" (see SupportedOptimizer). + optimizer_parameters : dict, optional + Optimizer parameters to use for training, as defined in `torch.optim`, by + default {}. + lr_scheduler : Union[SupportedScheduler, str], optional + Learning rate scheduler to use for training, by default "ReduceLROnPlateau" + (see SupportedScheduler). + lr_scheduler_parameters : dict, optional + Learning rate scheduler parameters to use for training, as defined in + `torch.optim`, by default {}. + """ + # create a AlgorithmModel compatible dictionary + if lr_scheduler_parameters is None: + lr_scheduler_parameters = {} + if optimizer_parameters is None: + optimizer_parameters = {} + if model_parameters is None: + model_parameters = {} + algorithm_configuration = { + "algorithm": algorithm, + "loss": loss, + "optimizer": { + "name": optimizer, + "parameters": optimizer_parameters, + }, + "lr_scheduler": { + "name": lr_scheduler, + "parameters": lr_scheduler_parameters, + }, + } + model_configuration = {"architecture": architecture} + model_configuration.update(model_parameters) + + # add model parameters to algorithm configuration + algorithm_configuration["model"] = model_configuration + + # call the parent init using an AlgorithmModel instance + super().__init__(AlgorithmModel(**algorithm_configuration)) + + # TODO add load_from_checkpoint wrapper diff --git a/src/careamics/lightning_prediction_datamodule.py b/src/careamics/lightning_prediction_datamodule.py new file mode 100644 index 00000000..654d44fd --- /dev/null +++ b/src/careamics/lightning_prediction_datamodule.py @@ -0,0 +1,390 @@ +from pathlib import Path +from typing import Any, Callable, List, Literal, Optional, Tuple, Union + +import numpy as np +import pytorch_lightning as L +from albumentations import Compose +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate + +from careamics.config import InferenceModel +from careamics.config.support import SupportedData +from careamics.config.tile_information import TileInformation +from careamics.dataset.dataset_utils import ( + get_read_func, + list_files, +) +from careamics.dataset.in_memory_dataset import ( + InMemoryPredictionDataset, +) +from careamics.dataset.iterable_dataset import ( + IterablePredictionDataset, +) +from careamics.utils import get_logger + +PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset] + +logger = get_logger(__name__) + + +def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any: + """ + Collate tiles received from CAREamics prediction dataloader. + + CAREamics prediction dataloader returns tuples of arrays and TileInformation. In + case of non-tiled data, this function will return the arrays. In case of tiled data, + it will return the arrays, the last tile flag, the overlap crop coordinates and the + stitch coordinates. + + Parameters + ---------- + batch : Tuple[Tuple[np.ndarray, TileInformation], ...] + Batch of tiles. + + Returns + ------- + Any + Collated batch. + """ + first_tile_info: TileInformation = batch[0][1] + # if not tiled, then return arrays + if not first_tile_info.tiled: + arrays, _ = zip(*batch) + + return default_collate(arrays) + # else we explicit the last_tile flag and coordinates + else: + new_batch = [ + (tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords) + for tile, t in batch + ] + + return default_collate(new_batch) + + +class CAREamicsClay(L.LightningDataModule): + """ + LightningDataModule for prediction dataset. + + The data module can be used with Path, str or numpy arrays. The data can be either + a folder containing images or a single file. + + To read custom data types, you can set `data_type` to `custom` in `data_config` + and provide a function that returns a numpy array from a path as + `read_source_func` parameter. The function will receive a Path object and + an axies string as arguments, the axes being derived from the `data_config`. + + You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g. + "*.czi") to filter the files extension using `extension_filter`. + + Parameters + ---------- + prediction_config : InferenceModel + Pydantic model for CAREamics prediction configuration. + pred_data : Union[Path, str, np.ndarray] + Prediction data, can be a path to a folder, a file or a numpy array. + read_source_func : Optional[Callable], optional + Function to read custom types, by default None. + extension_filter : str, optional + Filter to filter file extensions for custom types, by default "". + dataloader_params : dict, optional + Dataloader parameters, by default {}. + """ + + def __init__( + self, + prediction_config: InferenceModel, + pred_data: Union[Path, str, np.ndarray], + read_source_func: Optional[Callable] = None, + extension_filter: str = "", + dataloader_params: Optional[dict] = None, + ) -> None: + """ + Constructor. + + The data module can be used with Path, str or numpy arrays. The data can be + either a folder containing images or a single file. + + To read custom data types, you can set `data_type` to `custom` in `data_config` + and provide a function that returns a numpy array from a path as + `read_source_func` parameter. The function will receive a Path object and + an axies string as arguments, the axes being derived from the `data_config`. + + You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g. + "*.czi") to filter the files extension using `extension_filter`. + + Parameters + ---------- + prediction_config : InferenceModel + Pydantic model for CAREamics prediction configuration. + pred_data : Union[Path, str, np.ndarray] + Prediction data, can be a path to a folder, a file or a numpy array. + read_source_func : Optional[Callable], optional + Function to read custom types, by default None. + extension_filter : str, optional + Filter to filter file extensions for custom types, by default "". + dataloader_params : dict, optional + Dataloader parameters, by default {}. + + Raises + ------ + ValueError + If the data type is `custom` and no `read_source_func` is provided. + ValueError + If the data type is `array` and the input is not a numpy array. + ValueError + If the data type is `tiff` and the input is neither a Path nor a str. + """ + if dataloader_params is None: + dataloader_params = {} + if dataloader_params is None: + dataloader_params = {} + super().__init__() + + # check that a read source function is provided for custom types + if ( + prediction_config.data_type == SupportedData.CUSTOM + and read_source_func is None + ): + raise ValueError( + f"Data type {SupportedData.CUSTOM} is not allowed without " + f"specifying a `read_source_func`." + ) + + # and that arrays are passed, if array type specified + elif prediction_config.data_type == SupportedData.ARRAY and not isinstance( + pred_data, np.ndarray + ): + raise ValueError( + f"Expected array input (see configuration.data.data_type), but got " + f"{type(pred_data)} instead." + ) + + # and that Path or str are passed, if tiff file type specified + elif prediction_config.data_type == SupportedData.TIFF and not ( + isinstance(pred_data, Path) or isinstance(pred_data, str) + ): + raise ValueError( + f"Expected Path or str input (see configuration.data.data_type), " + f"but got {type(pred_data)} instead." + ) + + # configuration data + self.prediction_config = prediction_config + self.data_type = prediction_config.data_type + self.batch_size = prediction_config.batch_size + self.dataloader_params = dataloader_params + + self.pred_data = pred_data + self.tile_size = prediction_config.tile_size + self.tile_overlap = prediction_config.tile_overlap + + # read source function + if prediction_config.data_type == SupportedData.CUSTOM: + # mypy check + assert read_source_func is not None + + self.read_source_func: Callable = read_source_func + elif prediction_config.data_type != SupportedData.ARRAY: + self.read_source_func = get_read_func(prediction_config.data_type) + + self.extension_filter = extension_filter + + def prepare_data(self) -> None: + """Hook used to prepare the data before calling `setup`.""" + # if the data is a Path or a str + if not isinstance(self.pred_data, np.ndarray): + self.pred_files = list_files( + self.pred_data, self.data_type, self.extension_filter + ) + + def setup(self, stage: Optional[str] = None) -> None: + """ + Hook called at the beginning of predict. + + Parameters + ---------- + stage : Optional[str], optional + Stage, by default None. + """ + # if numpy array + if self.data_type == SupportedData.ARRAY: + # prediction dataset + self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset( + prediction_config=self.prediction_config, + inputs=self.pred_data, + ) + else: + self.predict_dataset = IterablePredictionDataset( + prediction_config=self.prediction_config, + src_files=self.pred_files, + read_source_func=self.read_source_func, + ) + + def predict_dataloader(self) -> DataLoader: + """ + Create a dataloader for prediction. + + Returns + ------- + DataLoader + Prediction dataloader. + """ + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + collate_fn=_collate_tiles, + **self.dataloader_params, + ) # TODO check workers are used + + +class CAREamicsPredictDataModule(CAREamicsClay): + """ + LightningDataModule wrapper of an inference dataset. + + Since the lightning datamodule has no access to the model, make sure that the + parameters passed to the datamodule are consistent with the model's requirements + and are coherent. + + The data module can be used with Path, str or numpy arrays. To use array data, set + `data_type` to `array` and pass a numpy array to `train_data`. + + The default transformations applied to the images are defined in + `careamics.config.inference_model`. To use different transformations, pass a list + of transforms or an albumentation `Compose` as `transforms` parameter. See examples + for more details. + + The `mean` and `std` parameters are only used if Normalization is defined either + in the default transformations or in the `transforms` parameter, but not with + a `Compose` object. If you pass a `Normalization` transform in a list as + `transforms`, then the mean and std parameters will be overwritten by those passed + to this method. + + By default, CAREamics only supports types defined in + `careamics.config.support.SupportedData`. To read custom data types, you can set + `data_type` to `custom` and provide a function that returns a numpy array from a + path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression + (e.g. "*.jpeg") to filter the files extension using `extension_filter`. + + In `dataloader_params`, you can pass any parameter accepted by PyTorch + dataloaders, except for `batch_size`, which is set by the `batch_size` + parameter. + + Parameters + ---------- + pred_data : Union[str, Path, np.ndarray] + Prediction data. + data_type : Union[Literal["array", "tiff", "custom"], SupportedData] + Data type, see `SupportedData` for available options. + mean : float + Mean value for normalization, only used if Normalization is defined in the + transforms. + std : float + Standard deviation value for normalization, only used if Normalization is + defined in the transform. + tile_size : Tuple[int, ...] + Tile size, 2D or 3D tile size. + tile_overlap : Tuple[int, ...] + Tile overlap, 2D or 3D tile overlap. + axes : str + Axes of the data, choosen amongst SCZYX. + batch_size : int + Batch size. + tta_transforms : bool, optional + Use test time augmentation, by default True. + transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional + List of transforms to apply to prediction patches. If None, default + transforms are applied. + read_source_func : Optional[Callable], optional + Function to read the source data, used if `data_type` is `custom`, by + default None. + extension_filter : str, optional + Filter for file extensions, used if `data_type` is `custom`, by default "". + dataloader_params : dict, optional + Pytorch dataloader parameters, by default {}. + """ + + def __init__( + self, + pred_data: Union[str, Path, np.ndarray], + data_type: Union[Literal["array", "tiff", "custom"], SupportedData], + mean: float, + std: float, + tile_size: Optional[Tuple[int, ...]] = None, + tile_overlap: Optional[Tuple[int, ...]] = None, + axes: str = "YX", + batch_size: int = 1, + tta_transforms: bool = True, + transforms: Optional[Union[List, Compose]] = None, + read_source_func: Optional[Callable] = None, + extension_filter: str = "", + dataloader_params: Optional[dict] = None, + ) -> None: + """ + Constructor. + + Parameters + ---------- + pred_data : Union[str, Path, np.ndarray] + Prediction data. + data_type : Union[Literal["array", "tiff", "custom"], SupportedData] + Data type, see `SupportedData` for available options. + tile_size : List[int] + Tile size, 2D or 3D tile size. + tile_overlap : List[int] + Tile overlap, 2D or 3D tile overlap. + axes : str + Axes of the data, choosen amongst SCZYX. + batch_size : int + Batch size. + tta_transforms : bool, optional + Use test time augmentation, by default True. + mean : Optional[float], optional + Mean value for normalization, only used if Normalization is defined, by + default None. + std : Optional[float], optional + Standard deviation value for normalization, only used if Normalization is + defined, by default None. + transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional + List of transforms to apply to prediction patches. If None, default + transforms are applied. + read_source_func : Optional[Callable], optional + Function to read the source data, used if `data_type` is `custom`, by + default None. + extension_filter : str, optional + Filter for file extensions, used if `data_type` is `custom`, by default "". + dataloader_params : dict, optional + Pytorch dataloader parameters, by default {}. + """ + if dataloader_params is None: + dataloader_params = {} + prediction_dict = { + "data_type": data_type, + "tile_size": tile_size, + "tile_overlap": tile_overlap, + "axes": axes, + "mean": mean, + "std": std, + "tta": tta_transforms, + "batch_size": batch_size, + } + + # if transforms are passed (otherwise it will use the default ones) + if transforms is not None: + prediction_dict["transforms"] = transforms + + # validate configuration + self.prediction_config = InferenceModel(**prediction_dict) + + # sanity check on the dataloader parameters + if "batch_size" in dataloader_params: + # remove it + del dataloader_params["batch_size"] + + super().__init__( + prediction_config=self.prediction_config, + pred_data=pred_data, + read_source_func=read_source_func, + extension_filter=extension_filter, + dataloader_params=dataloader_params, + ) diff --git a/src/careamics/lightning_prediction_loop.py b/src/careamics/lightning_prediction_loop.py new file mode 100644 index 00000000..c7e00fd2 --- /dev/null +++ b/src/careamics/lightning_prediction_loop.py @@ -0,0 +1,116 @@ +from typing import Optional + +import pytorch_lightning as L +from pytorch_lightning.loops.fetchers import _DataLoaderIterDataFetcher +from pytorch_lightning.loops.utilities import _no_grad_context +from pytorch_lightning.trainer import call +from pytorch_lightning.utilities.types import _PREDICT_OUTPUT + +from careamics.prediction import stitch_prediction + + +class CAREamicsPredictionLoop(L.loops._PredictionLoop): + """ + CAREamics prediction loop. + + This class extends the PyTorch Lightning `_PredictionLoop` class to include + the stitching of the tiles into a single prediction result. + """ + + def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: + """ + Calls `on_predict_epoch_end` hook. + + Adapted from the parent method. + + Returns + ------- + the results for all dataloaders + """ + trainer = self.trainer + call._call_callback_hooks(trainer, "on_predict_epoch_end") + call._call_lightning_module_hook(trainer, "on_predict_epoch_end") + + if self.return_predictions: + ######################################################## + ################ CAREamics specific code ############### + if len(self.predicted_array) == 1: + # TODO does this make sense to here? (force numpy array) + return self.predicted_array[0].numpy() + else: + # TODO revisit logic + return [element.numpy() for element in self.predicted_array] + ######################################################## + return None + + @_no_grad_context + def run(self) -> Optional[_PREDICT_OUTPUT]: + """ + Runs the prediction loop. + + Adapted from the parent method in order to stitch the predictions. + + Returns + ------- + Optional[_PREDICT_OUTPUT] + Prediction output + """ + self.setup_data() + if self.skip: + return None + self.reset() + self.on_run_start() + data_fetcher = self._data_fetcher + assert data_fetcher is not None + + self.predicted_array = [] + self.tiles = [] + self.stitching_data = [] + + while True: + try: + if isinstance(data_fetcher, _DataLoaderIterDataFetcher): + dataloader_iter = next(data_fetcher) + # hook's batch_idx and dataloader_idx arguments correctness cannot + # be guaranteed in this setting + batch = data_fetcher._batch + batch_idx = data_fetcher._batch_idx + dataloader_idx = data_fetcher._dataloader_idx + else: + dataloader_iter = None + batch, batch_idx, dataloader_idx = next(data_fetcher) + self.batch_progress.is_last_batch = data_fetcher.done + + # run step hooks + self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter) + + ######################################################## + ################ CAREamics specific code ############### + is_tiled = len(self.predictions[batch_idx]) == 2 + if is_tiled: + # extract the last tile flag and the coordinates (crop and stitch) + last_tile, *stitch_data = self.predictions[batch_idx][1] + + # append the tile and the coordinates to the lists + self.tiles.append(self.predictions[batch_idx][0]) + self.stitching_data.append(stitch_data) + + # if last tile, stitch the tiles and add array to the prediction + if any(last_tile): + predicted_batches = stitch_prediction( + self.tiles, self.stitching_data + ) + self.predicted_array.append(predicted_batches) + self.tiles.clear() + self.stitching_data.clear() + else: + # simply add the prediction to the list + self.predicted_array.append(self.predictions[batch_idx]) + ######################################################## + except StopIteration: + break + finally: + self._restarting = False + return self.on_run_end() + + # TODO predictions aren't stacked, list returned diff --git a/src/careamics/losses/__init__.py b/src/careamics/losses/__init__.py index 85a1d2f1..e984229f 100644 --- a/src/careamics/losses/__init__.py +++ b/src/careamics/losses/__init__.py @@ -1,4 +1,7 @@ """Losses module.""" -from .loss_factory import create_loss_function as create_loss_function +from .loss_factory import loss_factory + +# from .noise_model_factory import noise_model_factory as noise_model_factory +# from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel diff --git a/src/careamics/losses/loss_factory.py b/src/careamics/losses/loss_factory.py index d24befa4..054f09c5 100644 --- a/src/careamics/losses/loss_factory.py +++ b/src/careamics/losses/loss_factory.py @@ -3,22 +3,21 @@ This module contains a factory function for creating loss functions. """ -from typing import Callable +from typing import Callable, Union -from careamics.config import Configuration -from careamics.config.algorithm import Loss +from ..config.support import SupportedLoss +from .losses import mae_loss, mse_loss, n2v_loss -from .losses import n2v_loss - -def create_loss_function(config: Configuration) -> Callable: - """ - Create loss function based on Configuration. +# TODO add tests +# TODO add custom? +def loss_factory(loss: Union[SupportedLoss, str]) -> Callable: + """Return loss function. Parameters ---------- - config : Configuration - Configuration. + loss: SupportedLoss + Requested loss. Returns ------- @@ -30,9 +29,20 @@ def create_loss_function(config: Configuration) -> Callable: NotImplementedError If the loss is unknown. """ - loss_type = config.algorithm.loss - - if loss_type == Loss.N2V: + if loss == SupportedLoss.N2V: return n2v_loss + + # elif loss_type == SupportedLoss.PN2V: + # return pn2v_loss + + elif loss == SupportedLoss.MAE: + return mae_loss + + elif loss == SupportedLoss.MSE: + return mse_loss + + # elif loss_type == SupportedLoss.DICE: + # return dice_loss + else: - raise NotImplementedError(f"Loss {loss_type} is not yet supported.") + raise NotImplementedError(f"Loss {loss} is not yet supported.") diff --git a/src/careamics/losses/losses.py b/src/careamics/losses/losses.py index b0c747c8..b6993742 100644 --- a/src/careamics/losses/losses.py +++ b/src/careamics/losses/losses.py @@ -3,14 +3,34 @@ This submodule contains the various losses used in CAREamics. """ + import torch +# TODO if we are only using the DiceLoss, can we just implement it? +# from segmentation_models_pytorch.losses import DiceLoss +from torch.nn import L1Loss, MSELoss + + +def mse_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """ + Mean squared error loss. + + Returns + ------- + torch.Tensor + Loss value. + """ + loss = MSELoss() + return loss(samples, labels) + def n2v_loss( - samples: torch.Tensor, labels: torch.Tensor, masks: torch.Tensor, device: str + manipulated_patches: torch.Tensor, + original_patches: torch.Tensor, + masks: torch.Tensor, ) -> torch.Tensor: """ - N2V Loss function (see Eq.7 in Krull et al). + N2V Loss function described in A Krull et al 2018. Parameters ---------- @@ -20,15 +40,55 @@ def n2v_loss( Noisy patches. masks : torch.Tensor Array containing masked pixel locations. - device : str - Device to use. Returns ------- torch.Tensor Loss value. """ - errors = (labels - samples) ** 2 + errors = (original_patches - manipulated_patches) ** 2 # Average over pixels and batch loss = torch.sum(errors * masks) / torch.sum(masks) return loss + + +def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """ + N2N Loss function described in to J Lehtinen et al 2018. + + Parameters + ---------- + samples : torch.Tensor + Raw patches. + labels : torch.Tensor + Different subset of noisy patches. + + Returns + ------- + torch.Tensor + Loss value. + """ + loss = L1Loss() + return loss(samples, labels) + + +# def pn2v_loss( +# samples: torch.Tensor, +# labels: torch.Tensor, +# masks: torch.Tensor, +# noise_model: HistogramNoiseModel, +# ) -> torch.Tensor: +# """Probabilistic N2V loss function described in A Krull et al., CVF (2019).""" +# likelihoods = noise_model.likelihood(labels, samples) +# likelihoods_avg = torch.log(torch.mean(likelihoods, dim=0, keepdim=True)[0, ...]) + +# # Average over pixels and batch +# loss = -torch.sum(likelihoods_avg * masks) / torch.sum(masks) +# return loss + + +# def dice_loss( +# samples: torch.Tensor, labels: torch.Tensor, mode: str = "multiclass" +# ) -> torch.Tensor: +# """Dice loss function.""" +# return DiceLoss(mode=mode)(samples, labels.long()) diff --git a/src/careamics/losses/noise_model_factory.py b/src/careamics/losses/noise_model_factory.py new file mode 100644 index 00000000..fdab1182 --- /dev/null +++ b/src/careamics/losses/noise_model_factory.py @@ -0,0 +1,40 @@ +from typing import Type, Union + +from ..config.noise_models import NoiseModel, NoiseModelType +from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel + + +def noise_model_factory( + noise_config: NoiseModel, +) -> Type[Union[HistogramNoiseModel, GaussianMixtureNoiseModel, None]]: + """Create loss model based on Configuration. + + Parameters + ---------- + config : Configuration + Configuration. + + Returns + ------- + Noise model + + Raises + ------ + NotImplementedError + If the noise model is unknown. + """ + noise_model_type = noise_config.model_type if noise_config else None + + if noise_model_type == NoiseModelType.HIST: + return HistogramNoiseModel + + elif noise_model_type == NoiseModelType.GMM: + return GaussianMixtureNoiseModel + + elif noise_model_type is None: + return None + + else: + raise NotImplementedError( + f"Noise model {noise_model_type} is not yet supported." + ) diff --git a/src/careamics/losses/noise_models.py b/src/careamics/losses/noise_models.py new file mode 100644 index 00000000..5f4fc8ef --- /dev/null +++ b/src/careamics/losses/noise_models.py @@ -0,0 +1,524 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch + +from ..utils.logging import get_logger + +logger = get_logger(__name__) + + +# TODO here "Model" clashes a bit with the naming convention of the Pydantic Models +class NoiseModel(ABC): + """Base class for noise models.""" + + @abstractmethod + def instantiate(self): + """Instantiate the noise model. + + Method that should produce ready to use noise model. + """ + pass + + @abstractmethod + def likelihood(self, observations, signals): + """Function that returns the likelihood of observations given the signals.""" + pass + + +class HistogramNoiseModel(NoiseModel): + """Creates a NoiseModel object. + + Parameters + ---------- + histogram: numpy array + A histogram as create by the 'createHistogram(...)' method. + device: + The device your NoiseModel lives on, e.g. your GPU. + """ + + def __init__(self, **kwargs): + pass + + def instantiate(self, bins, min_value, max_value, observation, signal): + """Creates a nD histogram from 'observation' and 'signal'. + + Parameters + ---------- + bins: int + The number of bins in all dimensions. The total number of bins is + 'bins' ** number_of_dimensions. + min_value: float + the lower bound of the lowest bin. + max_value: float + the highest bound of the highest bin. + observation: np.array + A stack of noisy images. The number has to be divisible by the number of + images in signal. N subsequent images in observation belong to one image + in the signal. + signal: np.array + A stack of clean images. + + Returns + ------- + histogram: numpy array + A 3D array: + 'histogram[0,...]' holds the normalized nD counts. + Each row sums to 1, describing p(x_i|s_i). + 'histogram[1,...]' holds the lower boundaries of each bin in y. + 'histogram[2,...]' holds the upper boundaries of each bin in y. + The values for x can be obtained by transposing 'histogram[1,...]' + and 'histogram[2,...]'. + """ + img_factor = int(observation.shape[0] / signal.shape[0]) + histogram = np.zeros((3, bins, bins)) + value_range = [min_value, max_value] + + for i in range(observation.shape[0]): + observation_i = observation[i].copy().ravel() + + signal_i = (signal[i // img_factor].copy()).ravel() + + histogram_i = np.histogramdd( + (signal_i, observation_i), bins=bins, range=[value_range, value_range] + ) + # Adding a constant for numerical stability + histogram[0] = histogram[0] + histogram_i[0] + 1e-30 + + for i in range(bins): + # Exclude empty rows from normalization + if np.sum(histogram[0, i, :]) > 1e-20: + # Normalize each non-empty row + histogram[0, i, :] /= np.sum(histogram[0, i, :]) + + for i in range(bins): + # The lower boundaries of each bin in y are stored in dimension 1 + histogram[1, :, i] = histogram_i[1][:-1] + # The upper boundaries of each bin in y are stored in dimension 2 + histogram[2, :, i] = histogram_i[1][1:] + # The accordent numbers for x are just transposed. + + return histogram + + def likelihood(self, observed, signal): + """Calculate the likelihood using a histogram based noise model. + + For every pixel in a tensor, calculate (x_i|s_i). To ensure differentiability + in the direction of s_i, we linearly interpolate in this direction. + + Parameters + ---------- + observed: torch.Tensor + tensor holding your observed intesities x_i. + + signal: torch.Tensor + tensor holding hypotheses for the clean signal at every pixel s_i^k. + + Returns + ------- + Torch.tensor containing the observation likelihoods according to the + noise model. + """ + observed_float = self.get_index_observed_float(observed) + observed_long = observed_float.floor().long() + signal_float = self.get_index_signal_float(signal) + signal_long = signal_float.floor().long() + fact = signal_float - signal_long.float() + + # Finally we are looking ud the values and interpolate + return self.fullHist[signal_long, observed_long] * (1.0 - fact) + self.fullHist[ + torch.clamp((signal_long + 1).long(), 0, self.bins.long()), observed_long + ] * (fact) + + def get_index_observed_float(self, x: float): + """_summary_. + + Parameters + ---------- + x : _type_ + _description_ + + Returns + ------- + _type_ + _description_ + """ + return torch.clamp( + self.bins * (x - self.minv) / (self.maxv - self.minv), + min=0.0, + max=self.bins - 1 - 1e-3, + ) + + def get_index_signal_float(self, x): + """_summary_. + + Parameters + ---------- + x : _type_ + _description_ + + Returns + ------- + _type_ + _description_ + """ + return torch.clamp( + self.bins * (x - self.minv) / (self.maxv - self.minv), + min=0.0, + max=self.bins - 1 - 1e-3, + ) + + +# TODO refactor this into Pydantic model +class GaussianMixtureNoiseModel(NoiseModel): + """Describes a noise model parameterized as a mixture of gaussians. + + If you would like to initialize a new object from scratch, then set `params` = None + and specify the other parameters as keyword arguments. If you are instead loading + a model, use only `params`. + + Parameters + ---------- + **kwargs: keyworded, variable-length argument dictionary. + Arguments include: + min_signal : float + Minimum signal intensity expected in the image. + max_signal : float + Maximum signal intensity expected in the image. + weight : array + A [3*n_gaussian, n_coeff] sized array containing the values of the weights + describing the noise model. + Each gaussian contributes three parameters (mean, standard deviation and weight), + hence the number of rows in `weight` are 3*n_gaussian. + If `weight = None`, the weight array is initialized using the `min_signal` and + `max_signal` parameters. + n_gaussian: int + Number of gaussians. + n_coeff: int + Number of coefficients to describe the functional relationship between gaussian + parameters and the signal. + 2 implies a linear relationship, 3 implies a quadratic relationship and so on. + device: device + GPU device. + min_sigma: int + All values of sigma (`standard deviation`) below min_sigma are clamped to become + equal to min_sigma. + params: dictionary + Use `params` if one wishes to load a model with trained weights. + While initializing a new object of the class `GaussianMixtureNoiseModel` from + scratch, set this to `None`. + """ + + def __init__(self, **kwargs): + if kwargs.get("params") is None: + weight = kwargs.get("weight") + n_gaussian = kwargs.get("n_gaussian") + n_coeff = kwargs.get("n_coeff") + min_signal = kwargs.get("min_signal") + max_signal = kwargs.get("max_signal") + self.device = kwargs.get("device") + self.path = kwargs.get("path") + self.min_sigma = kwargs.get("min_sigma") + if weight is None: + weight = np.random.randn(n_gaussian * 3, n_coeff) + weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal) + weight = ( + torch.from_numpy(weight.astype(np.float32)).float().to(self.device) + ) + weight.requires_grad = True + self.n_gaussian = weight.shape[0] // 3 + self.n_coeff = weight.shape[1] + self.weight = weight + self.min_signal = torch.Tensor([min_signal]).to(self.device) + self.max_signal = torch.Tensor([max_signal]).to(self.device) + self.tol = torch.Tensor([1e-10]).to(self.device) + else: + params = kwargs.get("params") + self.device = kwargs.get("device") + + self.min_signal = torch.Tensor(params["min_signal"]).to(self.device) + self.max_signal = torch.Tensor(params["max_signal"]).to(self.device) + + self.weight = torch.Tensor(params["trained_weight"]).to(self.device) + self.min_sigma = np.ndarray.item(params["min_sigma"]) + self.n_gaussian = self.weight.shape[0] // 3 + self.n_coeff = self.weight.shape[1] + self.tol = torch.Tensor([1e-10]).to(self.device) + self.min_signal = torch.Tensor([self.min_signal]).to(self.device) + self.max_signal = torch.Tensor([self.max_signal]).to(self.device) + + def fast_shuffle(self, series, num): + """. + + Parameters + ---------- + series : _type_ + _description_ + num : _type_ + _description_ + + Returns + ------- + _type_ + _description_ + """ + length = series.shape[0] + for _i in range(num): + series = series[np.random.permutation(length), :] + return series + + def polynomial_regressor(self, weightParams, signals): + """Combines weight_parameters and signals to perform regression. + + Parameters + ---------- + weightParams : torch.cuda.FloatTensor + Corresponds to specific rows of the `self.weight' + + signals : torch.cuda.FloatTensor + Signals + + Returns + ------- + value : torch.cuda.FloatTensor + Corresponds to either of mean, standard deviation or weight, evaluated at + `signals` + """ + value = 0 + for i in range(weightParams.shape[0]): + value += weightParams[i] * ( + ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i + ) + return value + + def normal_density(self, x, m_=0.0, std_=None): + """Evaluates the normal probability density. + + Parameters + ---------- + x: torch.cuda.FloatTensor + Observations + m_: torch.cuda.FloatTensor + Mean + std_: torch.cuda.FloatTensor + Standard-deviation + + Returns + ------- + tmp: torch.cuda.FloatTensor + Normal probability density of `x` given `m_` and `std_` + + """ + tmp = -((x - m_) ** 2) + tmp = tmp / (2.0 * std_ * std_) + tmp = torch.exp(tmp) + tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_) + return tmp + + def likelihood(self, observations, signals): + """Evaluates the likelihood of observations. + + Given the signals and the corresponding gaussian parameters evaluates the + likelihood of observations. + + Parameters + ---------- + observations : torch.cuda.FloatTensor + Noisy observations + signals : torch.cuda.FloatTensor + Underlying signals + + Returns + ------- + value :p + self.tol + Likelihood of observations given the signals and the GMM noise model + + """ + gaussianParameters = self.getGaussianParameters(signals) + p = 0 + for gaussian in range(self.n_gaussian): + p += ( + self.normalDens( + observations, + gaussianParameters[gaussian], + gaussianParameters[self.n_gaussian + gaussian], + ) + * gaussianParameters[2 * self.n_gaussian + gaussian] + ) + return p + self.tol + + def get_gaussian_parameters(self, signals): + """Returns the noise model for given signals. + + Parameters + ---------- + signals : torch.cuda.FloatTensor + Underlying signals + + Returns + ------- + noiseModel: list of torch.cuda.FloatTensor + Contains a list of `mu`, `sigma` and `alpha` for the `signals` + + """ + noiseModel = [] + mu = [] + sigma = [] + alpha = [] + kernels = self.weight.shape[0] // 3 + for num in range(kernels): + mu.append(self.polynomialRegressor(self.weight[num, :], signals)) + + sigmaTemp = self.polynomialRegressor( + torch.exp(self.weight[kernels + num, :]), signals + ) + sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma) + sigma.append(torch.sqrt(sigmaTemp)) + alpha.append( + torch.exp( + self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) + + self.tol + ) + ) + + sum_alpha = 0 + for al in range(kernels): + sum_alpha = alpha[al] + sum_alpha + for ker in range(kernels): + alpha[ker] = alpha[ker] / sum_alpha + + sum_means = 0 + for ker in range(kernels): + sum_means = alpha[ker] * mu[ker] + sum_means + + for ker in range(kernels): + mu[ker] = mu[ker] - sum_means + signals + + for i in range(kernels): + noiseModel.append(mu[i]) + for j in range(kernels): + noiseModel.append(sigma[j]) + for k in range(kernels): + noiseModel.append(alpha[k]) + + return noiseModel + + def get_signal_observation_pairs(self, signal, observation, lowerClip, upperClip): + """Returns the Signal-Observation pixel intensities as a two-column array. + + Parameters + ---------- + signal : numpy array + Clean Signal Data + observation: numpy array + Noisy observation Data + lowerClip: float + Lower percentile bound for clipping. + upperClip: float + Upper percentile bound for clipping. + + Returns + ------- + noiseModel: list of torch floats + Contains a list of `mu`, `sigma` and `alpha` for the `signals` + + """ + lb = np.percentile(signal, lowerClip) + ub = np.percentile(signal, upperClip) + stepsize = observation[0].size + n_observations = observation.shape[0] + n_signals = signal.shape[0] + sig_obs_pairs = np.zeros((n_observations * stepsize, 2)) + + for i in range(n_observations): + j = i // (n_observations // n_signals) + sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel() + sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel() + sig_obs_pairs = sig_obs_pairs[ + (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub) + ] + return self.fast_shuffle(sig_obs_pairs, 2) + + def train( + self, + signal, + observation, + learning_rate=1e-1, + batchSize=250000, + n_epochs=2000, + name="GMMNoiseModel.npz", + lowerClip=0, + upperClip=100, + ): + """Training to learn the noise model from signal - observation pairs. + + Parameters + ---------- + signal: numpy array + Clean Signal Data + observation: numpy array + Noisy Observation Data + learning_rate: float + Learning rate. Default = 1e-1. + batchSize: int + Nini-batch size. Default = 250000. + n_epochs: int + Number of epochs. Default = 2000. + name: string + Model name. Default is `GMMNoiseModel`. This model after being trained is + saved at the location `path`. + + lowerClip : int + Lower percentile for clipping. Default is 0. + upperClip : int + Upper percentile for clipping. Default is 100. + + + """ + sig_obs_pairs = self.getSignalObservationPairs( + signal, observation, lowerClip, upperClip + ) + counter = 0 + optimizer = torch.optim.Adam([self.weight], lr=learning_rate) + for t in range(n_epochs): + jointLoss = 0 + if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]: + counter = 0 + sig_obs_pairs = self.fast_shuffle(sig_obs_pairs, 1) + + batch_vectors = sig_obs_pairs[ + counter * batchSize : (counter + 1) * batchSize, : + ] + observations = batch_vectors[:, 1].astype(np.float32) + signals = batch_vectors[:, 0].astype(np.float32) + observations = ( + torch.from_numpy(observations.astype(np.float32)) + .float() + .to(self.device) + ) + signals = torch.from_numpy(signals).float().to(self.device) + p = self.likelihood(observations, signals) + loss = torch.mean(-torch.log(p)) + jointLoss = jointLoss + loss + + if t % 100 == 0: + print(t, jointLoss.item()) + + if t % (int(n_epochs * 0.5)) == 0: + trained_weight = self.weight.cpu().detach().numpy() + min_signal = self.min_signal.cpu().detach().numpy() + max_signal = self.max_signal.cpu().detach().numpy() + np.savez( + self.path + name, + trained_weight=trained_weight, + min_signal=min_signal, + max_signal=max_signal, + min_sigma=self.min_sigma, + ) + + optimizer.zero_grad() + jointLoss.backward() + optimizer.step() + counter += 1 + + logger.info(f"The trained parameters {name} is saved at location: " + self.path) diff --git a/src/careamics/manipulation/__init__.py b/src/careamics/manipulation/__init__.py deleted file mode 100644 index 59c38cc9..00000000 --- a/src/careamics/manipulation/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Pixel manipulation functions for N2V.""" - - -from .pixel_manipulation import default_manipulate as default_manipulate diff --git a/src/careamics/manipulation/pixel_manipulation.py b/src/careamics/manipulation/pixel_manipulation.py deleted file mode 100644 index c60c35fc..00000000 --- a/src/careamics/manipulation/pixel_manipulation.py +++ /dev/null @@ -1,158 +0,0 @@ -""" -Pixel manipulation methods. - -Pixel manipulation is used in N2V and similar algorithm to replace the value of -masked pixels. -""" -from typing import Callable, Optional, Tuple - -import numpy as np - - -def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray: - """ - Randomly sample a jitter to be applied to the masking grid. - - This is done to account for cases where the step size is not an integer. - - Parameters - ---------- - step : float - Step size of the grid, output of np.linspace. - rng : np.random.Generator - Random number generator. - - Returns - ------- - np.ndarray - Array of random jitter to be added to the grid. - """ - # Define the random jitter to be added to the grid - odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2)) - - # Round the step size to the nearest integer depending on the jitter - return np.floor(step) if odd_jitter == 0 else np.ceil(step) - - -def get_stratified_coords( - mask_pixel_perc: float, - shape: Tuple[int, ...], -) -> np.ndarray: - """ - Generate coordinates of the pixels to mask. - - Randomly selects the coordinates of the pixels to mask in a stratified way, i.e. - the distance between masked pixels is approximately the same. - - Parameters - ---------- - mask_pixel_perc : float - Actual (quasi) percentage of masked pixels across the whole image. Used in - calculating the distance between masked pixels across each axis. - shape : Tuple[int, ...] - Shape of the input patch. - - Returns - ------- - np.ndarray - Array of coordinates of the masked pixels. - """ - rng = np.random.default_rng() - - # Define the approximate distance between masked pixels - mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype( - np.int32 - ) - - # Define a grid of coordinates for each axis in the input patch and the step size - pixel_coords = [] - for axis_size in shape: - # make sure axis size is evenly divisible by box size - num_pixels = int(np.ceil(axis_size / mask_pixel_distance)) - axis_pixel_coords, step = np.linspace( - 0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True - ) - # explain - pixel_coords.append(axis_pixel_coords.T) - - # Create a meshgrid of coordinates for each axis in the input patch - coordinate_grid_list = np.meshgrid(*pixel_coords) - coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T - - grid_random_increment = rng.integers( - _odd_jitter_func(float(step), rng) - * np.ones_like(coordinate_grid).astype(np.int32) - - 1, - size=coordinate_grid.shape, - endpoint=True, - ) - coordinate_grid += grid_random_increment - coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1) - return coordinate_grid - - -def default_manipulate( - patch: np.ndarray, - mask_pixel_percentage: float, - roi_size: int = 11, - augmentations: Optional[Callable] = None, -) -> Tuple[np.ndarray, ...]: - """ - Manipulate pixel in a patch, i.e. replace the masked value. - - Parameters - ---------- - patch : np.ndarray - Image patch, 2D or 3D, shape (y, x) or (z, y, x). - mask_pixel_percentage : floar - Approximate percentage of pixels to be masked. - roi_size : int - Size of the ROI the new pixel value is sampled from, by default 11. - augmentations : Callable, optional - Augmentations to apply, by default None. - - Returns - ------- - Tuple[np.ndarray] - Tuple containing the manipulated patch, the original patch and the mask. - """ - original_patch = patch.copy() - - # Get the coordinates of the pixels to be replaced - roi_centers = get_stratified_coords(mask_pixel_percentage, patch.shape) - rng = np.random.default_rng() - - # Generate coordinate grid for ROI - roi_span_full = np.arange(-np.floor(roi_size / 2), np.ceil(roi_size / 2)).astype( - np.int32 - ) - # Remove the center pixel from the grid - roi_span_wo_center = roi_span_full[roi_span_full != 0] - - # Randomly select coordinates from the grid - random_increment = rng.choice(roi_span_wo_center, size=roi_centers.shape) - - # Clip the coordinates to the patch size - replacement_coords = np.clip( - roi_centers + random_increment, - 0, - [patch.shape[i] - 1 for i in range(len(patch.shape))], - ) - # Get the replacement pixels from all rois - replacement_pixels = patch[tuple(replacement_coords.T.tolist())] - - # Replace the original pixels with the replacement pixels - patch[tuple(roi_centers.T.tolist())] = replacement_pixels - mask = np.where(patch != original_patch, 1, 0).astype(np.uint8) - - patch, original_patch, mask = ( - (patch, original_patch, mask) - if augmentations is None - else augmentations(patch, original_patch, mask) - ) - - return ( - np.expand_dims(patch, 0), - np.expand_dims(original_patch, 0), - np.expand_dims(mask, 0), - ) diff --git a/src/careamics/model_io/__init__.py b/src/careamics/model_io/__init__.py new file mode 100644 index 00000000..0e99771f --- /dev/null +++ b/src/careamics/model_io/__init__.py @@ -0,0 +1,8 @@ +"""Model I/O utilities.""" + + +__all__ = ["load_pretrained", "export_to_bmz"] + + +from .bmz_io import export_to_bmz +from .model_io_utils import load_pretrained diff --git a/src/careamics/model_io/bioimage/__init__.py b/src/careamics/model_io/bioimage/__init__.py new file mode 100644 index 00000000..f312bc7e --- /dev/null +++ b/src/careamics/model_io/bioimage/__init__.py @@ -0,0 +1,11 @@ +"""Bioimage Model Zoo format functions.""" + +__all__ = [ + "create_model_description", + "extract_model_path", + "get_unzip_path", + "create_env_text", +] + +from .bioimage_utils import create_env_text, get_unzip_path +from .model_description import create_model_description, extract_model_path diff --git a/src/careamics/model_io/bioimage/_readme_factory.py b/src/careamics/model_io/bioimage/_readme_factory.py new file mode 100644 index 00000000..e823f378 --- /dev/null +++ b/src/careamics/model_io/bioimage/_readme_factory.py @@ -0,0 +1,120 @@ +"""Functions used to create a README.md file for BMZ export.""" +from pathlib import Path +from typing import Optional + +import yaml + +from careamics.config import Configuration +from careamics.utils import cwd, get_careamics_home + + +def _yaml_block(yaml_str: str) -> str: + """Return a markdown code block with a yaml string. + + Parameters + ---------- + yaml_str : str + YAML string. + + Returns + ------- + str + Markdown code block with the YAML string. + """ + return f"```yaml\n{yaml_str}\n```" + + +def readme_factory( + config: Configuration, + careamics_version: str, + data_description: Optional[str] = None, +) -> Path: + """Create a README file for the model. + + `data_description` can be used to add more information about the content of the + data the model was trained on. + + Parameters + ---------- + config : Configuration + CAREamics configuration. + careamics_version : str + CAREamics version. + data_description : Optional[str], optional + Description of the data, by default None. + + Returns + ------- + Path + Path to the README file. + """ + algorithm = config.algorithm_config + training = config.training_config + data = config.data_config + + # create file + # TODO use tempfile as in the bmz_io module + with cwd(get_careamics_home()): + readme = Path("README.md") + readme.touch() + + # algorithm pretty name + algorithm_flavour = config.get_algorithm_flavour() + algorithm_pretty_name = algorithm_flavour + " - CAREamics" + + description = [f"# {algorithm_pretty_name}\n\n"] + + # algorithm description + description.append("Algorithm description:\n\n") + description.append(config.get_algorithm_description()) + description.append("\n\n") + + # algorithm details + description.append( + f"{algorithm_flavour} was trained using CAREamics (version " + f"{careamics_version}) with the following algorithm " + f"parameters:\n\n" + ) + description.append( + _yaml_block(yaml.dump(algorithm.model_dump(exclude_none=True))) + ) + description.append("\n\n") + + # data description + description.append("## Data description\n\n") + if data_description is not None: + description.append(data_description) + description.append("\n\n") + + description.append("The data was processed using the following parameters:\n\n") + + description.append(_yaml_block(yaml.dump(data.model_dump(exclude_none=True)))) + description.append("\n\n") + + # training description + description.append("## Training description\n\n") + + description.append("The model was trained using the following parameters:\n\n") + + description.append( + _yaml_block(yaml.dump(training.model_dump(exclude_none=True))) + ) + description.append("\n\n") + + # references + reference = config.get_algorithm_references() + if reference != "": + description.append("## References\n\n") + description.append(reference) + description.append("\n\n") + + # links + description.append( + "## Links\n\n" + "- [CAREamics repository](https://github.com/CAREamics/careamics)\n" + "- [CAREamics documentation](https://careamics.github.io/latest/)\n" + ) + + readme.write_text("".join(description)) + + return readme diff --git a/src/careamics/model_io/bioimage/bioimage_utils.py b/src/careamics/model_io/bioimage/bioimage_utils.py new file mode 100644 index 00000000..1ce28bfc --- /dev/null +++ b/src/careamics/model_io/bioimage/bioimage_utils.py @@ -0,0 +1,48 @@ +"""Bioimage.io utils.""" +from pathlib import Path +from typing import Union + + +def get_unzip_path(zip_path: Union[Path, str]) -> Path: + """Generate unzipped folder path from the bioimage.io model path. + + Parameters + ---------- + zip_path : Path + Path to the bioimage.io model. + + Returns + ------- + Path + Path to the unzipped folder. + """ + zip_path = Path(zip_path) + + return zip_path.parent / (str(zip_path.name) + ".unzip") + + +def create_env_text(pytorch_version: str) -> str: + """Create environment text for the bioimage model. + + Parameters + ---------- + pytorch_version : str + Pytorch version. + + Returns + ------- + str + Environment text. + """ + env = ( + f"name: careamics\n" + f"dependencies:\n" + f" - python=3.8\n" + f" - pytorch={pytorch_version}\n" + f" - torchvision={pytorch_version}\n" + f" - pip\n" + f" - pip:\n" + f" - git+https://github.com/CAREamics/careamics.git@dl4mia\n" + ) + # TODO from pip with package version + return env diff --git a/src/careamics/model_io/bioimage/model_description.py b/src/careamics/model_io/bioimage/model_description.py new file mode 100644 index 00000000..1901b955 --- /dev/null +++ b/src/careamics/model_io/bioimage/model_description.py @@ -0,0 +1,318 @@ +"""Module use to build BMZ model description.""" +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +from bioimageio.spec.model.v0_5 import ( + ArchitectureFromLibraryDescr, + Author, + AxisBase, + AxisId, + BatchAxis, + ChannelAxis, + EnvironmentFileDescr, + FileDescr, + FixedZeroMeanUnitVarianceDescr, + FixedZeroMeanUnitVarianceKwargs, + Identifier, + InputTensorDescr, + ModelDescr, + OutputTensorDescr, + PytorchStateDictWeightsDescr, + SpaceInputAxis, + SpaceOutputAxis, + TensorId, + Version, + WeightsDescr, +) + +from careamics.config import Configuration, DataModel + +from ._readme_factory import readme_factory + + +def _create_axes( + array: np.ndarray, + data_config: DataModel, + channel_names: Optional[List[str]] = None, + is_input: bool = True, +) -> List[AxisBase]: + """Create axes description. + + Array shape is expected to be SC(Z)YX. + + Parameters + ---------- + array : np.ndarray + Array. + data_config : DataModel + CAREamics data configuration. + channel_names : Optional[List[str]], optional + Channel names, by default None. + is_input : bool, optional + Whether the axes are input axes, by default True. + + Returns + ------- + List[AxisBase] + List of axes description. + + Raises + ------ + ValueError + If channel names are not provided when channel axis is present. + """ + # axes have to be SC(Z)YX + spatial_axes = data_config.axes.replace("S", "").replace("C", "") + + # batch is always present + axes_model = [BatchAxis()] + + if "C" in data_config.axes: + if channel_names is not None: + axes_model.append( + ChannelAxis(channel_names=[Identifier(name) for name in channel_names]) + ) + else: + raise ValueError( + f"Channel names must be provided if channel axis is present, axes: " + f"{data_config.axes}." + ) + else: + # singleton channel + axes_model.append(ChannelAxis(channel_names=[Identifier("channel")])) + + # spatial axes + for ind, axes in enumerate(spatial_axes): + if axes in ["X", "Y", "Z"]: + if is_input: + axes_model.append( + SpaceInputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind]) + ) + else: + axes_model.append( + SpaceOutputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind]) + ) + + return axes_model + + +def _create_inputs_ouputs( + input_array: np.ndarray, + output_array: np.ndarray, + data_config: DataModel, + input_path: Union[Path, str], + output_path: Union[Path, str], + channel_names: Optional[List[str]] = None, +) -> Tuple[InputTensorDescr, OutputTensorDescr]: + """Create input and output tensor description. + + Input and output paths must point to a `.npy` file. + + Parameters + ---------- + input_array : np.ndarray + Input array. + output_array : np.ndarray + Output array. + data_config : DataModel + CAREamics data configuration. + input_path : Union[Path, str] + Path to input .npy file. + output_path : Union[Path, str] + Path to output .npy file. + channel_names : Optional[List[str]], optional + Channel names, by default None. + + Returns + ------- + Tuple[InputTensorDescr, OutputTensorDescr] + Input and output tensor descriptions. + """ + input_axes = _create_axes(input_array, data_config, channel_names) + output_axes = _create_axes(output_array, data_config, channel_names, False) + + # mean and std + assert data_config.mean is not None, "Mean cannot be None." + assert data_config.std is not None, "Std cannot be None." + mean = data_config.mean + std = data_config.std + + # and the mean and std required to invert the normalization + # CAREamics denormalization: x = y * (std + eps) + mean + # BMZ normalization : x = (y - mean') / (std' + eps) + # to apply the BMZ normalization as a denormalization step, we need: + eps = 1e-6 + inv_mean = -mean / (std + eps) + inv_std = 1 / (std + eps) - eps + + # create input/output descriptions + input_descr = InputTensorDescr( + id=TensorId("input"), + axes=input_axes, + test_tensor=FileDescr(source=input_path), + preprocessing=[ + FixedZeroMeanUnitVarianceDescr( + kwargs=FixedZeroMeanUnitVarianceKwargs(mean=mean, std=std) + ) + ], + ) + output_descr = OutputTensorDescr( + id=TensorId("prediction"), + axes=output_axes, + test_tensor=FileDescr(source=output_path), + postprocessing=[ + FixedZeroMeanUnitVarianceDescr( + kwargs=FixedZeroMeanUnitVarianceKwargs( # invert normalization + mean=inv_mean, std=inv_std + ) + ) + ], + ) + + return input_descr, output_descr + + +def create_model_description( + config: Configuration, + name: str, + general_description: str, + authors: List[Author], + inputs: Union[Path, str], + outputs: Union[Path, str], + weights_path: Union[Path, str], + torch_version: str, + careamics_version: str, + config_path: Union[Path, str], + env_path: Union[Path, str], + channel_names: Optional[List[str]] = None, + data_description: Optional[str] = None, +) -> ModelDescr: + """Create model description. + + Parameters + ---------- + config : Configuration + CAREamics configuration. + name : str + Name fo the model. + general_description : str + General description of the model. + authors : List[Author] + Authors of the model. + inputs : Union[Path, str] + Path to input .npy file. + outputs : Union[Path, str] + Path to output .npy file. + weights_path : Union[Path, str] + Path to model weights. + torch_version : str + Pytorch version. + careamics_version : str + CAREamics version. + config_path : Union[Path, str] + Path to model configuration. + env_path : Union[Path, str] + Path to environment file. + channel_names : Optional[List[str]], optional + Channel names, by default None. + data_description : Optional[str], optional + Description of the data, by default None. + + Returns + ------- + ModelDescr + Model description. + """ + # documentation + doc = readme_factory( + config, + careamics_version=careamics_version, + data_description=data_description, + ) + + # inputs, outputs + input_descr, output_descr = _create_inputs_ouputs( + input_array=np.load(inputs), + output_array=np.load(outputs), + data_config=config.data_config, + input_path=inputs, + output_path=outputs, + channel_names=channel_names, + ) + + # weights description + architecture_descr = ArchitectureFromLibraryDescr( + import_from="careamics.models", + callable=f"{config.algorithm_config.model.architecture}", + kwargs=config.algorithm_config.model.model_dump(), + ) + + weights_descr = WeightsDescr( + pytorch_state_dict=PytorchStateDictWeightsDescr( + source=weights_path, + architecture=architecture_descr, + pytorch_version=Version(torch_version), + dependencies=EnvironmentFileDescr(source=env_path), + ), + ) + + # overall model description + model = ModelDescr( + name=name, + authors=authors, + description=general_description, + documentation=doc, + inputs=[input_descr], + outputs=[output_descr], + tags=config.get_algorithm_keywords(), + links=[ + "https://github.com/CAREamics/careamics", + "https://careamics.github.io/latest/", + ], + license="BSD-3-Clause", + version="0.1.0", + weights=weights_descr, + attachments=[FileDescr(source=config_path)], + cite=config.get_algorithm_citations(), + config={ # conversion from float32 to float64 creates small differences... + "bioimageio": { + "test_kwargs": { + "pytorch_state_dict": { + "decimals": 2, # ...so we relax the constraints on the decimals + } + } + } + }, + ) + + return model + + +def extract_model_path(model_desc: ModelDescr) -> Tuple[Path, Path]: + """Return the relative path to the weights and configuration files. + + Parameters + ---------- + model_desc : ModelDescr + Model description. + + Returns + ------- + Tuple[Path, Path] + Weights and configuration paths. + """ + weights_path = model_desc.weights.pytorch_state_dict.source.path + + if len(model_desc.attachments) == 1: + config_path = model_desc.attachments[0].source.path + else: + for file in model_desc.attachments: + if file.source.path.suffix == ".yml": + config_path = file.source.path + break + + if config_path is None: + raise ValueError("Configuration file not found.") + + return weights_path, config_path diff --git a/src/careamics/model_io/bmz_io.py b/src/careamics/model_io/bmz_io.py new file mode 100644 index 00000000..06e5d944 --- /dev/null +++ b/src/careamics/model_io/bmz_io.py @@ -0,0 +1,231 @@ +"""Function to export to the BioImage Model Zoo format.""" +import tempfile +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +import pkg_resources +from bioimageio.core import load_description, test_model +from bioimageio.spec import ValidationSummary, save_bioimageio_package +from torch import __version__, load, save + +from careamics.config import Configuration, load_configuration, save_configuration +from careamics.config.support import SupportedArchitecture +from careamics.lightning_module import CAREamicsKiln + +from .bioimage import ( + create_env_text, + create_model_description, + extract_model_path, + get_unzip_path, +) + + +def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path: + """ + Export the model state dictionary to a file. + + Parameters + ---------- + model : CAREamicsKiln + CAREamics model to export. + path : Union[Path, str] + Path to the file where to save the model state dictionary. + + Returns + ------- + Path + Path to the saved model state dictionary. + """ + path = Path(path) + + # make sure it has the correct suffix + if path.suffix not in ".pth": + path = path.with_suffix(".pth") + + # save model state dictionary + # we save through the torch model itself to avoid the initial "model." in the + # layers naming, which is incompatible with the way the BMZ load torch state dicts + save(model.model.state_dict(), path) + + return path + + +def _load_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> None: + """ + Load a model from a state dictionary. + + Parameters + ---------- + model : CAREamicsKiln + CAREamics model to be updated with the weights. + path : Union[Path, str] + Path to the model state dictionary. + """ + path = Path(path) + + # load model state dictionary + # same as in _export_state_dict, we load through the torch model to be compatible + # witht bioimageio.core expectations for a torch state dict + state_dict = load(path) + model.model.load_state_dict(state_dict) + + +# TODO break down in subfunctions +def export_to_bmz( + model: CAREamicsKiln, + config: Configuration, + path: Union[Path, str], + name: str, + general_description: str, + authors: List[dict], + input_array: np.ndarray, + output_array: np.ndarray, + channel_names: Optional[List[str]] = None, + data_description: Optional[str] = None, +) -> None: + """Export the model to BioImage Model Zoo format. + + Arrays are expected to be SC(Z)YX with singleton dimensions allowed for S and C. + + Parameters + ---------- + model : CAREamicsKiln + CAREamics model to export. + config : Configuration + Model configuration. + path : Union[Path, str] + Path to the output file. + name : str + Model name. + general_description : str + General description of the model. + authors : List[dict] + Authors of the model. + input_array : np.ndarray + Input array. + output_array : np.ndarray + Output array. + channel_names : Optional[List[str]], optional + Channel names, by default None. + data_description : Optional[str], optional + Description of the data, by default None. + + Raises + ------ + ValueError + If the model is a Custom model. + """ + path = Path(path) + + # method is not compatible with Custom models + if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM: + raise ValueError( + "Exporting Custom models to BioImage Model Zoo format is not supported." + ) + + # make sure that input and output arrays have the same shape + assert input_array.shape == output_array.shape, ( + f"Input ({input_array.shape}) and output ({output_array.shape}) arrays " + f"have different shapes" + ) + + # make sure it has the correct suffix + if path.suffix not in ".zip": + path = path.with_suffix(".zip") + + # versions + pytorch_version = __version__ + careamics_version = pkg_resources.get_distribution("careamics").version + + # save files in temporary folder + with tempfile.TemporaryDirectory() as tmpdirname: + temp_path = Path(tmpdirname) + + # create environment file + # TODO move in bioimage module + env_path = temp_path / "environment.yml" + env_path.write_text(create_env_text(pytorch_version)) + + # export input and ouputs + inputs = temp_path / "inputs.npy" + np.save(inputs, input_array) + outputs = temp_path / "outputs.npy" + np.save(outputs, output_array) + + # export configuration + config_path = save_configuration(config, temp_path) + + # export model state dictionary + weight_path = _export_state_dict(model, temp_path / "weights.pth") + + # create model description + model_description = create_model_description( + config=config, + name=name, + general_description=general_description, + authors=authors, + inputs=inputs, + outputs=outputs, + weights_path=weight_path, + torch_version=pytorch_version, + careamics_version=careamics_version, + config_path=config_path, + env_path=env_path, + channel_names=channel_names, + data_description=data_description, + ) + + # test model description + summary: ValidationSummary = test_model(model_description) + if summary.status == "failed": + raise ValueError(f"Model description test failed: {summary}") + + # save bmz model + save_bioimageio_package(model_description, output_path=path) + + +def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]: + """Load a model from a BioImage Model Zoo archive. + + Parameters + ---------- + path : Union[Path, str] + Path to the BioImage Model Zoo archive. + + Returns + ------- + Tuple[CAREamicsKiln, Configuration] + CAREamics model and configuration. + + Raises + ------ + ValueError + If the path is not a zip file. + """ + path = Path(path) + + if path.suffix != ".zip": + raise ValueError(f"Path must be a bioimage.io zip file, got {path}.") + + # load description, this creates an unzipped folder next to the archive + model_desc = load_description(path) + + # extract relative paths + weights_path, config_path = extract_model_path(model_desc) + + # create folder path and absolute paths + unzip_path = get_unzip_path(path) + weights_path = unzip_path / weights_path + config_path = unzip_path / config_path + + # load configuration + config = load_configuration(config_path) + + # create careamics lightning module + model = CAREamicsKiln(algorithm_config=config.algorithm_config) + + # load model state dictionary + _load_state_dict(model, weights_path) + + return model, config diff --git a/src/careamics/model_io/model_io_utils.py b/src/careamics/model_io/model_io_utils.py new file mode 100644 index 00000000..720ac49e --- /dev/null +++ b/src/careamics/model_io/model_io_utils.py @@ -0,0 +1,80 @@ +"""Utility functions to load pretrained models.""" + +from pathlib import Path +from typing import Tuple, Union + +from torch import load + +from careamics.config import Configuration +from careamics.lightning_module import CAREamicsKiln +from careamics.model_io.bmz_io import load_from_bmz +from careamics.utils import check_path_exists + + +def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]: + """ + Load a pretrained model from a checkpoint or a BioImage Model Zoo model. + + Expected formats are .ckpt or .zip files. + + Parameters + ---------- + path : Union[Path, str] + Path to the pretrained model. + + Returns + ------- + Tuple[CAREamicsKiln, Configuration] + Tuple of CAREamics model and its configuration. + + Raises + ------ + ValueError + If the model format is not supported. + """ + path = check_path_exists(path) + + if path.suffix == ".ckpt": + return _load_checkpoint(path) + elif path.suffix == ".zip": + return load_from_bmz(path) + else: + raise ValueError( + f"Invalid model format. Expected .ckpt or .zip, got {path.suffix}." + ) + + +def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]: + """ + Load a model from a checkpoint and return both model and configuration. + + Parameters + ---------- + path : Union[Path, str] + Path to the checkpoint. + + Returns + ------- + Tuple[CAREamicsKiln, Configuration] + Tuple of CAREamics model and its configuration. + + Raises + ------ + ValueError + If the checkpoint file does not contain hyper parameters (configuration). + """ + # load checkpoint + checkpoint: dict = load(path) + + # attempt to load configuration + try: + cfg_dict = checkpoint["hyper_parameters"] + except KeyError as e: + raise ValueError( + f"Invalid checkpoint file. No `hyper_parameters` found in the " + f"checkpoint: {checkpoint.keys()}" + ) from e + + model = CAREamicsKiln.load_from_checkpoint(path) + + return model, Configuration(**cfg_dict) diff --git a/src/careamics/models/__init__.py b/src/careamics/models/__init__.py index cb2f6bae..080514be 100644 --- a/src/careamics/models/__init__.py +++ b/src/careamics/models/__init__.py @@ -1,4 +1,7 @@ """Models package.""" -from .model_factory import create_model as create_model +__all__ = ["model_factory", "UNet"] + + +from .model_factory import model_factory from .unet import UNet as UNet diff --git a/src/careamics/models/activation.py b/src/careamics/models/activation.py new file mode 100644 index 00000000..c102fbc9 --- /dev/null +++ b/src/careamics/models/activation.py @@ -0,0 +1,35 @@ +from typing import Callable, Union + +import torch.nn as nn + +from ..config.support import SupportedActivation + + +def get_activation(activation: Union[SupportedActivation, str]) -> Callable: + """ + Get activation function. + + Parameters + ---------- + activation : str + Activation function name. + + Returns + ------- + Callable + Activation function. + """ + if activation == SupportedActivation.RELU: + return nn.ReLU() + elif activation == SupportedActivation.LEAKYRELU: + return nn.LeakyReLU() + elif activation == SupportedActivation.TANH: + return nn.Tanh() + elif activation == SupportedActivation.SIGMOID: + return nn.Sigmoid() + elif activation == SupportedActivation.SOFTMAX: + return nn.Softmax(dim=1) + elif activation == SupportedActivation.NONE: + return nn.Identity() + else: + raise ValueError(f"Activation {activation} not supported.") diff --git a/src/careamics/models/layers.py b/src/careamics/models/layers.py index 6fe378ed..fe46e3ed 100644 --- a/src/careamics/models/layers.py +++ b/src/careamics/models/layers.py @@ -3,8 +3,11 @@ This submodule contains layers used in the CAREamics models. """ +from typing import List, Optional, Tuple, Union + import torch import torch.nn as nn +from torch.nn import functional as F class Conv_Block(nn.Module): @@ -150,3 +153,244 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.dropout is not None: x = self.dropout(x) return x + + +def _unpack_kernel_size( + kernel_size: Union[Tuple[int, ...], int], dim: int +) -> Tuple[int, ...]: + """Unpack kernel_size to a tuple of ints. + + Inspired by Kornia implementation. TODO: link + """ + if isinstance(kernel_size, int): + kernel_dims = tuple([kernel_size for _ in range(dim)]) + else: + kernel_dims = kernel_size + return kernel_dims + + +def _compute_zero_padding( + kernel_size: Union[Tuple[int, ...], int], dim: int +) -> Tuple[int, ...]: + """Utility function that computes zero padding tuple.""" + kernel_dims = _unpack_kernel_size(kernel_size, dim) + return tuple([(kd - 1) // 2 for kd in kernel_dims]) + + +def get_pascal_kernel_1d( + kernel_size: int, + norm: bool = False, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Generate Yang Hui triangle (Pascal's triangle) for a given number. + + Inspired by Kornia implementation. TODO link + + Parameters + ---------- + kernel_size: height and width of the kernel. + norm: if to normalize the kernel or not. Default: False. + device: tensor device + dtype: tensor dtype + + Returns + ------- + kernel shaped as :math:`(kernel_size,)` + + Examples + -------- + >>> get_pascal_kernel_1d(1) + tensor([1.]) + >>> get_pascal_kernel_1d(2) + tensor([1., 1.]) + >>> get_pascal_kernel_1d(3) + tensor([1., 2., 1.]) + >>> get_pascal_kernel_1d(4) + tensor([1., 3., 3., 1.]) + >>> get_pascal_kernel_1d(5) + tensor([1., 4., 6., 4., 1.]) + >>> get_pascal_kernel_1d(6) + tensor([ 1., 5., 10., 10., 5., 1.]) + """ + pre: List[float] = [] + cur: List[float] = [] + for i in range(kernel_size): + cur = [1.0] * (i + 1) + + for j in range(1, i // 2 + 1): + value = pre[j - 1] + pre[j] + cur[j] = value + if i != 2 * j: + cur[-j - 1] = value + pre = cur + + out = torch.tensor(cur, device=device, dtype=dtype) + + if norm: + out = out / out.sum() + + return out + + +def _get_pascal_kernel_nd( + kernel_size: Union[Tuple[int, int], int], + norm: bool = True, + dim: int = 2, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Generate pascal filter kernel by kernel size. + + Inspired by Kornia implementation. + + Parameters + ---------- + kernel_size: height and width of the kernel. + norm: if to normalize the kernel or not. Default: True. + device: tensor device + dtype: tensor dtype + + Returns + ------- + if kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size) + otherwise the kernel will be shaped as kernel_size + + Examples + -------- + >>> _get_pascal_kernel_nd(1) + tensor([[1.]]) + >>> _get_pascal_kernel_nd(4) + tensor([[0.0156, 0.0469, 0.0469, 0.0156], + [0.0469, 0.1406, 0.1406, 0.0469], + [0.0469, 0.1406, 0.1406, 0.0469], + [0.0156, 0.0469, 0.0469, 0.0156]]) + >>> _get_pascal_kernel_nd(4, norm=False) + tensor([[1., 3., 3., 1.], + [3., 9., 9., 3.], + [3., 9., 9., 3.], + [1., 3., 3., 1.]]) + """ + kernel_dims = _unpack_kernel_size(kernel_size, dim) + + kernel = [ + get_pascal_kernel_1d(kd, device=device, dtype=dtype) for kd in kernel_dims + ] + + if dim == 2: + kernel = kernel[0][:, None] * kernel[1][None, :] + elif dim == 3: + kernel = ( + kernel[0][:, None, None] + * kernel[1][None, :, None] + * kernel[2][None, None, :] + ) + if norm: + kernel = kernel / torch.sum(kernel) + return kernel + + +def _max_blur_pool_by_kernel2d( + x: torch.Tensor, + kernel: torch.Tensor, + stride: int, + max_pool_size: int, + ceil_mode: bool, +) -> torch.Tensor: + """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel. + + Inspired by Kornia implementation. + """ + # compute local maxima + x = F.max_pool2d( + x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode + ) + # blur and downsample + padding = _compute_zero_padding((kernel.shape[-2], kernel.shape[-1]), dim=2) + return F.conv2d(x, kernel, padding=padding, stride=stride, groups=x.size(1)) + + +def _max_blur_pool_by_kernel3d( + x: torch.Tensor, + kernel: torch.Tensor, + stride: int, + max_pool_size: int, + ceil_mode: bool, +) -> torch.Tensor: + """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel. + + Inspired by Kornia implementation. + """ + # compute local maxima + x = F.max_pool3d( + x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode + ) + # blur and downsample + padding = _compute_zero_padding( + (kernel.shape[-3], kernel.shape[-2], kernel.shape[-1]), dim=3 + ) + return F.conv3d(x, kernel, padding=padding, stride=stride, groups=x.size(1)) + + +class MaxBlurPool(nn.Module): + """Compute pools and blurs and downsample a given feature map. + + Inspired by Kornia MaxBlurPool implementation. Equivalent to + ```nn.Sequential(nn.MaxPool2d(...), BlurPool2D(...))``` + + Parameters + ---------- + dim: int + Toggles between 2D and 3D + kernel_size: Union[Tuple[int, int], int] + Kernel size for max pooling. + stride: int + Stride for pooling. + max_pool_size: int + Max kernel size for max pooling. + ceil_mode: bool + Should be true to match output size of conv2d with same kernel size. + + Returns + ------- + torch.Tensor + The pooled and blurred tensor. + """ + + def __init__( + self, + dim: int, + kernel_size: Union[Tuple[int, int], int], + stride: int = 2, + max_pool_size: int = 2, + ceil_mode: bool = False, + ) -> None: + super().__init__() + self.dim = dim + self.kernel_size = kernel_size + self.stride = stride + self.max_pool_size = max_pool_size + self.ceil_mode = ceil_mode + self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the function.""" + self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype) + if self.dim == 2: + return _max_blur_pool_by_kernel2d( + x, + self.kernel.repeat((x.size(1), 1, 1, 1)), + self.stride, + self.max_pool_size, + self.ceil_mode, + ) + else: + return _max_blur_pool_by_kernel3d( + x, + self.kernel.repeat((x.size(1), 1, 1, 1, 1)), + self.stride, + self.max_pool_size, + self.ceil_mode, + ) diff --git a/src/careamics/models/model_factory.py b/src/careamics/models/model_factory.py index de4507d6..5b93622b 100644 --- a/src/careamics/models/model_factory.py +++ b/src/careamics/models/model_factory.py @@ -3,31 +3,30 @@ Model creation factory functions. """ -from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Union import torch -from careamics.bioimage import import_bioimage_model -from careamics.config import Configuration -from careamics.config.algorithm import Models -from careamics.utils.logging import get_logger - +from ..config.architectures import CustomModel, UNetModel, VAEModel, get_custom_model +from ..config.support import SupportedArchitecture +from ..utils import get_logger from .unet import UNet logger = get_logger(__name__) -def model_registry(model_name: str) -> torch.nn.Module: +def model_factory( + model_configuration: Union[UNetModel, VAEModel, CustomModel] +) -> torch.nn.Module: """ - Model factory. + Deep learning model factory. - Supported models are defined in config.algorithm.Models. + Supported models are defined in careamics.config.SupportedArchitecture. Parameters ---------- - model_name : str - Name of the model. + model_configuration : Union[UNetModel, VAEModel] + Model configuration Returns ------- @@ -37,215 +36,16 @@ def model_registry(model_name: str) -> torch.nn.Module: Raises ------ NotImplementedError - If the requested model is not implemented. - """ - if model_name == Models.UNET: - return UNet - else: - raise NotImplementedError(f"Model {model_name} is not implemented") - - -def create_model( - *, - model_path: Optional[Union[str, Path]] = None, - config: Optional[Configuration] = None, - device: Optional[torch.device] = None, -) -> Tuple[ - torch.nn.Module, - torch.optim.Optimizer, - Union[ - torch.optim.lr_scheduler.LRScheduler, - torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler - ], - torch.cuda.amp.GradScaler, - Configuration, -]: - """ - Instantiate a model from a model path or configuration. - - If both path and configuration are provided, the model path is used. The model - path should point to either a checkpoint (created during training) or a model - exported to the bioimage.io format. - - Parameters - ---------- - model_path : Optional[Union[str, Path]], optional - Path to a checkpoint or bioimage.io archive, by default None. - config : Optional[Configuration], optional - Configuration, by default None. - device : Optional[torch.device], optional - Torch device, by default None. - - Returns - ------- - torch.nn.Module - Instantiated model. - - Raises - ------ - ValueError - If the checkpoint path is invalid. - ValueError - If the checkpoint is invalid. - ValueError - If neither checkpoint nor configuration are provided. + If the requested architecture is not implemented. """ - if model_path is not None: - # Create model from checkpoint - model_path = Path(model_path) - if not model_path.exists() or model_path.suffix not in [".pth", ".zip"]: - raise ValueError( - f"Invalid model path: {model_path}. Current working dir: \ - {Path.cwd()!s}" - ) - - if model_path.suffix == ".zip": - model_path = import_bioimage_model(model_path) - - # Load checkpoint - checkpoint = torch.load(model_path, map_location=device) - - # Load the configuration - if "config" in checkpoint: - config = Configuration(**checkpoint["config"]) - algo_config = config.algorithm - model_config = algo_config.model_parameters - model_name = algo_config.model - else: - raise ValueError("Invalid checkpoint format, no configuration found.") - - # Create model - model: torch.nn.Module = model_registry(model_name)( - depth=model_config.depth, - conv_dim=algo_config.get_conv_dim(), - num_channels_init=model_config.num_channels_init, - ) - model.to(device) - - # Load the model state dict - if "model_state_dict" in checkpoint: - model.load_state_dict(checkpoint["model_state_dict"]) - logger.info("Loaded model state dict") - else: - raise ValueError("Invalid checkpoint format") - - # Load the optimizer and scheduler - optimizer, scheduler = get_optimizer_and_scheduler( - config, model, state_dict=checkpoint - ) - scaler = get_grad_scaler(config, state_dict=checkpoint) - - elif config is not None: - # Create model from configuration - algo_config = config.algorithm - model_config = algo_config.model_parameters - model_name = algo_config.model - - # Create model - model = model_registry(model_name)( - depth=model_config.depth, - conv_dim=algo_config.get_conv_dim(), - num_channels_init=model_config.num_channels_init, - ) - model.to(device) - optimizer, scheduler = get_optimizer_and_scheduler(config, model) - scaler = get_grad_scaler(config) - logger.info("Engine initialized from configuration") + if model_configuration.architecture == SupportedArchitecture.UNET: + return UNet(**model_configuration.model_dump()) + elif model_configuration.architecture == SupportedArchitecture.CUSTOM: + assert isinstance(model_configuration, CustomModel) + model = get_custom_model(model_configuration.name) + return model(**model_configuration.model_dump()) else: - raise ValueError("Either config or model_path must be provided") - - return model, optimizer, scheduler, scaler, config - - -def get_optimizer_and_scheduler( - config: Configuration, model: torch.nn.Module, state_dict: Optional[Dict] = None -) -> Tuple[ - torch.optim.Optimizer, - Union[ - torch.optim.lr_scheduler.LRScheduler, - torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler - ], -]: - """ - Create optimizer and learning rate schedulers. - - If a checkpoint state dictionary is provided, the optimizer and scheduler are - instantiated to the same state as the checkpoint's optimizer and scheduler. - - Parameters - ---------- - config : Configuration - Configuration. - model : torch.nn.Module - Model. - state_dict : Optional[Dict], optional - Checkpoint state dictionary, by default None. - - Returns - ------- - Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler] - Optimizer and scheduler. - """ - # retrieve optimizer name and parameters from config - optimizer_name = config.training.optimizer.name - optimizer_params = config.training.optimizer.parameters - - # then instantiate it - optimizer_func = getattr(torch.optim, optimizer_name) - optimizer = optimizer_func(model.parameters(), **optimizer_params) - - # same for learning rate scheduler - scheduler_name = config.training.lr_scheduler.name - scheduler_params = config.training.lr_scheduler.parameters - scheduler_func = getattr(torch.optim.lr_scheduler, scheduler_name) - scheduler = scheduler_func(optimizer, **scheduler_params) - - # load state from ther checkpoint if available - if state_dict is not None: - if "optimizer_state_dict" in state_dict: - optimizer.load_state_dict(state_dict["optimizer_state_dict"]) - logger.info("Loaded optimizer state dict") - else: - logger.warning( - "No optimizer state dict found in checkpoint. Optimizer not loaded." - ) - if "scheduler_state_dict" in state_dict: - scheduler.load_state_dict(state_dict["scheduler_state_dict"]) - logger.info("Loaded LR scheduler state dict") - else: - logger.warning( - "No LR scheduler state dict found in checkpoint. " - "LR scheduler not loaded." - ) - return optimizer, scheduler - - -def get_grad_scaler( - config: Configuration, state_dict: Optional[Dict] = None -) -> torch.cuda.amp.GradScaler: - """ - Instantiate gradscaler. - - If a checkpoint state dictionary is provided, the scaler is instantiated to the - same state as the checkpoint's scaler. - - Parameters - ---------- - config : Configuration - Configuration. - state_dict : Optional[Dict], optional - Checkpoint state dictionary, by default None. - - Returns - ------- - torch.cuda.amp.GradScaler - Instantiated gradscaler. - """ - use = config.training.amp.use - scaling = config.training.amp.init_scale - scaler = torch.cuda.amp.GradScaler(init_scale=scaling, enabled=use) - if state_dict is not None and "scaler_state_dict" in state_dict: - scaler.load_state_dict(state_dict["scaler_state_dict"]) - logger.info("Loaded GradScaler state dict") - return scaler + raise NotImplementedError( + f"Model {model_configuration.architecture} is not implemented or unknown." + ) diff --git a/src/careamics/models/unet.py b/src/careamics/models/unet.py index 494c343b..79edd699 100644 --- a/src/careamics/models/unet.py +++ b/src/careamics/models/unet.py @@ -3,12 +3,14 @@ A UNet encoder, decoder and complete model. """ -from typing import Callable, List, Optional +from typing import Any, List, Union import torch import torch.nn as nn -from .layers import Conv_Block +from ..config.support import SupportedActivation +from .activation import get_activation +from .layers import Conv_Block, MaxBlurPool class UnetEncoder(nn.Module): @@ -42,6 +44,7 @@ def __init__( use_batch_norm: bool = True, dropout: float = 0.0, pool_kernel: int = 2, + n2v2: bool = False, ) -> None: """ Constructor. @@ -65,7 +68,11 @@ def __init__( """ super().__init__() - self.pooling = getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel) + self.pooling = ( + getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel) + if not n2v2 + else MaxBlurPool(dim=conv_dim, kernel_size=3, max_pool_size=pool_kernel) + ) encoder_blocks = [] @@ -82,7 +89,6 @@ def __init__( ) ) encoder_blocks.append(self.pooling) - self.encoder_blocks = nn.ModuleList(encoder_blocks) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: @@ -134,6 +140,7 @@ def __init__( num_channels_init: int = 64, use_batch_norm: bool = True, dropout: float = 0.0, + n2v2: bool = False, ) -> None: """ Constructor. @@ -157,6 +164,9 @@ def __init__( scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear" ) in_channels = out_channels = num_channels_init * 2 ** (depth - 1) + + self.n2v2 = n2v2 + self.bottleneck = Conv_Block( conv_dim, in_channels=in_channels, @@ -169,12 +179,18 @@ def __init__( decoder_blocks = [] for n in range(depth): decoder_blocks.append(upsampling) - in_channels = num_channels_init * 2 ** (depth - n) - out_channels = num_channels_init + in_channels = ( + num_channels_init ** (depth - n) + if (self.n2v2 and n == depth - 1) + else num_channels_init * 2 ** (depth - n) + ) + out_channels = in_channels // 2 decoder_blocks.append( Conv_Block( conv_dim, - in_channels=in_channels, + in_channels=in_channels + in_channels // 2 + if n > 0 + else in_channels, out_channels=out_channels, intermediate_channel_multiplier=2, dropout_perc=dropout, @@ -200,13 +216,19 @@ def forward(self, *features: List[torch.Tensor]) -> torch.Tensor: torch.Tensor Output of the decoder. """ - x = features[0] - skip_connections = features[1:][::-1] + x: torch.Tensor = features[0] + skip_connections: torch.Tensor = features[1:][::-1] + x = self.bottleneck(x) + for i, module in enumerate(self.decoder_blocks): x = module(x) if isinstance(module, nn.Upsample): - x = torch.cat([x, skip_connections[i // 2]], axis=1) + if self.n2v2: + if x.shape != skip_connections[-1].shape: + x = torch.cat([x, skip_connections[i // 2]], axis=1) + else: + x = torch.cat([x, skip_connections[i // 2]], axis=1) return x @@ -214,12 +236,12 @@ class UNet(nn.Module): """ UNet model. - Adapted for PyTorch from + Adapted for PyTorch from: https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py. Parameters ---------- - conv_dim : int + conv_dims : int Number of dimensions of the convolution layers (2 or 3). num_classes : int, optional Number of classes to predict, by default 1. @@ -241,7 +263,7 @@ class UNet(nn.Module): def __init__( self, - conv_dim: int, + conv_dims: int, num_classes: int = 1, in_channels: int = 1, depth: int = 3, @@ -249,14 +271,16 @@ def __init__( use_batch_norm: bool = True, dropout: float = 0.0, pool_kernel: int = 2, - last_activation: Optional[Callable] = None, + final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE, + n2v2: bool = False, + **kwargs: Any, ) -> None: """ Constructor. Parameters ---------- - conv_dim : int + conv_dims : int Number of dimensions of the convolution layers (2 or 3). num_classes : int, optional Number of classes to predict, by default 1. @@ -278,28 +302,30 @@ def __init__( super().__init__() self.encoder = UnetEncoder( - conv_dim, + conv_dims, in_channels=in_channels, depth=depth, num_channels_init=num_channels_init, use_batch_norm=use_batch_norm, dropout=dropout, pool_kernel=pool_kernel, + n2v2=n2v2, ) self.decoder = UnetDecoder( - conv_dim, + conv_dims, depth=depth, num_channels_init=num_channels_init, use_batch_norm=use_batch_norm, dropout=dropout, + n2v2=n2v2, ) - self.final_conv = getattr(nn, f"Conv{conv_dim}d")( + self.final_conv = getattr(nn, f"Conv{conv_dims}d")( in_channels=num_channels_init, out_channels=num_classes, kernel_size=1, ) - self.last_activation = last_activation if last_activation else nn.Identity() + self.final_activation = get_activation(final_activation) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -318,5 +344,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: encoder_features = self.encoder(x) x = self.decoder(*encoder_features) x = self.final_conv(x) - x = self.last_activation(x) + x = self.final_activation(x) return x diff --git a/src/careamics/prediction/__init__.py b/src/careamics/prediction/__init__.py index a11cfceb..852e65de 100644 --- a/src/careamics/prediction/__init__.py +++ b/src/careamics/prediction/__init__.py @@ -2,8 +2,6 @@ __all__ = [ "stitch_prediction", - "tta_backward", - "tta_forward", ] -from .prediction_utils import stitch_prediction, tta_backward, tta_forward +from .stitch_prediction import stitch_prediction diff --git a/src/careamics/prediction/prediction_utils.py b/src/careamics/prediction/prediction_utils.py deleted file mode 100644 index f0139de5..00000000 --- a/src/careamics/prediction/prediction_utils.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -Prediction convenience functions. - -These functions are used during prediction. -""" -from typing import List - -import numpy as np -import torch - - -def stitch_prediction( - tiles: List[np.ndarray], - stitching_data: List, -) -> np.ndarray: - """ - Stitch tiles back together to form a full image. - - Parameters - ---------- - tiles : List[Tuple[np.ndarray, List[int]]] - Cropped tiles and their respective stitching coordinates. - stitching_data : List - List of coordinates obtained from - dataset.tiling.compute_crop_and_stitch_coords_1d. - - Returns - ------- - np.ndarray - Full image. - """ - # Get whole sample shape - input_shape = stitching_data[0][0] - predicted_image = np.zeros(input_shape, dtype=np.float32) - for tile, (_, overlap_crop_coords, stitch_coords) in zip(tiles, stitching_data): - # Compute coordinates for cropping predicted tile - slices = tuple([slice(c[0], c[1]) for c in overlap_crop_coords]) - - # Crop predited tile according to overlap coordinates - cropped_tile = tile.squeeze()[slices] - - # Insert cropped tile into predicted image using stitch coordinates - predicted_image[ - (..., *[slice(c[0], c[1]) for c in stitch_coords]) - ] = cropped_tile - return predicted_image - - -def tta_forward(x: torch.Tensor) -> List[torch.Tensor]: - """ - Augment 8-fold an array. - - The augmentation is performed using all 90 deg rotations and their flipped version, - as well as the original image flipped. - - Tensors should be of shape SC(Z)YX, with S and C potentially singleton dimensions. - - Parameters - ---------- - x : torch.Tensor - Data to augment. - - Returns - ------- - List - Stack of augmented images. - """ - x_aug = [ - x, - torch.rot90(x, 1, dims=(2, 3)), - torch.rot90(x, 2, dims=(2, 3)), - torch.rot90(x, 3, dims=(2, 3)), - ] - x_aug_flip = x_aug.copy() - for x_ in x_aug: - x_aug_flip.append(torch.flip(x_, dims=(1, 3))) - return x_aug_flip - - -def tta_backward(x_aug: List[torch.Tensor]) -> np.ndarray: - """ - Invert `tta_forward` and average the 8 images. - - The function takes a list of torch tensors and returns a numpy array. - - Parameters - ---------- - x_aug : List[torch.Tensor] - Stack of 8-fold augmented images. - - Returns - ------- - np.ndarray - Average of de-augmented x_aug. - """ - x_deaug = [ - x_aug[0].numpy(), - np.rot90(x_aug[1], -1, axes=(2, 3)), - np.rot90(x_aug[2], -2, axes=(2, 3)), - np.rot90(x_aug[3], -3, axes=(2, 3)), - np.flip(x_aug[4].numpy(), axis=(1, 3)), - np.rot90(np.flip(x_aug[5].numpy(), axis=(1, 3)), -1, axes=(2, 3)), - np.rot90(np.flip(x_aug[6].numpy(), axis=(1, 3)), -2, axes=(2, 3)), - np.rot90(np.flip(x_aug[7].numpy(), axis=(1, 3)), -3, axes=(2, 3)), - ] - return np.mean(x_deaug, 0) diff --git a/src/careamics/prediction/stitch_prediction.py b/src/careamics/prediction/stitch_prediction.py new file mode 100644 index 00000000..5e0ee7e1 --- /dev/null +++ b/src/careamics/prediction/stitch_prediction.py @@ -0,0 +1,73 @@ +""" +Prediction convenience functions. + +These functions are used during prediction. +""" +from typing import List + +import numpy as np +import torch + + +def stitch_prediction( + tiles: List[torch.Tensor], + stitching_data: List[List[torch.Tensor]], +) -> torch.Tensor: + """ + Stitch tiles back together to form a full image. + + Parameters + ---------- + tiles : List[torch.Tensor] + Cropped tiles and their respective stitching coordinates. + stitching_coords : List + List of information and coordinates obtained from + `dataset.tiled_patching.extract_tiles`. + + Returns + ------- + np.ndarray + Full image. + """ + # retrieve whole array size, there is two cases to consider: + # 1. the tiles are stored in a list + # 2. the tiles are stored in a list with batches along the first dim + if tiles[0].shape[0] > 1: + input_shape = np.array( + [el.numpy() for el in stitching_data[0][0][0]], dtype=int + ).squeeze() + else: + input_shape = np.array( + [el.numpy() for el in stitching_data[0][0]], dtype=int + ).squeeze() + + # TODO should use torch.zeros instead of np.zeros + predicted_image = torch.Tensor(np.zeros(input_shape, dtype=np.float32)) + + for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip( + tiles, stitching_data + ): + for batch_idx in range(tile_batch.shape[0]): + # Compute coordinates for cropping predicted tile + slices = tuple( + [ + slice(c[0][batch_idx], c[1][batch_idx]) + for c in overlap_crop_coords_batch + ] + ) + + # Crop predited tile according to overlap coordinates + cropped_tile = tile_batch[batch_idx].squeeze()[slices] + + # Insert cropped tile into predicted image using stitch coordinates + predicted_image[ + ( + ..., + *[ + slice(c[0][batch_idx], c[1][batch_idx]) + for c in stitch_coords_batch + ], + ) + ] = cropped_tile.to(torch.float32) + + return predicted_image diff --git a/src/careamics/transforms/__init__.py b/src/careamics/transforms/__init__.py new file mode 100644 index 00000000..dd4c2604 --- /dev/null +++ b/src/careamics/transforms/__init__.py @@ -0,0 +1,41 @@ +"""Transforms that are used to augment the data.""" + +__all__ = [ + "get_all_transforms", + "N2VManipulate", + "NDFlip", + "XYRandomRotate90", + "ImageRestorationTTA", + "Denormalize", + "Normalize", +] + + +from .n2v_manipulate import N2VManipulate +from .nd_flip import NDFlip +from .normalize import Denormalize, Normalize +from .tta import ImageRestorationTTA +from .xy_random_rotate90 import XYRandomRotate90 + +ALL_TRANSFORMS = { + "Normalize": Normalize, + "N2VManipulate": N2VManipulate, + "NDFlip": NDFlip, + "XYRandomRotate90": XYRandomRotate90, +} + + +def get_all_transforms() -> dict: + """Return all the transforms accepted by CAREamics. + + Note that while CAREamics accepts any `Compose` transforms from Albumentations (see + https://albumentations.ai/), only a few transformations are explicitely supported + (see `SupportedTransform`). + + Returns + ------- + dict + A dictionary with all the transforms accepted by CAREamics, where the keys are + the transform names and the values are the transform classes. + """ + return ALL_TRANSFORMS diff --git a/src/careamics/transforms/n2v_manipulate.py b/src/careamics/transforms/n2v_manipulate.py new file mode 100644 index 00000000..afc34b3b --- /dev/null +++ b/src/careamics/transforms/n2v_manipulate.py @@ -0,0 +1,113 @@ +from typing import Any, Literal, Optional, Tuple + +import numpy as np +from albumentations import ImageOnlyTransform + +from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis + +from .pixel_manipulation import median_manipulate, uniform_manipulate +from .struct_mask_parameters import StructMaskParameters + + +class N2VManipulate(ImageOnlyTransform): + """ + Default augmentation for the N2V model. + + This transform expects (Z)YXC dimensions. + + Parameters + ---------- + mask_pixel_percentage : float + Approximate percentage of pixels to be masked. + roi_size : int + Size of the ROI the new pixel value is sampled from, by default 11. + """ + + def __init__( + self, + roi_size: int = 11, + masked_pixel_percentage: float = 0.2, + strategy: Literal[ + "uniform", "median" + ] = SupportedPixelManipulation.UNIFORM.value, + remove_center: bool = True, + struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none", + struct_mask_span: int = 5, + ): + """Constructor. + + Parameters + ---------- + roi_size : int, optional + Size of the replacement area, by default 11 + masked_pixel_percentage : float, optional + Percentage of pixels to mask, by default 0.2 + strategy : Literal[ "uniform", "median" ], optional + Replaccement strategy, uniform or median, by default uniform + remove_center : bool, optional + Whether to remove central pixel from patch, by default True + struct_mask_axis : Literal["horizontal", "vertical", "none"], optional + StructN2V mask axis, by default "none" + struct_mask_span : int, optional + StructN2V mask span, by default 5 + """ + super().__init__(p=1) + self.masked_pixel_percentage = masked_pixel_percentage + self.roi_size = roi_size + self.strategy = strategy + self.remove_center = remove_center + + if struct_mask_axis == SupportedStructAxis.NONE: + self.struct_mask: Optional[StructMaskParameters] = None + else: + self.struct_mask = StructMaskParameters( + axis=0 if struct_mask_axis == SupportedStructAxis.HORIZONTAL else 1, + span=struct_mask_span, + ) + + def apply( + self, patch: np.ndarray, **kwargs: Any + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Apply the transform to the image. + + Parameters + ---------- + image : np.ndarray + Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + """ + masked = np.zeros_like(patch) + mask = np.zeros_like(patch) + if self.strategy == SupportedPixelManipulation.UNIFORM: + # Iterate over the channels to apply manipulation separately + for c in range(patch.shape[-1]): + masked[..., c], mask[..., c] = uniform_manipulate( + patch=patch[..., c], + mask_pixel_percentage=self.masked_pixel_percentage, + subpatch_size=self.roi_size, + remove_center=self.remove_center, + struct_params=self.struct_mask, + ) + elif self.strategy == SupportedPixelManipulation.MEDIAN: + # Iterate over the channels to apply manipulation separately + for c in range(patch.shape[-1]): + masked[..., c], mask[..., c] = median_manipulate( + patch=patch[..., c], + mask_pixel_percentage=self.masked_pixel_percentage, + subpatch_size=self.roi_size, + struct_params=self.struct_mask, + ) + else: + raise ValueError(f"Unknown masking strategy ({self.strategy}).") + + # TODO why return patch? + return masked, patch, mask + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + """Get the transform parameters. + + Returns + ------- + Tuple[str, ...] + Transform parameters. + """ + return ("roi_size", "masked_pixel_percentage", "strategy", "struct_mask") diff --git a/src/careamics/transforms/nd_flip.py b/src/careamics/transforms/nd_flip.py new file mode 100644 index 00000000..968b4d19 --- /dev/null +++ b/src/careamics/transforms/nd_flip.py @@ -0,0 +1,93 @@ +from typing import Any, Dict, Tuple + +import numpy as np +from albumentations import DualTransform + + +class NDFlip(DualTransform): + """Flip ND arrays on a single axis. + + This transform ignores singleton axes and randomly flips one of the other + axes, to the exception of the first and last axes (sample and channels). + + This transform expects (Z)YXC dimensions. + """ + + def __init__(self, p: float = 0.5, is_3D: bool = False, flip_z: bool = True): + """Constructor. + + Parameters + ---------- + p : float, optional + Probability to apply the transform, by default 0.5 + is_3D : bool, optional + Whether the data is 3D, by default False + flip_z : bool, optional + Whether to flip Z dimension, by default True + """ + super().__init__(p=p) + + self.is_3D = is_3D + self.flip_z = flip_z + + # "flippable" axes + if is_3D: + self.axis_indices = [0, 1, 2] if flip_z else [1, 2] + else: + self.axis_indices = [0, 1] + + def get_params(self, **kwargs: Any) -> Dict[str, int]: + """Get the transform parameters. + + Returns + ------- + Dict[str, int] + Transform parameters. + """ + return {"flip_axis": np.random.choice(self.axis_indices)} + + def apply(self, patch: np.ndarray, flip_axis: int, **kwargs: Any) -> np.ndarray: + """Apply the transform to the image. + + Parameters + ---------- + patch : np.ndarray + Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + flip_axis : int + Axis along which to flip the patch. + """ + if len(patch.shape) == 3 and self.is_3D: + raise ValueError( + "Incompatible patch shape and dimensionality. ZYXC patch shape " + "expected, but got YXC shape." + ) + + return np.ascontiguousarray(np.flip(patch, axis=flip_axis)) + + def apply_to_mask( + self, mask: np.ndarray, flip_axis: int, **kwargs: Any + ) -> np.ndarray: + """Apply the transform to the mask. + + Parameters + ---------- + mask : np.ndarray + Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + """ + if len(mask.shape) == 3 and self.is_3D: + raise ValueError( + "Incompatible mask shape and dimensionality. ZYXC patch shape " + "expected, but got YXC shape." + ) + + return np.ascontiguousarray(np.flip(mask, axis=flip_axis)) + + def get_transform_init_args_names(self, **kwargs: Any) -> Tuple[str, ...]: + """Get the transform arguments names. + + Returns + ------- + Tuple[str, ...] + Transform arguments names. + """ + return ("is_3D", "flip_z") diff --git a/src/careamics/transforms/normalize.py b/src/careamics/transforms/normalize.py new file mode 100644 index 00000000..4ec529b0 --- /dev/null +++ b/src/careamics/transforms/normalize.py @@ -0,0 +1,109 @@ +from typing import Any + +import numpy as np +from albumentations import DualTransform + + +class Normalize(DualTransform): + """ + Normalize an image or image patch. + + Normalization is a zero mean and unit variance. This transform expects (Z)YXC + dimensions. + + Not that an epsilon value of 1e-6 is added to the standard deviation to avoid + division by zero and that it returns a float32 image. + + Attributes + ---------- + mean : float + Mean value. + std : float + Standard deviation value. + eps : float + Epsilon value to avoid division by zero. + """ + + def __init__( + self, + mean: float, + std: float, + ): + super().__init__(always_apply=True, p=1) + + self.mean = mean + self.std = std + self.eps = 1e-6 + + def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray: + """ + Apply the transform to the image. + + Parameters + ---------- + patch : np.ndarray + Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + + Returns + ------- + np.ndarray + Normalized image or image patch. + """ + return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32) + + def apply_to_mask(self, mask: np.ndarray, **kwargs: Any) -> np.ndarray: + """ + Apply the transform to the mask. + + The mask is returned as is. + + Parameters + ---------- + mask : np.ndarray + Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + """ + return mask + + +class Denormalize(DualTransform): + """ + Denormalize an image or image patch. + + Denormalization is performed expecting a zero mean and unit variance input. This + transform expects (Z)YXC dimensions. + + Not that an epsilon value of 1e-6 is added to the standard deviation to avoid + division by zero during the normalization step, which is taken into account during + denormalization. + + Attributes + ---------- + mean : float + Mean value. + std : float + Standard deviation value. + eps : float + Epsilon value to avoid division by zero. + """ + + def __init__( + self, + mean: float, + std: float, + ): + super().__init__(always_apply=True, p=1) + + self.mean = mean + self.std = std + self.eps = 1e-6 + + def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray: + """ + Apply the transform to the image. + + Parameters + ---------- + patch : np.ndarray + Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + """ + return patch * (self.std + self.eps) + self.mean diff --git a/src/careamics/transforms/pixel_manipulation.py b/src/careamics/transforms/pixel_manipulation.py new file mode 100644 index 00000000..03e40e89 --- /dev/null +++ b/src/careamics/transforms/pixel_manipulation.py @@ -0,0 +1,383 @@ +""" +Pixel manipulation methods. + +Pixel manipulation is used in N2V and similar algorithm to replace the value of +masked pixels. +""" +from typing import Optional, Tuple, Union + +import numpy as np + +from .struct_mask_parameters import StructMaskParameters + + +def _apply_struct_mask( + patch: np.ndarray, coords: np.ndarray, struct_params: StructMaskParameters +) -> np.ndarray: + """Applies structN2V masks to patch. + + Each point in `coords` corresponds to the center of a mask, masks are paremeterized + by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by + a random value. + + Note that the structN2V mask is applied in 2D at the coordinates given by `coords`. + + Parameters + ---------- + patch : np.ndarray + Patch to be manipulated, 2D or 3D. + coords : np.ndarray + Coordinates of the ROI(subpatch) centers. + struct_params : StructMaskParameters + Parameters for the structN2V mask (axis and span). + + Returns + ------- + np.ndarray + Patch with the structN2V mask applied. + """ + # relative axis + moving_axis = -1 - struct_params.axis + + # Create a mask array + mask = np.expand_dims( + np.ones(struct_params.span), axis=list(range(len(patch.shape) - 1)) + ) # (1, 1, span) or (1, span) + + # Move the moving axis to the correct position + # i.e. the axis along which the coordinates should change + mask = np.moveaxis(mask, -1, moving_axis) + center = np.array(mask.shape) // 2 + + # Mark the center + mask[tuple(center.T)] = 0 + + # displacements from center + dx = np.indices(mask.shape)[:, mask == 1] - center[:, None] + + # combine all coords (ndim, npts,) with all displacements (ncoords,ndim,) + mix = dx.T[..., None] + coords.T[None] + mix = mix.transpose([1, 0, 2]).reshape([mask.ndim, -1]).T + + # delete entries that are out of bounds + mix = np.delete(mix, mix[:, moving_axis] < 0, axis=0) + + max_bound = patch.shape[moving_axis] - 1 + mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0) + + # replace neighbouring pixels with random values from flat dist + patch[tuple(mix.T)] = np.random.uniform(patch.min(), patch.max(), size=mix.shape[0]) + + return patch + + +def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray: + """ + Randomly sample a jitter to be applied to the masking grid. + + This is done to account for cases where the step size is not an integer. + + Parameters + ---------- + step : float + Step size of the grid, output of np.linspace. + rng : np.random.Generator + Random number generator. + + Returns + ------- + np.ndarray + Array of random jitter to be added to the grid. + """ + # Define the random jitter to be added to the grid + odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2)) + + # Round the step size to the nearest integer depending on the jitter + return np.floor(step) if odd_jitter == 0 else np.ceil(step) + + +def _get_stratified_coords( + mask_pixel_perc: float, shape: Union[Tuple[int, int], Tuple[int, int, int]] +) -> np.ndarray: + """ + Generate coordinates of the pixels to mask. + + Randomly selects the coordinates of the pixels to mask in a stratified way, i.e. + the distance between masked pixels is approximately the same. + + Parameters + ---------- + mask_pixel_perc : float + Actual (quasi) percentage of masked pixels across the whole image. Used in + calculating the distance between masked pixels across each axis. + shape : Tuple[int, ...] + Shape of the input patch. + + Returns + ------- + np.ndarray + Array of coordinates of the masked pixels. + """ + if len(shape) < 2 or len(shape) > 3: + raise ValueError( + "Calculating coordinates is only possible for 2D and 3D patches" + ) + + rng = np.random.default_rng() + + mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype( + np.int32 + ) + + # Define a grid of coordinates for each axis in the input patch and the step size + pixel_coords = [] + steps = [] + for axis_size in shape: + # make sure axis size is evenly divisible by box size + num_pixels = int(np.ceil(axis_size / mask_pixel_distance)) + axis_pixel_coords, step = np.linspace( + 0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True + ) + # explain + pixel_coords.append(axis_pixel_coords.T) + steps.append(step) + + # Create a meshgrid of coordinates for each axis in the input patch + coordinate_grid_list = np.meshgrid(*pixel_coords) + coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T + + grid_random_increment = rng.integers( + _odd_jitter_func(float(max(steps)), rng) + * np.ones_like(coordinate_grid).astype(np.int32) + - 1, + size=coordinate_grid.shape, + endpoint=True, + ) + coordinate_grid += grid_random_increment + coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1) + return coordinate_grid + + +def _create_subpatch_center_mask( + subpatch: np.ndarray, center_coords: np.ndarray +) -> np.ndarray: + """Create a mask with the center of the subpatch masked. + + Parameters + ---------- + subpatch : np.ndarray + Subpatch to be manipulated. + center_coords : np.ndarray + Coordinates of the original center before possible crop. + + Returns + ------- + np.ndarray + Mask with the center of the subpatch masked. + """ + mask = np.ones(subpatch.shape) + mask[tuple(center_coords)] = 0 + return np.ma.make_mask(mask) # type: ignore + + +def _create_subpatch_struct_mask( + subpatch: np.ndarray, center_coords: np.ndarray, struct_params: StructMaskParameters +) -> np.ndarray: + """Create a structN2V mask for the subpatch. + + Parameters + ---------- + subpatch : np.ndarray + Subpatch to be manipulated. + center_coords : np.ndarray + Coordinates of the original center before possible crop. + struct_params : StructMaskParameters + Parameters for the structN2V mask (axis and span). + + Returns + ------- + np.ndarray + StructN2V mask for the subpatch. + """ + # Create a mask with the center of the subpatch masked + mask_placeholder = np.ones(subpatch.shape) + + # reshape to move the struct axis to the first position + mask_reshaped = np.moveaxis(mask_placeholder, struct_params.axis, 0) + + # create the mask index for the struct axis + mask_index = slice( + max(0, center_coords.take(struct_params.axis) - (struct_params.span - 1) // 2), + min( + 1 + center_coords.take(struct_params.axis) + (struct_params.span - 1) // 2, + subpatch.shape[struct_params.axis], + ), + ) + mask_reshaped[struct_params.axis][mask_index] = 0 + + # reshape back to the original shape + mask = np.moveaxis(mask_reshaped, 0, struct_params.axis) + + return np.ma.make_mask(mask) # type: ignore + + +def uniform_manipulate( + patch: np.ndarray, + mask_pixel_percentage: float, + subpatch_size: int = 11, + remove_center: bool = True, + struct_params: Optional[StructMaskParameters] = None, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Manipulate pixels by replacing them with a neighbor values. + + Manipulated pixels are selected unformly selected in a subpatch, away from a grid + with an approximate uniform probability to be selected across the whole patch. + If `struct_params` is not None, an additional structN2V mask is applied to the + data, replacing the pixels in the mask with random values (excluding the pixel + already manipulated). + + Parameters + ---------- + patch : np.ndarray + Image patch, 2D or 3D, shape (y, x) or (z, y, x). + mask_pixel_percentage : float + Approximate percentage of pixels to be masked. + subpatch_size : int + Size of the subpatch the new pixel value is sampled from, by default 11. + remove_center : bool + Whether to remove the center pixel from the subpatch, by default False. See + uniform with/without central pixel in the documentation. #TODO add link + struct_params: Optional[StructMaskParameters] + Parameters for the structN2V mask (axis and span). + + Returns + ------- + Tuple[np.ndarray] + Tuple containing the manipulated patch and the corresponding mask. + """ + # Get the coordinates of the pixels to be replaced + transformed_patch = patch.copy() + + subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape) + rng = np.random.default_rng() + + # Generate coordinate grid for subpatch + roi_span_full = np.arange( + -np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2) + ).astype(np.int32) + + # Remove the center pixel from the grid if needed + roi_span = roi_span_full[roi_span_full != 0] if remove_center else roi_span_full + + # Randomly select coordinates from the grid + random_increment = rng.choice(roi_span, size=subpatch_centers.shape) + + # Clip the coordinates to the patch size + replacement_coords = np.clip( + subpatch_centers + random_increment, + 0, + [patch.shape[i] - 1 for i in range(len(patch.shape))], + ) + + # Get the replacement pixels from all subpatchs + replacement_pixels = patch[tuple(replacement_coords.T.tolist())] + + # Replace the original pixels with the replacement pixels + transformed_patch[tuple(subpatch_centers.T.tolist())] = replacement_pixels + mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8) + + if struct_params is not None: + transformed_patch = _apply_struct_mask( + transformed_patch, subpatch_centers, struct_params + ) + + return ( + transformed_patch, + mask, + ) + + +def median_manipulate( + patch: np.ndarray, + mask_pixel_percentage: float, + subpatch_size: int = 11, + struct_params: Optional[StructMaskParameters] = None, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Manipulate pixels by replacing them with the median of their surrounding subpatch. + + N2V2 version, manipulated pixels are selected randomly away from a grid with an + approximate uniform probability to be selected across the whole patch. + + If `struct_params` is not None, an additional structN2V mask is applied to the data, + replacing the pixels in the mask with random values (excluding the pixel already + manipulated). + + Parameters + ---------- + patch : np.ndarray + Image patch, 2D or 3D, shape (y, x) or (z, y, x). + mask_pixel_percentage : floar + Approximate percentage of pixels to be masked. + subpatch_size : int + Size of the subpatch the new pixel value is sampled from, by default 11. + struct_params: Optional[StructMaskParameters] + Parameters for the structN2V mask (axis and span). + + Returns + ------- + Tuple[np.ndarray] + Tuple containing the manipulated patch, the original patch and the mask. + """ + transformed_patch = patch.copy() + + # Get the coordinates of the pixels to be replaced + subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape) + + # Generate coordinate grid for subpatch + roi_span = np.array( + [-np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)] + ).astype(np.int32) + + subpatch_crops_span_full = subpatch_centers[np.newaxis, ...].T + roi_span + + # Dimensions n dims, n centers, (min, max) + subpatch_crops_span_clipped = np.clip( + subpatch_crops_span_full, + a_min=np.zeros_like(patch.shape)[:, np.newaxis, np.newaxis], + a_max=np.array(patch.shape)[:, np.newaxis, np.newaxis], + ) + + for idx in range(subpatch_crops_span_clipped.shape[1]): + subpatch_coords = subpatch_crops_span_clipped[:, idx, ...] + idxs = [ + slice(x[0], x[1]) if x[1] - x[0] > 0 else slice(0, 1) + for x in subpatch_coords + ] + subpatch = patch[tuple(idxs)] + subpatch_center_adjusted = subpatch_centers[idx] - subpatch_coords[:, 0] + + if struct_params is None: + subpatch_mask = _create_subpatch_center_mask( + subpatch, subpatch_center_adjusted + ) + else: + subpatch_mask = _create_subpatch_struct_mask( + subpatch, subpatch_center_adjusted, struct_params + ) + transformed_patch[tuple(subpatch_centers[idx])] = np.median( + subpatch[subpatch_mask] + ) + + mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8) + + if struct_params is not None: + transformed_patch = _apply_struct_mask( + transformed_patch, subpatch_centers, struct_params + ) + + return ( + transformed_patch, + mask, + ) diff --git a/src/careamics/transforms/struct_mask_parameters.py b/src/careamics/transforms/struct_mask_parameters.py new file mode 100644 index 00000000..48e49ef7 --- /dev/null +++ b/src/careamics/transforms/struct_mask_parameters.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from typing import Literal + + +@dataclass +class StructMaskParameters: + """Parameters of structN2V masks. + + Parameters + ---------- + axis : Literal[0, 1] + Axis along which to apply the mask, horizontal (0) or vertical (1). + span : int + Span of the mask. + """ + + axis: Literal[0, 1] + span: int diff --git a/src/careamics/transforms/tta.py b/src/careamics/transforms/tta.py new file mode 100644 index 00000000..39e65950 --- /dev/null +++ b/src/careamics/transforms/tta.py @@ -0,0 +1,74 @@ +"""Test-time augmentations.""" +from typing import List + +import numpy as np +from torch import Tensor, flip, mean, rot90, stack + + +# TODO add tests +class ImageRestorationTTA: + """ + Test-time augmentation for image restoration tasks. + + The augmentation is performed using all 90 deg rotations and their flipped version, + as well as the original image flipped. + + Tensors should be of shape SC(Z)YX + + This transformation is used in the LightningModule in order to perform test-time + agumentation. + """ + + def __init__(self) -> None: + """Constructor.""" + pass + + def forward(self, x: Tensor) -> List[Tensor]: + """ + Apply test-time augmentation to the input tensor. + + Parameters + ---------- + x : Tensor + Input tensor, shape SC(Z)YX. + + Returns + ------- + List[Tensor] + List of augmented tensors. + """ + augmented = [ + x, + rot90(x, 1, dims=(-2, -1)), + rot90(x, 2, dims=(-2, -1)), + rot90(x, 3, dims=(-2, -1)), + ] + augmented_flip = augmented.copy() + for x_ in augmented: + augmented_flip.append(flip(x_, dims=(-3, -1))) + return augmented_flip + + def backward(self, x: List[Tensor]) -> np.ndarray: + """Undo the test-time augmentation. + + Parameters + ---------- + x : Any + List of augmented tensors. + + Returns + ------- + Any + Original tensor. + """ + reverse = [ + x[0], + rot90(x[1], -1, dims=(-2, -1)), + rot90(x[2], -2, dims=(-2, -1)), + rot90(x[3], -3, dims=(-2, -1)), + flip(x[4], dims=(-3, -1)), + rot90(flip(x[5], dims=(-3, -1)), -1, dims=(-2, -1)), + rot90(flip(x[6], dims=(-3, -1)), -2, dims=(-2, -1)), + rot90(flip(x[7], dims=(-3, -1)), -3, dims=(-2, -1)), + ] + return mean(stack(reverse), dim=0) diff --git a/src/careamics/transforms/xy_random_rotate90.py b/src/careamics/transforms/xy_random_rotate90.py new file mode 100644 index 00000000..5333e036 --- /dev/null +++ b/src/careamics/transforms/xy_random_rotate90.py @@ -0,0 +1,95 @@ +from typing import Any, Dict, Tuple + +import numpy as np +from albumentations import DualTransform + + +class XYRandomRotate90(DualTransform): + """Applies random 90 degree rotations to the YX axis. + + This transform expects (Z)YXC dimensions. + + Parameters + ---------- + p : int, optional + Probability to apply the transform, by default 0.5 + is_3D : bool, optional + Whether the patches are 3D, by default False + """ + + def __init__(self, p: float = 0.5, is_3D: bool = False): + """Constructor. + + Parameters + ---------- + p : float, optional + Probability to apply the transform, by default 0.5 + is_3D : bool, optional + Whether the patches are 3D, by default False + """ + super().__init__(p=p) + + self.is_3D = is_3D + + # rotation axes + if is_3D: + self.axes = (1, 2) + else: + self.axes = (0, 1) + + def get_params(self, **kwargs: Any) -> Dict[str, int]: + """Get the transform parameters. + + Returns + ------- + Dict[str, int] + Transform parameters. + """ + return {"n_rotations": np.random.randint(1, 4)} + + def apply(self, patch: np.ndarray, n_rotations: int, **kwargs: Any) -> np.ndarray: + """Apply the transform to the image. + + Parameters + ---------- + patch : np.ndarray + Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + flip_axis : int + Axis along which to flip the patch. + """ + if len(patch.shape) == 3 and self.is_3D: + raise ValueError( + "Incompatible patch shape and dimensionality. ZYXC patch shape " + "expected, but got YXC shape." + ) + + return np.ascontiguousarray(np.rot90(patch, k=n_rotations, axes=self.axes)) + + def apply_to_mask( + self, mask: np.ndarray, n_rotations: int, **kwargs: Any + ) -> np.ndarray: + """Apply the transform to the mask. + + Parameters + ---------- + mask : np.ndarray + Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + """ + if len(mask.shape) != 4 and self.is_3D: + raise ValueError( + "Incompatible mask shape and dimensionality. ZYXC patch shape " + "expected, but got YXC shape." + ) + + return np.ascontiguousarray(np.rot90(mask, k=n_rotations, axes=self.axes)) + + def get_transform_init_args_names(self) -> Tuple[str, str]: + """ + Get the transform arguments. + + Returns + ------- + Tuple[str] + Transform arguments. + """ + return ("p", "is_3D") diff --git a/src/careamics/utils/__init__.py b/src/careamics/utils/__init__.py index beeaa0dd..3207fe2b 100644 --- a/src/careamics/utils/__init__.py +++ b/src/careamics/utils/__init__.py @@ -2,19 +2,17 @@ __all__ = [ - "denormalize", - "normalize", - "get_device", - "check_axes_validity", - "add_axes", - "check_tiling_validity", "cwd", - "MetricTracker", + "get_ram_size", + "check_path_exists", + "BaseEnum", + "get_logger", + "get_careamics_home", ] -from .context import cwd -from .metrics import MetricTracker -from .normalization import denormalize, normalize -from .torch_utils import get_device -from .validators import add_axes, check_axes_validity, check_tiling_validity +from .base_enum import BaseEnum +from .context import cwd, get_careamics_home +from .logging import get_logger +from .path_utils import check_path_exists +from .ram import get_ram_size diff --git a/src/careamics/utils/ascii_logo.txt b/src/careamics/utils/ascii_logo.txt deleted file mode 100644 index bd15b3b0..00000000 --- a/src/careamics/utils/ascii_logo.txt +++ /dev/null @@ -1,9 +0,0 @@ - ...... ...... ........ ........ .... - -+++----+- -+++--+++- :+++---+++: :+++----- .--: -.+++ .: +++. .+++. :+++ :+++ :+++ :------. .---:----..:----. :--- :----: :----:. -.+++ .+++. .+++. :+++ -++= :+++ +=....=+++ :+++-..=+++-..=++= -+++ .+++-..++ +++-..=+. -.+++ .++++++++++. :++++++++=. :++++++: .+++. :+++ :+++ -+++ -+++ :+++ .+++=. -.+++ .+++. .+++. :+++ -+++ :+++ :=++==++++. :+++ :+++ -+++ -+++ :+++ .-=+++=: -.+++ .. .+++. .+++. :+++ :+++ :+++ .+++. .+++. :+++ :+++ -+++ -+++ :+++ .. .. :+++. - -++=-::-+= .+++. .+++. :+++ :+++ :+++-:::: =++=--=+++. :+++ :+++ -+++ -+++ =++=:-+= =+-:=++= - ...... ... ... ... ... ........ .... ... ... ... .... .... .... ..... \ No newline at end of file diff --git a/src/careamics/utils/augment.py b/src/careamics/utils/augment.py deleted file mode 100644 index 1190ffd5..00000000 --- a/src/careamics/utils/augment.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Augmentation module.""" -from typing import Tuple - -import numpy as np - - -# TODO: unused? -def _flip_and_rotate( - image: np.ndarray, rotate_state: int, flip_state: int -) -> np.ndarray: - """ - Apply the given number of 90 degrees rotations and flip to an array. - - Parameters - ---------- - image : np.ndarray - Array containing single image or patch, 2D or 3D. - rotate_state : int - Number of 90 degree rotations to apply. - flip_state : int - 0 or 1, whether to flip the array or not. - - Returns - ------- - np.ndarray - Flipped and rotated array. - """ - rotated = np.rot90(image, k=rotate_state, axes=(-2, -1)) - flipped = np.flip(rotated, axis=-1) if flip_state == 1 else rotated - return flipped.copy() - - -def augment_batch( - patch: np.ndarray, - original_image: np.ndarray, - mask: np.ndarray, - seed: int = 42, -) -> Tuple[np.ndarray, ...]: - """ - Apply augmentation function to patches and masks. - - Parameters - ---------- - patch : np.ndarray - Array containing single image or patch, 2D or 3D with masked pixels. - original_image : np.ndarray - Array containing original image or patch, 2D or 3D. - mask : np.ndarray - Array containing only masked pixels, 2D or 3D. - seed : int, optional - Seed for random number generator, controls the rotation and falipping. - - Returns - ------- - Tuple[np.ndarray, ...] - Tuple of augmented arrays. - """ - rng = np.random.default_rng(seed=seed) - rotate_state = rng.integers(0, 4) - flip_state = rng.integers(0, 2) - return ( - _flip_and_rotate(patch, rotate_state, flip_state), - _flip_and_rotate(original_image, rotate_state, flip_state), - _flip_and_rotate(mask, rotate_state, flip_state), - ) diff --git a/src/careamics/utils/base_enum.py b/src/careamics/utils/base_enum.py new file mode 100644 index 00000000..8ff8bb6c --- /dev/null +++ b/src/careamics/utils/base_enum.py @@ -0,0 +1,32 @@ +from enum import Enum, EnumMeta +from typing import Any + + +class _ContainerEnum(EnumMeta): + def __contains__(cls, item: Any) -> bool: + try: + cls(item) + except ValueError: + return False + return True + + @classmethod + def has_value(cls, value: Any) -> bool: + return value in cls._value2member_map_ + + +class BaseEnum(Enum, metaclass=_ContainerEnum): + """Base Enum class, allowing checking if a value is in the enum. + + Example + ------- + >>> from careamics.utils.base_enum import BaseEnum + >>> # Define a new enum + >>> class BaseEnumExtension(BaseEnum): + ... VALUE = "value" + >>> # Check if value is in the enum + >>> "value" in BaseEnumExtension + True + """ + + pass diff --git a/src/careamics/utils/context.py b/src/careamics/utils/context.py index 6802b903..d39626c8 100644 --- a/src/careamics/utils/context.py +++ b/src/careamics/utils/context.py @@ -9,6 +9,24 @@ from typing import Iterator, Union +def get_careamics_home() -> Path: + """Return the CAREamics home directory. + + CAREamics home directory is a hidden folder in home. + + Returns + ------- + Path + CAREamics home directory path. + """ + home = Path.home() / ".careamics" + + if not home.exists(): + home.mkdir(parents=True, exist_ok=True) + + return home + + @contextmanager def cwd(path: Union[str, Path]) -> Iterator[None]: """ @@ -29,8 +47,10 @@ def cwd(path: Union[str, Path]) -> Iterator[None]: Examples -------- - >>> with cwd(path): - ... pass + The context is whcnaged within the block and then restored to the original one. + + >>> with cwd(my_path): + ... pass # do something """ path = Path(path) diff --git a/src/careamics/utils/metrics.py b/src/careamics/utils/metrics.py index b961fc2b..23283c3a 100644 --- a/src/careamics/utils/metrics.py +++ b/src/careamics/utils/metrics.py @@ -112,49 +112,3 @@ def scale_invariant_psnr( range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt) gt_ = _zero_mean(gt) / np.std(gt) return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter) - - -class MetricTracker: - """ - Metric tracker class. - - This class is used to track values, sum, count and average of a metric over time. - - Attributes - ---------- - val : int - Last value of the metric. - avg : torch.Tensor.float - Average value of the metric. - sum : int - Sum of the metric values (times number of values). - count : int - Number of values. - """ - - def __init__(self) -> None: - """Constructor.""" - self.reset() - - def reset(self) -> None: - """Reset the metric tracker state.""" - self.val = 0.0 - self.avg: torch.Tensor.float = 0.0 - self.sum = 0.0 - self.count = 0.0 - - def update(self, value: int, n: int = 1) -> None: - """ - Update the metric tracker state. - - Parameters - ---------- - value : int - Value to update the metric tracker with. - n : int - Number of values, equals to batch size. - """ - self.val = value - self.sum += value * n - self.count += n - self.avg = self.sum / self.count diff --git a/src/careamics/utils/normalization.py b/src/careamics/utils/normalization.py deleted file mode 100644 index 42a8d417..00000000 --- a/src/careamics/utils/normalization.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -Normalization submodule. - -These methods are used to normalize and denormalize images. -""" -import numpy as np - - -def normalize(img: np.ndarray, mean: float, std: float) -> np.ndarray: - """ - Normalize an image using mean and standard deviation. - - Images are normalised by subtracting the mean and dividing by the standard - deviation. - - Parameters - ---------- - img : np.ndarray - Image to normalize. - mean : float - Mean. - std : float - Standard deviation. - - Returns - ------- - np.ndarray - Normalized array. - """ - zero_mean = img - mean - return zero_mean / std - - -def denormalize(img: np.ndarray, mean: float, std: float) -> np.ndarray: - """ - Denormalize an image using mean and standard deviation. - - Images are denormalised by multiplying by the standard deviation and adding the - mean. - - Parameters - ---------- - img : np.ndarray - Image to denormalize. - mean : float - Mean. - std : float - Standard deviation. - - Returns - ------- - np.ndarray - Denormalized array. - """ - return img * std + mean diff --git a/src/careamics/utils/path_utils.py b/src/careamics/utils/path_utils.py new file mode 100644 index 00000000..61bb744a --- /dev/null +++ b/src/careamics/utils/path_utils.py @@ -0,0 +1,24 @@ +from pathlib import Path +from typing import Union + + +def check_path_exists(path: Union[str, Path]) -> Path: + """Check if a path exists. If not, raise an error. + + Note that it returns `path` as a Path object. + + Parameters + ---------- + path : Union[str, Path] + Path to check. + + Returns + ------- + Path + Path as a Path object. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Data path {path} is incorrect or does not exist.") + + return path diff --git a/src/careamics/utils/ram.py b/src/careamics/utils/ram.py new file mode 100644 index 00000000..2a26c781 --- /dev/null +++ b/src/careamics/utils/ram.py @@ -0,0 +1,13 @@ +import psutil + + +def get_ram_size() -> int: + """ + Get RAM size in bytes. + + Returns + ------- + int + RAM size in mbytes. + """ + return psutil.virtual_memory().total / 1024**2 diff --git a/src/careamics/utils/receptive_field.py b/src/careamics/utils/receptive_field.py new file mode 100644 index 00000000..05de04bd --- /dev/null +++ b/src/careamics/utils/receptive_field.py @@ -0,0 +1,102 @@ +"""Receptive field calculation for computing the tile overlap.""" + +# Adapted from: https://github.com/frgfm/torch-scan + +import math +import warnings +from typing import Tuple, Union + +from torch import Tensor, nn +from torch.nn import Module +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd +from torch.nn.modules.pooling import ( + _AdaptiveAvgPoolNd, + _AdaptiveMaxPoolNd, + _AvgPoolNd, + _MaxPoolNd, +) + + +def module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, float, float]: + """Estimate the spatial receptive field of the module. + + Args: + module (torch.nn.Module): PyTorch module + inp (torch.Tensor): input to the module + out (torch.Tensor): output of the module + Returns: + receptive field + effective stride + effective padding + """ + if isinstance( + module, + ( + nn.Identity, + nn.Flatten, + nn.ReLU, + nn.ELU, + nn.LeakyReLU, + nn.ReLU6, + nn.Tanh, + nn.Sigmoid, + _BatchNorm, + nn.Dropout, + nn.Linear, + ), + ): + return 1.0, 1.0, 0.0 + elif isinstance(module, _ConvTransposeNd): + return rf_convtransposend(module, inp, out) + elif isinstance(module, (_ConvNd, _MaxPoolNd, _AvgPoolNd)): + return rf_aggregnd(module, inp, out) + elif isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)): + return rf_adaptive_poolnd(module, inp, out) + else: + warnings.warn( + f"Module type not supported: {module.__class__.__name__}", stacklevel=1 + ) + return 1.0, 1.0, 0.0 + + +def rf_convtransposend( + module: _ConvTransposeNd, _: Tensor, __: Tensor +) -> Tuple[float, float, float]: + k = ( + module.kernel_size[0] + if isinstance(module.kernel_size, tuple) + else module.kernel_size + ) + s = module.stride[0] if isinstance(module.stride, tuple) else module.stride + return -k, 1.0 / s, 0.0 + + +def rf_aggregnd( + module: Union[_ConvNd, _MaxPoolNd, _AvgPoolNd], _: Tensor, __: Tensor +) -> Tuple[float, float, float]: + k = ( + module.kernel_size[0] + if isinstance(module.kernel_size, tuple) + else module.kernel_size + ) + if hasattr(module, "dilation"): + d = ( + module.dilation[0] + if isinstance(module.dilation, tuple) + else module.dilation + ) + k = d * (k - 1) + 1 + s = module.stride[0] if isinstance(module.stride, tuple) else module.stride + p = module.padding[0] if isinstance(module.padding, tuple) else module.padding + return k, s, p # type: ignore[return-value] + + +def rf_adaptive_poolnd( + _: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor +) -> Tuple[int, int, float]: + stride = math.ceil(inp.shape[-1] / out.shape[-1]) + kernel_size = stride + padding = (inp.shape[-1] - kernel_size * stride) / 2 + + return kernel_size, stride, padding diff --git a/src/careamics/utils/running_stats.py b/src/careamics/utils/running_stats.py new file mode 100644 index 00000000..1268d3e4 --- /dev/null +++ b/src/careamics/utils/running_stats.py @@ -0,0 +1,43 @@ +"""Running stats submodule, used in the Zarr dataset.""" + +# from multiprocessing import Value +# from typing import Tuple + +# import numpy as np + + +# class RunningStats: +# """Calculates running mean and std.""" + +# def __init__(self) -> None: +# self.reset() + +# def reset(self) -> None: +# """Reset the running stats.""" +# self.avg_mean = Value("d", 0) +# self.avg_std = Value("d", 0) +# self.m2 = Value("d", 0) +# self.count = Value("i", 0) + +# def init(self, mean: float, std: float) -> None: +# """Initialize running stats.""" +# with self.avg_mean.get_lock(): +# self.avg_mean.value += mean +# with self.avg_std.get_lock(): +# self.avg_std.value = std + +# def compute_std(self) -> Tuple[float, float]: +# """Compute std.""" +# if self.count.value >= 2: +# self.avg_std.value = np.sqrt(self.m2.value / self.count.value) + +# def update(self, value: float) -> None: +# """Update running stats.""" +# with self.count.get_lock(): +# self.count.value += 1 +# delta = value - self.avg_mean.value +# with self.avg_mean.get_lock(): +# self.avg_mean.value += delta / self.count.value +# delta2 = value - self.avg_mean.value +# with self.m2.get_lock(): +# self.m2.value += delta * delta2 diff --git a/src/careamics/utils/torch_utils.py b/src/careamics/utils/torch_utils.py index d0553933..1cc0fb01 100644 --- a/src/careamics/utils/torch_utils.py +++ b/src/careamics/utils/torch_utils.py @@ -3,87 +3,124 @@ These functions are used to control certain aspects and behaviours of PyTorch. """ -import logging +import inspect +from typing import Dict, Union import torch +from careamics.config.support import SupportedOptimizer, SupportedScheduler -def get_device() -> torch.device: +from ..utils.logging import get_logger + +logger = get_logger(__name__) # TODO are logger still needed? + + +def filter_parameters( + func: type, + user_params: dict, +) -> dict: + """ + Filter parameters according to the function signature. + + Parameters + ---------- + func : type + Class object. + user_params : Dict + User provided parameters. + + Returns + ------- + Dict + Parameters matching `func`'s signature. + """ + # Get the list of all default parameters + default_params = list(inspect.signature(func).parameters.keys()) + + # Filter matching parameters + params_to_be_used = set(user_params.keys()) & set(default_params) + + return {key: user_params[key] for key in params_to_be_used} + + +def get_optimizer(name: str) -> torch.optim.Optimizer: + """ + Return the optimizer class given its name. + + Parameters + ---------- + name : str + Optimizer name. + + Returns + ------- + torch.nn.Optimizer + Optimizer class. + """ + if name not in SupportedOptimizer: + raise NotImplementedError(f"Optimizer {name} is not yet supported.") + + return getattr(torch.optim, name) + + +def get_optimizers() -> Dict[str, str]: + """ + Return the list of all optimizers available in torch.optim. + + Returns + ------- + Dict + Optimizers available in torch.optim. + """ + optims = {} + for name, obj in inspect.getmembers(torch.optim): + if inspect.isclass(obj) and issubclass(obj, torch.optim.Optimizer): + if name != "Optimizer": + optims[name] = name + return optims + + +def get_scheduler( + name: str, +) -> Union[ + torch.optim.lr_scheduler.LRScheduler, + torch.optim.lr_scheduler.ReduceLROnPlateau, +]: + """ + Return the scheduler class given its name. + + Parameters + ---------- + name : str + Scheduler name. + + Returns + ------- + Union + Scheduler class. + """ + if name not in SupportedScheduler: + raise NotImplementedError(f"Scheduler {name} is not yet supported.") + + return getattr(torch.optim.lr_scheduler, name) + + +def get_schedulers() -> Dict[str, str]: """ - Select the device to use for training. + Return the list of all schedulers available in torch.optim.lr_scheduler. Returns ------- - torch.device - CUDA or CPU device, depending on availability of CUDA devices. + Dict + Schedulers available in torch.optim.lr_scheduler. """ - if torch.cuda.is_available(): - logging.info("CUDA available. Using GPU.") - device = torch.device("cuda") - else: - logging.info("CUDA not available. Using CPU.") - device = torch.device("cpu") - return device - - -# def compile_model(model: torch.nn.Module) -> torch.nn.Module: -# """ -# Torch.compile wrapper. - -# Parameters -# ---------- -# model : torch.nn.Module -# Model. - -# Returns -# ------- -# torch.nn.Module -# Compiled model if compile is available, the model itself otherwise. -# """ -# if hasattr(torch, "compile") and sys.version_info.minor <= 9: -# return torch.compile(model, mode="reduce-overhead") -# else: -# return model - - -# def seed_everything(seed: int) -> None: -# """ -# Seed all random number generators for reproducibility. - -# Parameters -# ---------- -# seed : int -# Seed. -# """ -# import random - -# import numpy as np - -# random.seed(seed) -# np.random.seed(seed) -# torch.manual_seed(seed) -# torch.cuda.manual_seed_all(seed) - - -# def setup_cudnn_reproducibility( -# deterministic: bool = True, benchmark: bool = True -# ) -> None: -# """ -# Prepare CuDNN benchmark and sets it to be deterministic/non-deterministic mode. - -# Parameters -# ---------- -# deterministic : bool -# Deterministic mode, if running CuDNN backend. -# benchmark : bool -# If True, uses CuDNN heuristics to figure out which algorithm will be most -# performant for your model architecture and input. False may slow down training -# """ -# if torch.cuda.is_available(): -# if deterministic: -# deterministic = os.environ.get("CUDNN_DETERMINISTIC", "True") == "True" -# torch.backends.cudnn.deterministic = deterministic - -# if benchmark: -# benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True" -# torch.backends.cudnn.benchmark = benchmark + schedulers = {} + for name, obj in inspect.getmembers(torch.optim.lr_scheduler): + if inspect.isclass(obj) and issubclass( + obj, torch.optim.lr_scheduler.LRScheduler + ): + if "LRScheduler" not in name: + schedulers[name] = name + elif name == "ReduceLROnPlateau": # somewhat not a subclass of LRScheduler + schedulers[name] = name + return schedulers diff --git a/src/careamics/utils/validators.py b/src/careamics/utils/validators.py deleted file mode 100644 index fc69bbc8..00000000 --- a/src/careamics/utils/validators.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -Validator functions. - -These functions are used to validate dimensions and axes of inputs. -""" -from typing import List - -import numpy as np - -AXES = "STCZYX" - - -def check_axes_validity(axes: str) -> None: - """ - Sanity check on axes. - - The constraints on the axes are the following: - - must be a combination of 'STCZYX' - - must not contain duplicates - - must contain at least 2 contiguous axes: X and Y - - must contain at most 4 axes - - cannot contain both S and T axes - - C is currently not allowed - - Parameters - ---------- - axes : str - Axes to validate. - """ - _axes = axes.upper() - - # Minimum is 2 (XY) and maximum is 4 (TZYX) - if len(_axes) < 2 or len(_axes) > 4: - raise ValueError( - f"Invalid axes {axes}. Must contain at least 2 and at most 4 axes." - ) - - # all characters must be in REF_AXES = 'STCZYX' - if not all(s in AXES for s in _axes): - raise ValueError(f"Invalid axes {axes}. Must be a combination of {AXES}.") - - # check for repeating characters - for i, s in enumerate(_axes): - if i != _axes.rfind(s): - raise ValueError( - f"Invalid axes {axes}. Cannot contain duplicate axes" - f" (got multiple {axes[i]})." - ) - - # currently no implementation for C - if "C" in _axes: - raise NotImplementedError("Currently, C axis is not supported.") - - # prevent S and T axes at the same time - if "T" in _axes and "S" in _axes: - raise NotImplementedError( - f"Invalid axes {axes}. Cannot contain both S and T axes." - ) - - # prior: X and Y contiguous (#FancyComments) - # right now the next check is invalidating this, but in the future, we might - # allow random order of axes (or at least XY and YX) - if "XY" not in _axes and "YX" not in _axes: - raise ValueError(f"Invalid axes {axes}. X and Y must be contiguous.") - - # check that the axes are in the right order - for i, s in enumerate(_axes): - if i < len(_axes) - 1: - index_s = AXES.find(s) - index_next = AXES.find(_axes[i + 1]) - - if index_s > index_next: - raise ValueError( - f"Invalid axes {axes}. Axes must be in the order {AXES}." - ) - - -def add_axes(input_array: np.ndarray, axes: str) -> np.ndarray: - """ - Add missing axes to the input, typically batch and channel. - - This method validates the axes first. Then it inspects the input array and add - missing dimensions if necessary. - - Parameters - ---------- - input_array : np.ndarray - Input array. - axes : str - Axes to add. - - Returns - ------- - np.ndarray - Array with new singleton axes. - """ - # validate axes - check_axes_validity(axes) - - # is 3D - is_3D = "Z" in axes - - # number of dims - n_dims = 5 if is_3D else 4 - - # array of dim 2, 3 or 4 - if len(input_array.shape) < n_dims: - if "S" not in axes and "T" not in axes: - input_array = input_array[np.newaxis, ...] - - # still missing C dimension - if len(input_array.shape) < n_dims: - input_array = input_array[:, np.newaxis, ...] - - return input_array - - -def check_tiling_validity(tile_shape: List[int], overlaps: List[int]) -> None: - """ - Check that the tiling parameters are valid. - - Parameters - ---------- - tile_shape : List[int] - Shape of the tiles. - overlaps : List[int] - Overlap between tiles. - - Raises - ------ - ValueError - If one of the parameters is None. - ValueError - If one of the element is zero. - ValueError - If one of the element is non-divisible by 2. - ValueError - If the number of elements in `overlaps` and `tile_shape` is different. - ValueError - If one of the overlaps is larger than the corresponding tile shape. - """ - # cannot be None - if tile_shape is None or overlaps is None: - raise ValueError( - "Cannot use tiling without specifying `tile_shape` and " - "`overlaps`, make sure they have been correctly specified." - ) - - # non-zero and divisible by two - for dims_list in [tile_shape, overlaps]: - for dim in dims_list: - if dim < 1: - raise ValueError(f"Entry must be non-null positive (got {dim}).") - - if dim % 2 != 0: - raise ValueError(f"Entry must be divisible by 2 (got {dim}).") - - # same length - if len(overlaps) != len(tile_shape): - raise ValueError( - f"Overlaps ({len(overlaps)}) and tile shape ({len(tile_shape)}) must " - f"have the same number of dimensions." - ) - - # overlaps smaller than tile shape - for overlap, tile_dim in zip(overlaps, tile_shape): - if overlap >= tile_dim: - raise ValueError( - f"Overlap ({overlap}) must be smaller than tile shape ({tile_dim})." - ) diff --git a/src/careamics/utils/wandb.py b/src/careamics/utils/wandb.py deleted file mode 100644 index 68576393..00000000 --- a/src/careamics/utils/wandb.py +++ /dev/null @@ -1,121 +0,0 @@ -""" -A WandB logger for CAREamics. - -Implements a WandB class for use within the Engine. -""" -import sys -from pathlib import Path -from typing import Dict, Union - -import torch -import wandb - -from careamics.config import Configuration - - -def is_notebook() -> bool: - """ - Check if the code is executed from a notebook or a qtconsole. - - Returns - ------- - bool - True if the code is executed from a notebooks, False otherwise. - """ - try: - from IPython import get_ipython - - shell = get_ipython().__class__.__name__ - if shell == "ZMQInteractiveShell": - return True # Jupyter notebook or qtconsole - else: - return False - except (NameError, ModuleNotFoundError): - return False - - -class WandBLogging: - """ - WandB logging class. - - Parameters - ---------- - experiment_name : str - Name of the experiment. - log_path : Path - Path in which to save the WandB log. - config : Configuration - Configuration of the model. - model_to_watch : torch.nn.Module - Model. - save_code : bool, optional - Whether to save the code, by default True. - """ - - def __init__( - self, - experiment_name: str, - log_path: Path, - config: Configuration, - model_to_watch: torch.nn.Module, - save_code: bool = True, - ): - """ - Constructor. - - Parameters - ---------- - experiment_name : str - Name of the experiment. - log_path : Path - Path in which to save the WandB log. - config : Configuration - Configuration of the model. - model_to_watch : torch.nn.Module - Model. - save_code : bool, optional - Whether to save the code, by default True. - """ - self.run = wandb.init( - project="careamics-restoration", - dir=log_path, - name=experiment_name, - config=config.model_dump() if config else None, - # save_code=save_code, - ) - if model_to_watch: - wandb.watch(model_to_watch, log="all", log_freq=1) - if save_code: - if is_notebook(): - # Get all sys path and select the root - code_path = Path([p for p in sys.path if "caremics" in p][-1]).parent - else: - code_path = Path("../") - self.log_code(code_path) - - def log_metrics(self, metric_dict: Dict) -> None: - """ - Log metrics to wandb. - - Parameters - ---------- - metric_dict : Dict - New metrics entry. - """ - self.run.log(metric_dict, commit=True) - - def log_code(self, code_path: Union[str, Path]) -> None: - """ - Log code to wandb. - - Parameters - ---------- - code_path : Union[str, Path] - Path to the code. - """ - self.run.log_code( - root=code_path, - include_fn=lambda path: path.endswith(".py") - or path.endswith(".yml") - or path.endswith(".yaml"), - ) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index ddb7d4cb..00000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""CAREamics tests.""" diff --git a/tests/bioimage/test_engine_bmz.py b/tests/bioimage/test_engine_bmz.py deleted file mode 100644 index b667fc62..00000000 --- a/tests/bioimage/test_engine_bmz.py +++ /dev/null @@ -1,123 +0,0 @@ -from pathlib import Path - -import numpy as np -import pytest - -from careamics import Configuration, Engine - - -def test_generate_rdf_without_mean_std(minimum_config: dict): - """Test that generating rdf without specifying mean and std - raises an error.""" - # create configuration and save it to disk - config = Configuration(**minimum_config) - - # create an engine to export the model - engine = Engine(config=config) - with pytest.raises(ValueError): - engine._generate_rdf() - - # test if error is raised when config is None - engine.config = None - with pytest.raises(ValueError): - engine._generate_rdf() - - -def test_bioimage_generate_rdf(minimum_config: dict): - """Test generating rdf using default specs.""" - # create configuration and save it to disk - mean = 666.666 - std = 42.420 - minimum_config["data"]["mean"] = mean - minimum_config["data"]["std"] = std - minimum_config["data"]["axes"] = "YX" - config = Configuration(**minimum_config) - - # create an engine to export the model - engine = Engine(config=config) - - # create a monkey patch for the input - engine._input = np.random.randint(0, 255, minimum_config["training"]["patch_size"]) - - # Sample files - axes = "bcyx" - test_inputs = Path(minimum_config["working_directory"]) / "test_inputs.npy" - test_outputs = Path(minimum_config["working_directory"]) / "test_outputs.npy" - - # Export rdf - rdf = engine._generate_rdf() - assert rdf["preprocessing"][0][0]["kwargs"]["mean"] == [mean] - assert rdf["preprocessing"][0][0]["kwargs"]["std"] == [std] - assert rdf["postprocessing"][0][0]["kwargs"]["offset"] == [mean] - assert rdf["postprocessing"][0][0]["kwargs"]["gain"] == [std] - assert rdf["test_inputs"] == [str(test_inputs)] - assert rdf["test_outputs"] == [str(test_outputs)] - assert rdf["input_axes"] == [axes] - assert rdf["output_axes"] == [axes] - - -def test_bioimage_generate_rdf_with_specs(minimum_config: dict): - """Test model export to bioimage format by using default specs.""" - # create configuration and save it to disk - mean = 666.666 - std = 42.420 - minimum_config["data"]["mean"] = mean - minimum_config["data"]["std"] = std - minimum_config["data"]["axes"] = "YX" - config = Configuration(**minimum_config) - - # create an engine to export the model - engine = Engine(config=config) - - # create a monkey patch for the input - engine._input = np.random.randint(0, 255, minimum_config["training"]["patch_size"]) - - # Test model specs - model_specs = {"description": "Some description", "license": "to kill"} - rdf = engine._generate_rdf(model_specs=model_specs) - assert rdf["description"] == model_specs["description"] - assert rdf["license"] == model_specs["license"] - - -@pytest.mark.parametrize( - "axes, shape", - [ - ("YX", (64, 128)), - ("ZYX", (8, 128, 64)), - ], -) -def test_bioimage_generate_rdf_with_input( - minimum_config: dict, ordered_array, axes, shape -): - """Test generating rdf using default specs.""" - # create configuration and save it to disk - mean = 666.666 - std = 42.420 - minimum_config["algorithm"]["is_3D"] = len(shape) == 3 - minimum_config["training"]["patch_size"] = ( - (64, 64) if len(shape) == 2 else (8, 64, 64) - ) - minimum_config["data"]["mean"] = mean - minimum_config["data"]["std"] = std - minimum_config["data"]["axes"] = axes - config = Configuration(**minimum_config) - - # create an engine to export the model - engine = Engine(config=config) - - # create a monkey patch for the input - monkey_input = np.random.randint(0, 255, minimum_config["training"]["patch_size"]) - engine._input = monkey_input - - # create other input - other_input = ordered_array(shape=shape) - - # create rdf - rdf = engine._generate_rdf(input_array=other_input) - - # inspect input/output - array_in = np.load(rdf["test_inputs"][0]) - assert (array_in.squeeze() == other_input).all() - - array_out = np.load(rdf["test_outputs"][0]) - assert array_out.max() != 0 diff --git a/tests/bioimage/test_io.py b/tests/bioimage/test_io.py deleted file mode 100644 index a8cafcda..00000000 --- a/tests/bioimage/test_io.py +++ /dev/null @@ -1,84 +0,0 @@ -from pathlib import Path - -import numpy as np -import pytest -import torch -from bioimageio.core import resource_tests - -from careamics.bioimage import import_bioimage_model -from careamics.config import Configuration -from careamics.engine import Engine -from careamics.models import create_model - - -def save_checkpoint(engine: Engine, config: Configuration) -> None: - # create a fake checkpoint - checkpoint = { - "epoch": 1, - "model_state_dict": engine.model.state_dict(), - "optimizer_state_dict": engine.optimizer.state_dict(), - "scheduler_state_dict": engine.lr_scheduler.state_dict(), - "grad_scaler_state_dict": engine.scaler.state_dict(), - "loss": 0.01, - "config": config, - } - checkpoint_path = ( - Path(config["working_directory"]) - .joinpath(f"{config['experiment_name']}_best.pth") - .absolute() - ) - torch.save(checkpoint, checkpoint_path) - - -@pytest.mark.parametrize( - "axes, patch", - [ - ("YX", [64, 64]), - ("ZYX", [32, 64, 64]), - ], -) -def test_bioimage_io(minimum_config: dict, tmp_path: Path, axes, patch): - """Test model export/import to bioimage format.""" - # create configuration - minimum_config["data"]["mean"] = 666.666 - minimum_config["data"]["std"] = 42.420 - minimum_config["data"]["axes"] = axes - minimum_config["training"]["patch_size"] = patch - minimum_config["algorithm"]["is_3D"] = len(axes) == 3 - - config = Configuration(**minimum_config) - - # create an engine to export the model - engine = Engine(config=config) - - # create a monkey patch for the input (array saved during first validation) - engine._input = np.random.randint(0, 255, minimum_config["training"]["patch_size"]) - engine._input = engine._input[np.newaxis, np.newaxis, ...] - - # save fake checkpoint - save_checkpoint(engine, minimum_config) - - # output zip file - zip_file = tmp_path / "tmp_model.bioimage.io.zip" - - engine.save_as_bioimage(zip_file.absolute()) - assert zip_file.exists() - - # load model - _, _, _, _, loaded_config = create_model(model_path=zip_file) - assert isinstance(loaded_config, Configuration) - - # check that the configuration is the same - assert loaded_config == config - - # validate model - results = resource_tests.test_model(zip_file) - for result in results: - assert result["status"] == "passed", f"Failed at {result['name']}." - - -def test_bioimage_wrong_path(tmp_path: Path): - """Test that the model export fails if the path is wrong.""" - path = tmp_path / "wrong_path.tiff" - with pytest.raises(ValueError): - import_bioimage_model(path) diff --git a/tests/bioimage/test_rdf.py b/tests/bioimage/test_rdf.py deleted file mode 100644 index c7266277..00000000 --- a/tests/bioimage/test_rdf.py +++ /dev/null @@ -1,37 +0,0 @@ -from pathlib import Path - -import pytest - -from careamics.bioimage.rdf import _get_model_doc, get_default_model_specs - - -@pytest.mark.parametrize("name", ["Noise2Void"]) -def test_get_model_doc(name): - doc = _get_model_doc(name) - assert Path(doc).exists() - - -def test_get_model_doc_error(): - with pytest.raises(FileNotFoundError): - _get_model_doc("NotAModel") - - -@pytest.mark.parametrize("name", ["Noise2Void"]) -@pytest.mark.parametrize("is_3D", [True, False]) -def test_default_model_specs(name, is_3D): - mean = 666.666 - std = 42.420 - - if is_3D: - axes = "zyx" - else: - axes = "yx" - - specs = get_default_model_specs(name, mean, std, is_3D=is_3D) - assert specs["name"] == name - assert specs["preprocessing"][0][0]["kwargs"]["mean"] == [mean] - assert specs["preprocessing"][0][0]["kwargs"]["std"] == [std] - assert specs["preprocessing"][0][0]["kwargs"]["axes"] == axes - assert specs["postprocessing"][0][0]["kwargs"]["offset"] == [mean] - assert specs["postprocessing"][0][0]["kwargs"]["gain"] == [std] - assert specs["postprocessing"][0][0]["kwargs"]["axes"] == axes diff --git a/tests/config/architectures/test_architecture_model.py b/tests/config/architectures/test_architecture_model.py new file mode 100644 index 00000000..b97ab7e9 --- /dev/null +++ b/tests/config/architectures/test_architecture_model.py @@ -0,0 +1,11 @@ +from careamics.config.architectures import ArchitectureModel + + +def test_model_dump(): + """Test that architecture keyword is removed from the model dump.""" + model_params = {"architecture": "LeCorbusier"} + model = ArchitectureModel(**model_params) + + # dump model + model_dict = model.model_dump() + assert model_dict == {} diff --git a/tests/config/architectures/test_custom_model.py b/tests/config/architectures/test_custom_model.py new file mode 100644 index 00000000..63b6c776 --- /dev/null +++ b/tests/config/architectures/test_custom_model.py @@ -0,0 +1,100 @@ +import pytest +from torch import nn, ones + +from careamics.config.architectures import CustomModel, get_custom_model, register_model +from careamics.config.support import SupportedArchitecture + + +@register_model(name="linear") +class LinearModel(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(ones(in_features, out_features)) + self.bias = nn.Parameter(ones(out_features)) + + def forward(self, input): + return (input @ self.weight) + self.bias + + +@register_model(name="not_a_model") +class NotAModel: + def __init__(self, id): + self.id = id + + def forward(self, input): + return input + + +def test_any_custom_parameters(): + """Test that the custom model can have any fields. + + Note that those fields are validated by instantiating the + model. + """ + CustomModel(architecture="Custom", name="linear", in_features=10, out_features=5) + + +def test_linear_model(): + """Test that the model can be retrieved and instantiated.""" + model = get_custom_model("linear") + model(in_features=10, out_features=5) + + +def test_not_a_model(): + """Test that the model can be retrieved and instantiated.""" + model = get_custom_model("not_a_model") + model(3) + + +def test_custom_model(): + """Test that the custom model can be instantiated.""" + # prepare model dictionary + model_dict = { + "architecture": SupportedArchitecture.CUSTOM.value, + "name": "linear", + "in_features": 10, + "out_features": 5, + } + + # create Pydantic model + pydantic_model = CustomModel(**model_dict) + + # instantiate model + model_class = get_custom_model(pydantic_model.name) + model = model_class(**pydantic_model.model_dump()) + + assert isinstance(model, LinearModel) + assert model.in_features == 10 + assert model.out_features == 5 + + +def test_custom_model_wrong_class(): + """Test that the Pydantic custom model raises an error if the model is not a + torch.nn.Module subclass.""" + # prepare model dictionary + model_dict = { + "architecture": "Custom", + "name": "not_a_model", + "parameters": {"id": 3}, + } + + # create Pydantic model + with pytest.raises(ValueError): + CustomModel(**model_dict) + + +def test_wrong_parameters(): + """Test that the custom model raises an error if the parameters are not valid.""" + # prepare model dictionary + model_dict = { + "architecture": "Custom", + "name": "linear", + "parameters": {"in_features": 10}, + } + + # create Pydantic model + with pytest.raises(ValueError): + CustomModel(**model_dict) diff --git a/tests/config/architectures/test_register_model.py b/tests/config/architectures/test_register_model.py new file mode 100644 index 00000000..e41cf4b0 --- /dev/null +++ b/tests/config/architectures/test_register_model.py @@ -0,0 +1,45 @@ +import pytest + +from careamics.config.architectures import ( + clear_custom_models, + get_custom_model, + register_model, +) + + +# register a model +@register_model(name="mymodel") +class MyModel: + model_name: str + model_id: int + + +def test_register_model(): + """Test the register_model decorator.""" + + # get custom model + model = get_custom_model("mymodel") + + # check if it is a subclass of MyModel + assert issubclass(model, MyModel) + + +def test_wrong_model(): + """Test that an error is raised if an unknown model is requested.""" + get_custom_model("mymodel") + + with pytest.raises(ValueError): + get_custom_model("unknown_model") + + +def test_clear_custom_models(): + """Test that the custom models are cleared.""" + # retrieve model + get_custom_model("mymodel") + + # clear custom models + clear_custom_models() + + # request the model again + with pytest.raises(ValueError): + get_custom_model("mymodel") diff --git a/tests/config/architectures/test_unet_model.py b/tests/config/architectures/test_unet_model.py new file mode 100644 index 00000000..0f41a91b --- /dev/null +++ b/tests/config/architectures/test_unet_model.py @@ -0,0 +1,121 @@ +import pytest + +from careamics.config.architectures import UNetModel +from careamics.config.support import SupportedActivation + + +def test_instantiation(): + """Test that UNetModel can be instantiated.""" + model_params = { + "architecture": "UNet", + "conv_dim": 2, + "num_channels_init": 16, + } + + # instantiate model + UNetModel(**model_params) + + +def test_architecture_missing(): + """Test that UNetModel requires architecture.""" + model_params = { + "depth": 2, + "num_channels_init": 16, + } + + with pytest.raises(ValueError): + UNetModel(**model_params) + + +@pytest.mark.parametrize("num_channels_init", [8, 16, 32, 96, 128]) +def test_num_channels_init(num_channels_init: int): + """Test that UNetModel accepts num_channels_init as an even number and + minimum 8.""" + model_params = {"architecture": "UNet", "num_channels_init": num_channels_init} + + # instantiate model + UNetModel(**model_params) + + +@pytest.mark.parametrize("num_channels_init", [2, 17, 127]) +def test_wrong_num_channels_init(num_channels_init: int): + """Test that wrong num_channels_init causes an error.""" + model_params = {"architecture": "UNet", "num_channels_init": num_channels_init} + + with pytest.raises(ValueError): + UNetModel(**model_params) + + +def test_activations(): + """Test that UNetModel accepts all activations.""" + for act in SupportedActivation: + model_params = { + "architecture": "UNet", + "num_channels_init": 16, + "final_activation": act.value, + } + + # instantiate model + UNetModel(**model_params) + + +def test_all_activations_are_supported(): + """Test that all activations defined in the Literal are supported.""" + # list of supported activations + activations = list(SupportedActivation) + + # Algorithm json schema + schema = UNetModel.model_json_schema() + + # check that all activations are supported + for act in schema["properties"]["final_activation"]["enum"]: + assert act in activations + + +def test_activation_wrong_values(): + """Test that wrong values are not accepted.""" + model_params = { + "architecture": "UNet", + "num_channels_init": 16, + "final_activation": "wrong", + } + + with pytest.raises(ValueError): + UNetModel(**model_params) + + +def test_parameters_wrong_values_by_assigment(): + """Test that wrong values are not accepted through assignment.""" + model_params = {"architecture": "UNet", "num_channels_init": 16, "depth": 2} + model = UNetModel(**model_params) + + # depth + model.depth = model_params["depth"] + with pytest.raises(ValueError): + model.depth = -1 + + # number of channels + model.num_channels_init = model_params["num_channels_init"] + with pytest.raises(ValueError): + model.num_channels_init = 2 + + +def test_model_dump(): + """Test that default values are excluded from model dump.""" + model_params = { + "architecture": "UNet", + "num_channels_init": 16, # non-default value + "final_activation": "ReLU", # non-default value + } + model = UNetModel(**model_params) + + # dump model + model_dict = model.model_dump(exclude_defaults=True) + + # check that default values are excluded except the architecture + assert "architecture" not in model_dict + assert len(model_dict) == 2 + + # check that we get all the optional values with the exclude_defaults flag + model_dict = model.model_dump(exclude_defaults=False) + assert len(model_dict) == len(dict(model)) - 1 diff --git a/tests/config/support/test_supported_data.py b/tests/config/support/test_supported_data.py new file mode 100644 index 00000000..ef8bf22d --- /dev/null +++ b/tests/config/support/test_supported_data.py @@ -0,0 +1,76 @@ +from fnmatch import fnmatch +from pathlib import Path + +import numpy as np +import pytest +import tifffile + +from careamics.config.support import SupportedData + + +def test_extension_tiff_fnmatch(tmp_path: Path): + """Test that the TIFF extension is compatible with fnmatch.""" + path = tmp_path / "test.tif" + + # test as str + assert fnmatch(str(path), SupportedData.get_extension(SupportedData.TIFF)) + + # test as Path + assert fnmatch(path, SupportedData.get_extension(SupportedData.TIFF)) + + +def test_extension_tiff_rglob(tmp_path: Path): + """Test that the TIFF extension is compatible with Path.rglob.""" + # create text file + text_path = tmp_path / "test.txt" + text_path.write_text("test") + + # create image + path = tmp_path / "test.tif" + image = np.ones((10, 10)) + tifffile.imwrite(path, image) + + # search for files + files = list(tmp_path.rglob(SupportedData.get_extension(SupportedData.TIFF))) + assert len(files) == 1 + assert files[0] == path + + +def test_extension_custom_fnmatch(tmp_path: Path): + """Test that the custom extension is compatible with fnmatch.""" + path = tmp_path / "test.czi" + + # test as str + assert fnmatch(str(path), SupportedData.get_extension(SupportedData.CUSTOM)) + + # test as Path + assert fnmatch(path, SupportedData.get_extension(SupportedData.CUSTOM)) + + +def test_extension_custom_rglob(tmp_path: Path): + """Test that the custom extension is compatible with Path.rglob.""" + # create text file + text_path = tmp_path / "test.txt" + text_path.write_text("test") + + # create image + path = tmp_path / "test.npy" + image = np.ones((10, 10)) + np.save(path, image) + + # search for files + files = list(tmp_path.rglob(SupportedData.get_extension(SupportedData.CUSTOM))) + assert len(files) == 2 + assert set(files) == {path, text_path} + + +def test_extension_array_error(): + """Test that the array extension raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + SupportedData.get_extension(SupportedData.ARRAY) + + +def test_extension_any_error(): + """Test that any extension raises NotImplementedError.""" + with pytest.raises(ValueError): + SupportedData.get_extension("some random") diff --git a/tests/config/support/test_supported_optimizers.py b/tests/config/support/test_supported_optimizers.py new file mode 100644 index 00000000..8682fc33 --- /dev/null +++ b/tests/config/support/test_supported_optimizers.py @@ -0,0 +1,18 @@ +from torch import optim + +from careamics.config.support.supported_optimizers import ( + SupportedOptimizer, + SupportedScheduler, +) + + +def test_schedulers_exist(): + """Test that `SupportedScheduler` contains existing torch schedulers.""" + for scheduler in SupportedScheduler: + assert hasattr(optim.lr_scheduler, scheduler) + + +def test_optimizers_exist(): + """Test that `SupportedOptimizer` contains existing torch optimizers.""" + for optimizer in SupportedOptimizer: + assert hasattr(optim, optimizer) diff --git a/tests/config/test_algorithm.py b/tests/config/test_algorithm.py deleted file mode 100644 index 65d0b95e..00000000 --- a/tests/config/test_algorithm.py +++ /dev/null @@ -1,187 +0,0 @@ -import pytest - -from careamics.config.algorithm import Algorithm, ModelParameters - - -@pytest.mark.parametrize("depth", [1, 5, 10]) -def test_model_parameters_depth(complete_config: dict, depth: int): - """Test that ModelParameters accepts depth between 1 and 10.""" - model_params = complete_config["algorithm"]["model_parameters"] - model_params["depth"] = depth - - model = ModelParameters(**model_params) - assert model.depth == depth - - -@pytest.mark.parametrize("depth", [-1, 11]) -def test_model_parameters_wrong_depth(complete_config: dict, depth: int): - """Test that wrong depth cause an error.""" - model_params = complete_config["algorithm"]["model_parameters"] - model_params["depth"] = depth - - with pytest.raises(ValueError): - ModelParameters(**model_params) - - -@pytest.mark.parametrize("num_channels_init", [8, 16, 32, 96, 128]) -def test_model_parameters_num_channels_init( - complete_config: dict, num_channels_init: int -): - """Test that ModelParameters accepts num_channels_init as a power of two and - minimum 8.""" - model_params = complete_config["algorithm"]["model_parameters"] - model_params["num_channels_init"] = num_channels_init - - model = ModelParameters(**model_params) - assert model.num_channels_init == num_channels_init - - -@pytest.mark.parametrize("num_channels_init", [2, 17, 127]) -def test_model_parameters_wrong_num_channels_init( - complete_config: dict, num_channels_init: int -): - """Test that wrong num_channels_init cause an error.""" - model_params = complete_config["algorithm"]["model_parameters"] - model_params["num_channels_init"] = num_channels_init - - with pytest.raises(ValueError): - ModelParameters(**model_params) - - -@pytest.mark.parametrize("roi_size", [5, 9, 15]) -def test_model_parameters_roi_size(complete_config: dict, roi_size: int): - """Test that Algorithm accepts roi_size as an even number within the - range [3, 21].""" - params = complete_config["algorithm"] - params["roi_size"] = roi_size - - algorithm = Algorithm(**params) - assert algorithm.roi_size == roi_size - - -@pytest.mark.parametrize("roi_size", [2, 4, 23]) -def test_model_parameters_wrong_roi_size(complete_config: dict, roi_size: int): - """Test that wrong num_channels_init cause an error.""" - params = complete_config["algorithm"] - params["roi_size"] = roi_size - - with pytest.raises(ValueError): - Algorithm(**params) - - -def test_model_parameters_wrong_values_by_assigment(complete_config: dict): - """Test that wrong values are not accepted through assignment.""" - model_params = complete_config["algorithm"]["model_parameters"] - model = ModelParameters(**model_params) - - # depth - model.depth = model_params["depth"] - with pytest.raises(ValueError): - model.depth = -1 - - # number of channels - model.num_channels_init = model_params["num_channels_init"] - with pytest.raises(ValueError): - model.num_channels_init = 2 - - -@pytest.mark.parametrize("masked_pixel_percentage", [0.1, 0.2, 5, 20]) -def test_masked_pixel_percentage(complete_config: dict, masked_pixel_percentage: float): - """Test that Algorithm accepts the minimum configuration.""" - algorithm = complete_config["algorithm"] - algorithm["masked_pixel_percentage"] = masked_pixel_percentage - - algo = Algorithm(**algorithm) - assert algo.masked_pixel_percentage == masked_pixel_percentage - - -@pytest.mark.parametrize("masked_pixel_percentage", [0.01, 21]) -def test_wrong_masked_pixel_percentage( - complete_config: dict, masked_pixel_percentage: float -): - """Test that Algorithm accepts the minimum configuration.""" - algorithm = complete_config["algorithm"] - algorithm["masked_pixel_percentage"] = masked_pixel_percentage - - with pytest.raises(ValueError): - Algorithm(**algorithm) - - -def test_wrong_values_by_assigment(complete_config: dict): - """Test that wrong values are not accepted through assignment.""" - algorithm = complete_config["algorithm"] - algo = Algorithm(**algorithm) - - # loss - algo.loss = algorithm["loss"] - with pytest.raises(ValueError): - algo.loss = "mse" - - # model - algo.model = algorithm["model"] - with pytest.raises(ValueError): - algo.model = "unet" - - # is_3D - algo.is_3D = algorithm["is_3D"] - with pytest.raises(ValueError): - algo.is_3D = 3 - - # masking_strategy - algo.masking_strategy = algorithm["masking_strategy"] - with pytest.raises(ValueError): - algo.masking_strategy = "mean" - - # masked_pixel_percentage - algo.masked_pixel_percentage = algorithm["masked_pixel_percentage"] - with pytest.raises(ValueError): - algo.masked_pixel_percentage = 0.01 - - # model_parameters - algo.model_parameters = algorithm["model_parameters"] - with pytest.raises(ValueError): - algo.model_parameters = "params" - - -def test_algorithm_to_dict_minimum(minimum_config: dict): - """ "Test that export to dict does not include optional values.""" - algorithm_minimum = Algorithm(**minimum_config["algorithm"]).model_dump() - assert algorithm_minimum == minimum_config["algorithm"] - - assert "loss" in algorithm_minimum - assert "model" in algorithm_minimum - assert "is_3D" in algorithm_minimum - assert "masking_strategy" not in algorithm_minimum - assert "masked_pixel_percentage" not in algorithm_minimum - assert "model_parameters" not in algorithm_minimum - - -def test_algorithm_to_dict_complete(complete_config: dict): - """ "Test that export to dict does not include optional values.""" - algorithm_complete = Algorithm(**complete_config["algorithm"]).model_dump() - assert algorithm_complete == complete_config["algorithm"] - # TODO values are hardcoded in the fixture, is it ok ? - assert "loss" in algorithm_complete - assert "model" in algorithm_complete - assert "is_3D" in algorithm_complete - assert "masking_strategy" in algorithm_complete - assert "masked_pixel_percentage" in algorithm_complete - assert "roi_size" in algorithm_complete - assert "model_parameters" in algorithm_complete - assert "depth" in algorithm_complete["model_parameters"] - assert "num_channels_init" in algorithm_complete["model_parameters"] - - -def test_algorithm_to_dict_optionals(complete_config: dict): - """Test that export to dict does not include optional values.""" - # change optional value to the default - algo_config = complete_config["algorithm"] - algo_config["model_parameters"] = { - "depth": 2, - "num_channels_init": 32, - } - algo_config["masking_strategy"] = "default" - - algorithm_complete = Algorithm(**complete_config["algorithm"]).model_dump() - assert "model_parameters" not in algorithm_complete - assert "masking_strategy" not in algorithm_complete diff --git a/tests/config/test_algorithm_model.py b/tests/config/test_algorithm_model.py new file mode 100644 index 00000000..6321c46c --- /dev/null +++ b/tests/config/test_algorithm_model.py @@ -0,0 +1,102 @@ +import pytest + +from careamics.config.algorithm_model import AlgorithmModel +from careamics.config.support import ( + SupportedAlgorithm, + SupportedArchitecture, + SupportedLoss, +) + + +def test_all_algorithms_are_supported(): + """Test that all algorithms defined in the Literal are supported.""" + # list of supported algorithms + algorithms = list(SupportedAlgorithm) + + # Algorithm json schema to extract the literal value + schema = AlgorithmModel.model_json_schema() + + # check that all algorithms are supported + for algo in schema["properties"]["algorithm"]["enum"]: + assert algo in algorithms + + +def test_supported_losses(minimum_algorithm_custom): + """Test that all supported losses are accepted by the AlgorithmModel.""" + for loss in SupportedLoss: + minimum_algorithm_custom["loss"] = loss.value + AlgorithmModel(**minimum_algorithm_custom) + + +def test_all_losses_are_supported(): + """Test that all losses defined in the Literal are supported.""" + # list of supported losses + losses = list(SupportedLoss) + + # Algorithm json schema + schema = AlgorithmModel.model_json_schema() + + # check that all losses are supported + for loss in schema["properties"]["loss"]["enum"]: + assert loss in losses + + +def test_model_discriminator(minimum_algorithm_n2v): + """Test that discriminator permits correct assignment.""" + for model_name in SupportedArchitecture: + # TODO change once VAE are implemented + if model_name.value == "UNet": + minimum_algorithm_n2v["model"]["architecture"] = model_name.value + + algo = AlgorithmModel(**minimum_algorithm_n2v) + assert algo.model.architecture == model_name.value + + +@pytest.mark.parametrize( + "algorithm, loss, model", + [ + ("n2v", "n2v", {"architecture": "UNet", "n2v2": False}), + ("custom", "mae", {"architecture": "UNet", "n2v2": True}), + ], +) +def test_algorithm_constraints(algorithm: str, loss: str, model: dict): + """Test that constraints are passed for each algorithm.""" + AlgorithmModel(algorithm=algorithm, loss=loss, model=model) + + +@pytest.mark.parametrize("algorithm", ["n2v", "n2n"]) +def test_n_channels_n2v_and_n2n(algorithm): + """Check that an error is raised if n2v and n2n have different number of channels in + input and output.""" + model = { + "architecture": "UNet", + "in_channels": 1, + "num_classes": 2, + "n2v2": False, + } + loss = "mae" if algorithm == "n2n" else "n2v" + + with pytest.raises(ValueError): + AlgorithmModel(algorithm=algorithm, loss=loss, model=model) + + +@pytest.mark.parametrize( + "algorithm, n_in, n_out", + [ + ("n2v", 2, 2), + ("n2n", 3, 3), + ("care", 1, 2), + ], +) +def test_comaptiblity_of_number_of_channels(algorithm, n_in, n_out): + """Check that no error is thrown when instantiating the algorithm with a valid + number of in and out channels.""" + model = { + "architecture": "UNet", + "in_channels": n_in, + "num_classes": n_out, + "n2v2": False, + } + loss = "n2v" if algorithm == "n2v" else "mae" + + AlgorithmModel(algorithm=algorithm, loss=loss, model=model) diff --git a/tests/config/test_config.py b/tests/config/test_config.py deleted file mode 100644 index 7cca742f..00000000 --- a/tests/config/test_config.py +++ /dev/null @@ -1,196 +0,0 @@ -from pathlib import Path - -import pytest - -from careamics.config import ( - Configuration, - load_configuration, - save_configuration, -) - - -@pytest.mark.parametrize("name", ["Sn4K3", "C4_M e-L"]) -def test_config_valid_names(minimum_config: dict, name: str): - """Test valid names (letters, numbers, spaces, dashes and underscores).""" - minimum_config["experiment_name"] = name - myconf = Configuration(**minimum_config) - assert myconf.experiment_name == name - - -@pytest.mark.parametrize("name", ["", " ", "#", "/", "^", "%", ",", ".", "a=b"]) -def test_config_invalid_names(minimum_config: dict, name: str): - """Test that invalid names raise an error.""" - minimum_config["experiment_name"] = name - with pytest.raises(ValueError): - Configuration(**minimum_config) - - -@pytest.mark.parametrize("path", ["", "tmp"]) -def test_config_valid_working_directory( - tmp_path: Path, minimum_config: dict, path: str -): - """Test valid working directory. - - A valid working directory exists or its direct parent exists. - """ - path = tmp_path / path - minimum_config["working_directory"] = str(path) - myconf = Configuration(**minimum_config) - assert myconf.working_directory == path - - -def test_config_invalid_working_directory(tmp_path: Path, minimum_config: dict): - """Test that invalid working directory raise an error. - - Since its parent does not exist, this case is invalid. - """ - path = tmp_path / "tmp" / "tmp" - minimum_config["working_directory"] = str(path) - with pytest.raises(ValueError): - Configuration(**minimum_config) - - path = tmp_path / "tmp.txt" - path.touch() - minimum_config["working_directory"] = str(path) - with pytest.raises(ValueError): - Configuration(**minimum_config) - - -def test_3D_algorithm_and_data_compatibility(minimum_config: dict): - """Test that errors are raised if algithm `is_3D` and data axes are incompatible.""" - # 3D but no Z in axes - minimum_config["algorithm"]["is_3D"] = True - with pytest.raises(ValueError): - Configuration(**minimum_config) - - # 2D but Z in axes - minimum_config["algorithm"]["is_3D"] = False - minimum_config["data"]["axes"] = "ZYX" - with pytest.raises(ValueError): - Configuration(**minimum_config) - - -def test_set_3D(minimum_config: dict): - """Test the set 3D method.""" - conf = Configuration(**minimum_config) - - # set to 3D - conf.set_3D(True, "ZYX") - - # set to 2D - conf.set_3D(False, "SYX") - - # fails if they are not compatible - with pytest.raises(ValueError): - conf.set_3D(True, "SYX") - - with pytest.raises(ValueError): - conf.set_3D(False, "ZYX") - - -def test_wrong_values_by_assignment(complete_config: dict): - """Test that wrong values raise an error when assigned.""" - config = Configuration(**complete_config) - - # experiment name - config.experiment_name = "My name is Inigo Montoya" - with pytest.raises(ValueError): - config.experiment_name = "ยฏ\\_(ใƒ„)_/ยฏ" - - # working directory - config.working_directory = complete_config["working_directory"] - with pytest.raises(ValueError): - config.working_directory = "o/o" - - # data - config.data = complete_config["data"] - with pytest.raises(ValueError): - config.data = "I am not a data model" - - # algorithm - config.algorithm = complete_config["algorithm"] - with pytest.raises(ValueError): - config.algorithm = None - - # training - config.training = complete_config["training"] - with pytest.raises(ValueError): - config.training = "Hubert Blaine Wolfeschlegelsteinhausenbergerdorff Sr." - - # TODO Because algorithm is a sub-model of Configuration, and the validation is - # done at the level of the Configuration, this does not cause any error, although - # it should. - config.algorithm.is_3D = True - - -def test_minimum_config(minimum_config: dict): - """Test that we can instantiate a minimum config.""" - dictionary = Configuration(**minimum_config).model_dump() - assert dictionary == minimum_config - - -def test_complete_config(complete_config: dict): - """Test that we can instantiate a minimum config.""" - dictionary = Configuration(**complete_config).model_dump() - assert dictionary == complete_config - - -def test_config_to_dict_with_default_optionals(complete_config: dict): - """Test that the exclude optional options in model dump gives a full configuration, - including the default optional values. - - Note that None values are always excluded. - """ - # Algorithm default optional parameters - complete_config["algorithm"]["masking_strategy"] = "default" - complete_config["algorithm"]["masked_pixel_percentage"] = 0.2 - complete_config["algorithm"]["model_parameters"] = { - "depth": 2, - "num_channels_init": 32, - } - - # Training default optional parameters - complete_config["training"]["optimizer"]["parameters"] = {} - complete_config["training"]["lr_scheduler"]["parameters"] = {} - complete_config["training"]["use_wandb"] = True - complete_config["training"]["num_workers"] = 0 - complete_config["training"]["amp"] = { - "use": True, - "init_scale": 1024, - } - - # instantiate config - myconf = Configuration(**complete_config) - assert myconf.model_dump(exclude_optionals=False) == complete_config - - -def test_config_to_yaml(tmp_path: Path, minimum_config: dict): - """Test that we can export a config to yaml and load it back""" - - # test that we can instantiate a config - myconf = Configuration(**minimum_config) - - # export to yaml - yaml_path = save_configuration(myconf, tmp_path) - assert yaml_path.exists() - - # load from yaml - my_other_conf = load_configuration(yaml_path) - assert my_other_conf == myconf - - -def test_config_to_yaml_wrong_path(tmp_path: Path, minimum_config: dict): - """Test that an error is raised when the path is not a directory and not a .yml""" - - # test that we can instantiate a config - myconf = Configuration(**minimum_config) - - # export to yaml - yaml_path = tmp_path / "tmp.txt" - with pytest.raises(ValueError): - save_configuration(myconf, yaml_path) - - # existing file - yaml_path.touch() - with pytest.raises(ValueError): - save_configuration(myconf, yaml_path) diff --git a/tests/config/test_config_filters.py b/tests/config/test_config_filters.py deleted file mode 100644 index c2eac5c3..00000000 --- a/tests/config/test_config_filters.py +++ /dev/null @@ -1,42 +0,0 @@ -from pathlib import Path - -from careamics.config.config_filter import ( - paths_to_str, - remove_default_optionals, -) - - -def test_paths_to_str(): - """Test paths_to_str.""" - dictionary = { - "path1": Path("path1"), - "path2": "path2", - "path3": Path("path3"), - "path4": 3, - } - dictionary = paths_to_str(dictionary) - assert isinstance(dictionary["path1"], str) - assert isinstance(dictionary["path2"], str) - assert isinstance(dictionary["path3"], str) - assert isinstance(dictionary["path4"], int) - - -def test_remove_default_optionals(): - """Test remove default optionals.""" - dictionary = { - "key1": "value1", - "key2": 2, - "key3": "value3", - "key4": 5.5, - } - default = { - "key1": "value6", - "key2": 2, - "key3": "value3", - } - - remove_default_optionals(dictionary, default) - assert dictionary["key1"] == "value1" - assert "key2" not in dictionary.keys() - assert "key3" not in dictionary.keys() - assert dictionary["key4"] == 5.5 diff --git a/tests/config/test_configuration_factory.py b/tests/config/test_configuration_factory.py new file mode 100644 index 00000000..711c4375 --- /dev/null +++ b/tests/config/test_configuration_factory.py @@ -0,0 +1,222 @@ +import pytest + +from careamics.config import create_n2v_configuration +from careamics.config.support import ( + SupportedPixelManipulation, + SupportedStructAxis, + SupportedTransform, +) + + +def test_n2v_configuration(): + """Test that N2V configuration can be created.""" + config = create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + ) + assert ( + config.data_config.transforms[-1].name + == SupportedTransform.N2V_MANIPULATE.value + ) + assert ( + config.data_config.transforms[-1].strategy + == SupportedPixelManipulation.UNIFORM.value + ) + assert not config.data_config.transforms[-2].is_3D # XY_RANDOM_ROTATE90 + assert not config.data_config.transforms[-3].is_3D # NDFLIP + assert not config.algorithm_config.model.is_3D() + + +def test_n2v_3d_configuration(): + """Test that N2V configuration can be created in 3D.""" + config = create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="ZYX", + patch_size=[64, 64, 64], + batch_size=8, + num_epochs=100, + ) + assert ( + config.data_config.transforms[-1].name + == SupportedTransform.N2V_MANIPULATE.value + ) + assert ( + config.data_config.transforms[-1].strategy + == SupportedPixelManipulation.UNIFORM.value + ) + assert config.data_config.transforms[-2].is_3D # XY_RANDOM_ROTATE90 + assert config.data_config.transforms[-3].is_3D # NDFLIP + assert config.algorithm_config.model.is_3D() + + +def test_n2v_3d_error(): + """Test that errors are raised if algorithm `is_3D` and data axes are + incompatible.""" + with pytest.raises(ValueError): + create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="ZYX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + ) + + with pytest.raises(ValueError): + create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64, 64], + batch_size=8, + num_epochs=100, + ) + + +def test_n2v_model_parameters(): + """Test passing N2V UNet parameters, and that explicit parameters override the + model_kwargs ones.""" + config = create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + use_n2v2=False, + model_kwargs={ + "depth": 4, + "n2v2": True, + "in_channels": 2, + "num_classes": 5, + }, + ) + assert config.algorithm_config.model.depth == 4 + assert not config.algorithm_config.model.n2v2 + + # set to 1 because no C specified + assert config.algorithm_config.model.in_channels == 1 + assert config.algorithm_config.model.num_classes == 1 + + +def test_n2v_model_parameters_channels(): + """Test that the number of channels in the function call has priority over the + model kwargs.""" + config = create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YXC", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + n_channels=4, + model_kwargs={ + "depth": 4, + "n2v2": True, + "in_channels": 2, + "num_classes": 5, + }, + ) + assert config.algorithm_config.model.in_channels == 4 + assert config.algorithm_config.model.num_classes == 4 + + +def test_n2v_model_parameters_channels_error(): + """Test that an error is raised if the number of channels is not specified and + C in axes, or C in axes and number of channels not specified.""" + with pytest.raises(ValueError): + create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YXC", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + ) + + with pytest.raises(ValueError): + create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + n_channels=5, + ) + + +def test_n2v_no_aug(): + """Test that N2V configuration can be created without augmentation.""" + config = create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + use_augmentations=False, + ) + assert len(config.data_config.transforms) == 2 + assert ( + config.data_config.transforms[-1].name + == SupportedTransform.N2V_MANIPULATE.value + ) + assert config.data_config.transforms[-2].name == SupportedTransform.NORMALIZE.value + + +def test_n2v_augmentation_parameters(): + """Test that N2V configuration can be created with augmentation parameters.""" + config = create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + roi_size=17, + masked_pixel_percentage=0.5, + ) + assert config.data_config.transforms[-1].roi_size == 17 + assert config.data_config.transforms[-1].masked_pixel_percentage == 0.5 + + +def test_n2v2(): + """Test that N2V2 configuration can be created.""" + config = create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + use_n2v2=True, + ) + assert ( + config.data_config.transforms[-1].strategy + == SupportedPixelManipulation.MEDIAN.value + ) + + +def test_structn2v(): + """Test that StructN2V configuration can be created.""" + config = create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + struct_n2v_axis=SupportedStructAxis.HORIZONTAL.value, + struct_n2v_span=7, + ) + assert ( + config.data_config.transforms[-1].struct_mask_axis + == SupportedStructAxis.HORIZONTAL.value + ) + assert config.data_config.transforms[-1].struct_mask_span == 7 diff --git a/tests/config/test_configuration_model.py b/tests/config/test_configuration_model.py new file mode 100644 index 00000000..353184cf --- /dev/null +++ b/tests/config/test_configuration_model.py @@ -0,0 +1,191 @@ +from pathlib import Path + +import pytest + +from careamics.config import ( + Configuration, + load_configuration, + save_configuration, +) +from careamics.config.support import ( + SupportedAlgorithm, + SupportedPixelManipulation, + SupportedTransform, +) + + +@pytest.mark.parametrize("name", ["Sn4K3", "C4_M e-L"]) +def test_valid_names(minimum_configuration: dict, name: str): + """Test valid names (letters, numbers, spaces, dashes and underscores).""" + minimum_configuration["experiment_name"] = name + myconf = Configuration(**minimum_configuration) + assert myconf.experiment_name == name + + +@pytest.mark.parametrize("name", ["", " ", "#", "/", "^", "%", ",", ".", "a=b"]) +def test_invalid_names(minimum_configuration: dict, name: str): + """Test that invalid names raise an error.""" + minimum_configuration["experiment_name"] = name + with pytest.raises(ValueError): + Configuration(**minimum_configuration) + + +def test_3D_algorithm_and_data_compatibility(minimum_configuration: dict): + """Test that errors are raised if algorithm `is_3D` and data axes are + incompatible. + """ + # 3D but no Z in axes + minimum_configuration["algorithm_config"]["model"]["conv_dims"] = 3 + config = Configuration(**minimum_configuration) + assert config.algorithm_config.model.conv_dims == 2 + + # 2D but Z in axes + minimum_configuration["algorithm_config"]["model"]["conv_dims"] = 2 + minimum_configuration["data_config"]["axes"] = "ZYX" + minimum_configuration["data_config"]["patch_size"] = [64, 64, 64] + config = Configuration(**minimum_configuration) + assert config.algorithm_config.model.conv_dims == 3 + + +def test_set_3D(minimum_configuration: dict): + """Test the set 3D method.""" + conf = Configuration(**minimum_configuration) + + # set to 3D + conf.set_3D(True, "ZYX", [64, 64, 64]) + assert conf.data_config.axes == "ZYX" + assert conf.data_config.patch_size == [64, 64, 64] + assert conf.algorithm_config.model.conv_dims == 3 + + # set to 2D + conf.set_3D(False, "SYX", [64, 64]) + assert conf.data_config.axes == "SYX" + assert conf.data_config.patch_size == [64, 64] + assert conf.algorithm_config.model.conv_dims == 2 + + +def test_algorithm_and_data_default_transforms(minimum_configuration: dict): + """Test that the default data transforms are compatible with n2v.""" + minimum_configuration["algorithm_config"] = { + "algorithm": "n2v", + "loss": "n2v", + "model": { + "architecture": "UNet", + }, + } + Configuration(**minimum_configuration) + + +@pytest.mark.parametrize( + "algorithm, strategy", + [ + ("n2v", SupportedPixelManipulation.UNIFORM.value), + ("n2v", SupportedPixelManipulation.MEDIAN.value), + ("n2v2", SupportedPixelManipulation.UNIFORM.value), + ("n2v2", SupportedPixelManipulation.MEDIAN.value), + ], +) +def test_n2v2_and_transforms(minimum_configuration: dict, algorithm, strategy): + """Test that the manipulation strategy is corrected if the data transforms are + incompatible with n2v2.""" + use_n2v2 = algorithm == "n2v2" + minimum_configuration["algorithm_config"] = { + "algorithm": "n2v", + "loss": "n2v", + "model": { + "architecture": "UNet", + "n2v2": use_n2v2, + }, + } + + expected_strategy = ( + SupportedPixelManipulation.MEDIAN.value + if use_n2v2 + else SupportedPixelManipulation.UNIFORM.value + ) + + # missing ManipulateN2V + minimum_configuration["data_config"]["transforms"] = [ + {"name": SupportedTransform.NDFLIP.value} + ] + config = Configuration(**minimum_configuration) + assert len(config.data_config.transforms) == 2 + assert ( + config.data_config.transforms[-1].name + == SupportedTransform.N2V_MANIPULATE.value + ) + assert config.data_config.transforms[-1].strategy == expected_strategy + + # passing ManipulateN2V with the wrong strategy + minimum_configuration["data_config"]["transforms"] = [ + { + "name": SupportedTransform.N2V_MANIPULATE.value, + "strategy": strategy, + } + ] + config = Configuration(**minimum_configuration) + assert config.data_config.transforms[-1].strategy == expected_strategy + + +def test_setting_n2v2(minimum_configuration: dict): + # make sure we use n2v + minimum_configuration["algorithm_config"][ + "algorithm" + ] = SupportedAlgorithm.N2V.value + + # test config + config = Configuration(**minimum_configuration) + assert config.algorithm_config.algorithm == SupportedAlgorithm.N2V.value + assert not config.algorithm_config.model.n2v2 + assert ( + config.data_config.transforms[-1].strategy + == SupportedPixelManipulation.UNIFORM.value + ) + + # set N2V2 + config.set_N2V2(True) + assert config.algorithm_config.model.n2v2 + assert ( + config.data_config.transforms[-1].strategy + == SupportedPixelManipulation.MEDIAN.value + ) + + # set back to N2V + config.set_N2V2(False) + assert not config.algorithm_config.model.n2v2 + assert ( + config.data_config.transforms[-1].strategy + == SupportedPixelManipulation.UNIFORM.value + ) + + +def test_config_to_yaml(tmp_path: Path, minimum_configuration: dict): + """Test that we can export a config to yaml and load it back""" + + # test that we can instantiate a config + myconf = Configuration(**minimum_configuration) + + # export to yaml + yaml_path = save_configuration(myconf, tmp_path) + assert yaml_path.exists() + + # load from yaml + my_other_conf = load_configuration(yaml_path) + assert my_other_conf == myconf + + +def test_config_to_yaml_wrong_path(tmp_path: Path, minimum_configuration: dict): + """Test that an error is raised when the path is not a directory and not a .yml""" + + # test that we can instantiate a config + myconf = Configuration(**minimum_configuration) + + # export to yaml + yaml_path = tmp_path / "tmp.txt" + with pytest.raises(ValueError): + save_configuration(myconf, yaml_path) + + # existing file + yaml_path.touch() + with pytest.raises(ValueError): + save_configuration(myconf, yaml_path) diff --git a/tests/config/test_data.py b/tests/config/test_data.py deleted file mode 100644 index 67f6923c..00000000 --- a/tests/config/test_data.py +++ /dev/null @@ -1,142 +0,0 @@ -import pytest - -from careamics.config.data import Data, SupportedExtension - - -@pytest.mark.parametrize("ext", ["tiff", "tif", "TIFF", "TIF", ".TIF"]) -def test_supported_extensions_case_insensitive(ext: str): - """Test that SupportedExtension enum accepts all extensions in upper - cases and with .""" - sup_ext = SupportedExtension(ext) - - new_ext = ext.lower() - if ext.startswith("."): - new_ext = new_ext[1:] - - assert sup_ext.value == new_ext - - -@pytest.mark.parametrize("ext", ["nd2", "jpg", "png ", "zarr", "npy"]) -def test_wrong_extensions(minimum_config: dict, ext: str): - """Test that supported model raises ValueError for unsupported extensions.""" - data_config = minimum_config["data"] - data_config["data_format"] = ext - - # instantiate Data model - with pytest.raises(ValueError): - Data(**data_config) - - -@pytest.mark.parametrize("mean, std", [(0, 124.5), (12.6, 0.1)]) -def test_mean_std_non_negative(complete_config: dict, mean, std): - """Test that non negative mean and std are accepted.""" - complete_config["data"]["mean"] = mean - complete_config["data"]["std"] = std - - data_model = Data(**complete_config["data"]) - assert data_model.mean == mean - assert data_model.std == std - - -def test_mean_std_negative(complete_config: dict): - """Test that negative mean and std are not accepted.""" - complete_config["data"]["mean"] = -1 - complete_config["data"]["std"] = 10.4 - - with pytest.raises(ValueError): - Data(**complete_config["data"]) - - complete_config["data"]["mean"] = 10.4 - complete_config["data"]["std"] = -1 - - with pytest.raises(ValueError): - Data(**complete_config["data"]) - - -def test_mean_std_both_specified_or_none(complete_config: dict): - """Test an error is raised if std is specified but mean is None.""" - # No error if both are None - complete_config["data"].pop("mean") - complete_config["data"].pop("std") - Data(**complete_config["data"]) - - # No error if mean is defined - complete_config["data"]["mean"] = 10.4 - Data(**complete_config["data"]) - - # Error if only std is defined - complete_config["data"].pop("mean") - complete_config["data"]["std"] = 10.4 - - with pytest.raises(ValueError): - Data(**complete_config["data"]) - - -def test_set_mean_and_std(complete_config: dict): - """Test that mean and std can be set after initialization.""" - # they can be set both, when they are already set - data = Data(**complete_config["data"]) - data.set_mean_and_std(4.07, 14.07) - - # and if they are both None - complete_config["data"].pop("mean") - complete_config["data"].pop("std") - data = Data(**complete_config["data"]) - data.set_mean_and_std(10.4, 0.5) - - -def test_wrong_values_by_assigment(complete_config: dict): - """Test that wrong values are not accepted through assignment.""" - data_model = Data(**complete_config["data"]) - - # in memory - data_model.in_memory = complete_config["data"]["in_memory"] - with pytest.raises(ValueError): - data_model.in_memory = "Trues" - - # data format - data_model.data_format = complete_config["data"]["data_format"] # check assignment - with pytest.raises(ValueError): - data_model.data_format = "png" - - # axes - data_model.axes = complete_config["data"]["axes"] - with pytest.raises(ValueError): - data_model.axes = "-YX" - - # mean - data_model.mean = complete_config["data"]["mean"] - with pytest.raises(ValueError): - data_model.mean = -1 - - # std - data_model.std = complete_config["data"]["std"] - with pytest.raises(ValueError): - data_model.std = -1 - - -def test_data_to_dict_minimum(minimum_config: dict): - """Test that export to dict does not include None values and Paths. - - In the minimum config, only training+validation should be defined, all the other - paths are None.""" - data_minimum = Data(**minimum_config["data"]).model_dump() - assert data_minimum == minimum_config["data"] - - assert "in_memory" in data_minimum.keys() - assert "data_format" in data_minimum.keys() - assert "axes" in data_minimum.keys() - assert "mean" not in data_minimum.keys() - assert "std" not in data_minimum.keys() - - -def test_data_to_dict_complete(complete_config: dict): - """Test that export to dict does not include None values and Paths.""" - data_complete = Data(**complete_config["data"]).model_dump() - assert data_complete == complete_config["data"] - - assert "in_memory" in data_complete.keys() - assert "data_format" in data_complete.keys() - assert "axes" in data_complete.keys() - assert "mean" in data_complete.keys() - assert "std" in data_complete.keys() diff --git a/tests/config/test_data_model.py b/tests/config/test_data_model.py new file mode 100644 index 00000000..4e59d776 --- /dev/null +++ b/tests/config/test_data_model.py @@ -0,0 +1,382 @@ +import pytest +from albumentations import Compose + +from careamics.config.data_model import DataModel +from careamics.config.support import ( + SupportedPixelManipulation, + SupportedStructAxis, + SupportedTransform, +) +from careamics.config.transformations import ( + N2VManipulateModel, + NDFlipModel, + NormalizeModel, + XYRandomRotate90Model, +) +from careamics.transforms import get_all_transforms + + +@pytest.mark.parametrize("ext", ["nd2", "jpg", "png ", "zarr", "npy"]) +def test_wrong_extensions(minimum_data: dict, ext: str): + """Test that supported model raises ValueError for unsupported extensions.""" + minimum_data["data_type"] = ext + + # instantiate DataModel model + with pytest.raises(ValueError): + DataModel(**minimum_data) + + +@pytest.mark.parametrize("mean, std", [(0, 124.5), (12.6, 0.1)]) +def test_mean_std_non_negative(minimum_data: dict, mean, std): + """Test that non negative mean and std are accepted.""" + minimum_data["mean"] = mean + minimum_data["std"] = std + + data_model = DataModel(**minimum_data) + assert data_model.mean == mean + assert data_model.std == std + + +def test_mean_std_both_specified_or_none(minimum_data: dict): + """Test an error is raised if std is specified but mean is None.""" + # No error if both are None + DataModel(**minimum_data) + + # Error if only mean is defined + minimum_data["mean"] = 10.4 + with pytest.raises(ValueError): + DataModel(**minimum_data) + + # Error if only std is defined + minimum_data.pop("mean") + minimum_data["std"] = 10.4 + with pytest.raises(ValueError): + DataModel(**minimum_data) + + # No error if both are specified + minimum_data["mean"] = 10.4 + minimum_data["std"] = 10.4 + DataModel(**minimum_data) + + +def test_set_mean_and_std(minimum_data: dict): + """Test that mean and std can be set after initialization.""" + # they can be set both, when they None + mean = 4.07 + std = 14.07 + data = DataModel(**minimum_data) + data.set_mean_and_std(mean, std) + assert data.mean == mean + assert data.std == std + + # and if they are already set + minimum_data["mean"] = 10.4 + minimum_data["std"] = 3.2 + data = DataModel(**minimum_data) + data.set_mean_and_std(mean, std) + assert data.mean == mean + assert data.std == std + + +def test_mean_and_std_in_normalize(minimum_data: dict): + """Test that mean and std are added to the Normalize transform.""" + minimum_data["mean"] = 10.4 + minimum_data["std"] = 3.2 + minimum_data["transforms"] = [ + {"name": SupportedTransform.NORMALIZE.value}, + ] + data = DataModel(**minimum_data) + assert data.transforms[0].mean == 10.4 + assert data.transforms[0].std == 3.2 + + +def test_patch_size(minimum_data: dict): + """Test that non-zero even patch size are accepted.""" + # 2D + data_model = DataModel(**minimum_data) + + # 3D + minimum_data["patch_size"] = [16, 8, 8] + minimum_data["axes"] = "ZYX" + + data_model = DataModel(**minimum_data) + assert data_model.patch_size == minimum_data["patch_size"] + + +@pytest.mark.parametrize( + "patch_size", [[12], [0, 12, 12], [12, 12, 13], [16, 10, 16], [12, 12, 12, 12]] +) +def test_wrong_patch_size(minimum_data: dict, patch_size): + """Test that wrong patch sizes are not accepted (zero or odd, dims 1 or > 3).""" + minimum_data["axes"] = "ZYX" if len(patch_size) == 3 else "YX" + minimum_data["patch_size"] = patch_size + + with pytest.raises(ValueError): + DataModel(**minimum_data) + + +def test_set_3d(minimum_data: dict): + """Test that 3D can be set.""" + data = DataModel(**minimum_data) + assert "Z" not in data.axes + assert len(data.patch_size) == 2 + + # error if changing Z manually + with pytest.raises(ValueError): + data.axes = "ZYX" + + # or patch size + data = DataModel(**minimum_data) + with pytest.raises(ValueError): + data.patch_size = [64, 64, 64] + + # set 3D + data = DataModel(**minimum_data) + data.set_3D("ZYX", [64, 64, 64]) + assert "Z" in data.axes + assert len(data.patch_size) == 3 + + +@pytest.mark.parametrize( + "transforms", + [ + [ + {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.N2V_MANIPULATE.value}, + ], + [ + {"name": SupportedTransform.NDFLIP.value}, + ], + [ + {"name": SupportedTransform.NORMALIZE.value}, + {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, + {"name": SupportedTransform.N2V_MANIPULATE.value}, + ], + ], +) +def test_passing_supported_transforms(minimum_data: dict, transforms): + """Test that list of supported transforms can be passed.""" + minimum_data["transforms"] = transforms + model = DataModel(**minimum_data) + + supported = { + "NDFlip": NDFlipModel, + "XYRandomRotate90": XYRandomRotate90Model, + "Normalize": NormalizeModel, + "N2VManipulate": N2VManipulateModel, + } + + for ind, t in enumerate(transforms): + assert t["name"] == model.transforms[ind].name + assert isinstance(model.transforms[ind], supported[t["name"]]) + + +@pytest.mark.parametrize( + "transforms", + [ + [ + {"name": SupportedTransform.N2V_MANIPULATE.value}, + {"name": SupportedTransform.NDFLIP.value}, + ], + [ + {"name": SupportedTransform.N2V_MANIPULATE.value}, + ], + [ + {"name": SupportedTransform.NORMALIZE.value}, + {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.N2V_MANIPULATE.value}, + {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, + ], + ], +) +def test_n2vmanipulate_last_transform(minimum_data: dict, transforms): + """Test that N2V Manipulate is moved to the last position if it is not.""" + minimum_data["transforms"] = transforms + model = DataModel(**minimum_data) + assert model.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value + + +def test_multiple_n2v_manipulate(minimum_data: dict): + """Test that passing multiple n2v manipulate raises an error.""" + minimum_data["transforms"] = [ + {"name": SupportedTransform.N2V_MANIPULATE.value}, + {"name": SupportedTransform.N2V_MANIPULATE.value}, + ] + with pytest.raises(ValueError): + DataModel(**minimum_data) + + +def test_remove_n2v_manipulate(minimum_data: dict): + """Test that N2V Manipulate can be removed.""" + minimum_data["transforms"] = [ + {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.N2V_MANIPULATE.value}, + ] + model = DataModel(**minimum_data) + model.remove_n2v_manipulate() + assert len(model.transforms) == 1 + assert model.transforms[-1].name == SupportedTransform.NDFLIP.value + + +def test_add_n2v_manipulate(minimum_data: dict): + """Test that N2V Manipulate can be added.""" + minimum_data["transforms"] = [ + {"name": SupportedTransform.NDFLIP.value}, + ] + model = DataModel(**minimum_data) + model.add_n2v_manipulate() + assert len(model.transforms) == 2 + assert model.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value + + # test that adding twice doesn't change anything + model.add_n2v_manipulate() + assert len(model.transforms) == 2 + assert model.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value + + +def test_correct_transform_parameters(minimum_data: dict): + """Test that the transforms have the correct parameters. + + This is important to know that the transforms are not all instantiated as + a generic transform. + """ + minimum_data["transforms"] = [ + {"name": SupportedTransform.NORMALIZE.value}, + {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, + {"name": SupportedTransform.N2V_MANIPULATE.value}, + ] + model = DataModel(**minimum_data) + + # Normalize + params = model.transforms[0].model_dump() + assert "mean" in params + assert "std" in params + + # NDFlip + params = model.transforms[1].model_dump() + assert "p" in params + assert "is_3D" in params + assert "flip_z" in params + + # XYRandomRotate90 + params = model.transforms[2].model_dump() + assert "p" in params + assert "is_3D" in params + assert isinstance(model.transforms[2], XYRandomRotate90Model) + + # N2VManipulate + params = model.transforms[3].model_dump() + assert "roi_size" in params + assert "masked_pixel_percentage" in params + assert "strategy" in params + assert "struct_mask_axis" in params + assert "struct_mask_span" in params + + +def test_passing_empty_transforms(minimum_data: dict): + """Test that empty list of transforms can be passed.""" + minimum_data["transforms"] = [] + DataModel(**minimum_data) + + +def test_passing_incorrect_element(minimum_data: dict): + """Test that incorrect element in the list of transforms raises an error ( + e.g. passing un object rather than a string).""" + minimum_data["transforms"] = [ + {"name": get_all_transforms()[SupportedTransform.NDFLIP.value]()}, + ] + with pytest.raises(ValueError): + DataModel(**minimum_data) + + +def test_passing_compose_transform(minimum_data: dict): + """Test that Compose transform can be passed.""" + minimum_data["transforms"] = Compose( + [ + get_all_transforms()[SupportedTransform.NDFLIP](), + get_all_transforms()[SupportedTransform.N2V_MANIPULATE](), + ] + ) + DataModel(**minimum_data) + + +def test_3D_and_transforms(minimum_data: dict): + """Test that NDFlip is corrected if the data is 3D.""" + minimum_data["transforms"] = [ + { + "name": SupportedTransform.NDFLIP.value, + "is_3D": True, + "flip_z": True, + }, + { + "name": SupportedTransform.XY_RANDOM_ROTATE90.value, + "is_3D": True, + }, + ] + data = DataModel(**minimum_data) + assert data.transforms[0].is_3D is False + assert data.transforms[1].is_3D is False + + # change to 3D + data.set_3D("ZYX", [64, 64, 64]) + data.transforms[0].is_3D = True + data.transforms[1].is_3D = True + + +def test_set_n2v_strategy(minimum_data: dict): + """Test that the N2V strategy can be set.""" + uniform = SupportedPixelManipulation.UNIFORM.value + median = SupportedPixelManipulation.MEDIAN.value + + data = DataModel(**minimum_data) + assert data.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value + assert data.transforms[-1].strategy == uniform + + data.set_N2V2_strategy(median) + assert data.transforms[-1].strategy == median + + data.set_N2V2_strategy(uniform) + assert data.transforms[-1].strategy == uniform + + +def test_set_n2v_strategy_wrong_value(minimum_data: dict): + """Test that passing a wrong strategy raises an error.""" + data = DataModel(**minimum_data) + with pytest.raises(ValueError): + data.set_N2V2_strategy("wrong_value") + + +def test_set_struct_mask(minimum_data: dict): + """Test that the struct mask can be set.""" + none = SupportedStructAxis.NONE.value + vertical = SupportedStructAxis.VERTICAL.value + horizontal = SupportedStructAxis.HORIZONTAL.value + + data = DataModel(**minimum_data) + assert data.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value + assert data.transforms[-1].struct_mask_axis == none + assert data.transforms[-1].struct_mask_span == 5 + + data.set_structN2V_mask(vertical, 3) + assert data.transforms[-1].struct_mask_axis == vertical + assert data.transforms[-1].struct_mask_span == 3 + + data.set_structN2V_mask(horizontal, 7) + assert data.transforms[-1].struct_mask_axis == horizontal + assert data.transforms[-1].struct_mask_span == 7 + + data.set_structN2V_mask(none, 11) + assert data.transforms[-1].struct_mask_axis == none + assert data.transforms[-1].struct_mask_span == 11 + + +def test_set_struct_mask_wrong_value(minimum_data: dict): + """Test that passing a wrong struct mask axis raises an error.""" + data = DataModel(**minimum_data) + with pytest.raises(ValueError): + data.set_structN2V_mask("wrong_value", 3) + + with pytest.raises(ValueError): + data.set_structN2V_mask(SupportedStructAxis.VERTICAL.value, 1) diff --git a/tests/config/test_inference_model.py b/tests/config/test_inference_model.py new file mode 100644 index 00000000..885a4816 --- /dev/null +++ b/tests/config/test_inference_model.py @@ -0,0 +1,185 @@ +import pytest +from albumentations import Compose + +from careamics.config.inference_model import InferenceModel +from careamics.config.support import ( + SupportedTransform, +) +from careamics.transforms import get_all_transforms + + +@pytest.mark.parametrize("ext", ["nd2", "jpg", "png ", "zarr", "npy"]) +def test_wrong_extensions(minimum_inference: dict, ext: str): + """Test that supported model raises ValueError for unsupported extensions.""" + minimum_inference["data_type"] = ext + + # instantiate InferenceModel model + with pytest.raises(ValueError): + InferenceModel(**minimum_inference) + + +def test_mean_std_both_specified_or_none(minimum_inference: dict): + """Test error raising when setting mean and std.""" + # Errors if both are None + minimum_inference["mean"] = None + minimum_inference["std"] = None + with pytest.raises(ValueError): + InferenceModel(**minimum_inference) + + # Error if only mean is defined + minimum_inference["mean"] = 10.4 + with pytest.raises(ValueError): + InferenceModel(**minimum_inference) + + # Error if only std is defined + minimum_inference.pop("mean") + minimum_inference["std"] = 10.4 + with pytest.raises(ValueError): + InferenceModel(**minimum_inference) + + # No error if both are specified + minimum_inference["mean"] = 10.4 + minimum_inference["std"] = 10.4 + InferenceModel(**minimum_inference) + + +def test_tile_size(minimum_inference: dict): + """Test that non-zero even patch size are accepted.""" + # no tiling + prediction_model = InferenceModel(**minimum_inference) + + # 2D + minimum_inference["tile_size"] = [16, 8] + minimum_inference["tile_overlap"] = [2, 2] + minimum_inference["axes"] = "YX" + + prediction_model = InferenceModel(**minimum_inference) + assert prediction_model.tile_size == minimum_inference["tile_size"] + assert prediction_model.tile_overlap == minimum_inference["tile_overlap"] + + # 3D + minimum_inference["tile_size"] = [16, 8, 32] + minimum_inference["tile_overlap"] = [2, 2, 2] + minimum_inference["axes"] = "ZYX" + + prediction_model = InferenceModel(**minimum_inference) + assert prediction_model.tile_size == minimum_inference["tile_size"] + assert prediction_model.tile_overlap == minimum_inference["tile_overlap"] + + +@pytest.mark.parametrize( + "tile_size", [[12], [0, 12, 12], [12, 12, 13], [12, 12, 12, 12]] +) +def test_wrong_tile_size(minimum_inference: dict, tile_size): + """Test that wrong patch sizes are not accepted (zero or odd, dims 1 or > 3).""" + minimum_inference["axes"] = "ZYX" if len(tile_size) == 3 else "YX" + minimum_inference["tile_size"] = tile_size + + with pytest.raises(ValueError): + InferenceModel(**minimum_inference) + + +@pytest.mark.parametrize( + "tile_size, tile_overlap", [([12, 12], [2, 2, 2]), ([12, 12, 12], [14, 2, 2])] +) +def test_wrong_tile_overlap(minimum_inference: dict, tile_size, tile_overlap): + """Test that wrong patch sizes are not accepted (zero or odd, dims 1 or > 3).""" + minimum_inference["axes"] = "ZYX" if len(tile_size) == 3 else "YX" + minimum_inference["tile_size"] = tile_size + minimum_inference["tile_overlap"] = tile_overlap + + with pytest.raises(ValueError): + InferenceModel(**minimum_inference) + + +def test_set_3d(minimum_inference: dict): + """Test that 3D can be set.""" + minimum_inference["tile_size"] = [64, 64] + minimum_inference["tile_overlap"] = [32, 32] + + pred = InferenceModel(**minimum_inference) + assert "Z" not in pred.axes + assert len(pred.tile_size) == 2 + assert len(pred.tile_overlap) == 2 + + # error if changing Z manually + with pytest.raises(ValueError): + pred.axes = "ZYX" + + # or patch size + pred = InferenceModel(**minimum_inference) + with pytest.raises(ValueError): + pred.tile_size = [64, 64, 64] + + with pytest.raises(ValueError): + pred.tile_overlap = [64, 64, 64] + + # set 3D + pred = InferenceModel(**minimum_inference) + pred.set_3D("ZYX", [64, 64, 64], [32, 32, 32]) + assert "Z" in pred.axes + assert len(pred.tile_size) == 3 + assert len(pred.tile_overlap) == 3 + + +@pytest.mark.parametrize( + "transforms", + [ + [ + {"name": SupportedTransform.NORMALIZE.value}, + ], + ], +) +def test_passing_supported_transforms(minimum_inference: dict, transforms): + """Test that list of supported transforms can be passed.""" + minimum_inference["transforms"] = transforms + InferenceModel(**minimum_inference) + + +def test_cannot_pass_n2v_manipulate(minimum_inference: dict): + """Test that passing N2V pixel manipulate transform raises an error.""" + minimum_inference["transforms"] = [ + {"name": SupportedTransform.N2V_MANIPULATE.value}, + ] + with pytest.raises(ValueError): + InferenceModel(**minimum_inference) + + +def test_passing_empty_transforms(minimum_inference: dict): + """Test that empty list of transforms can be passed.""" + minimum_inference["transforms"] = [] + InferenceModel(**minimum_inference) + + +def test_passing_incorrect_element(minimum_inference: dict): + """Test that incorrect element in the list of transforms raises an error ( + e.g. passing un object rather than a string).""" + minimum_inference["transforms"] = [ + {"name": get_all_transforms()[SupportedTransform.NDFLIP.value]()}, + ] + with pytest.raises(ValueError): + InferenceModel(**minimum_inference) + + +def test_passing_compose_transform(minimum_inference: dict): + """Test that Compose transform can be passed.""" + minimum_inference["transforms"] = Compose( + [ + get_all_transforms()[SupportedTransform.NORMALIZE](mean=10.4, std=3.2), + get_all_transforms()[SupportedTransform.NDFLIP](), + ] + ) + InferenceModel(**minimum_inference) + + +def test_mean_and_std_in_normalize(minimum_inference: dict): + """Test that mean and std are added to the Normalize transform.""" + minimum_inference["mean"] = 10.4 + minimum_inference["std"] = 3.2 + minimum_inference["transforms"] = [ + {"name": SupportedTransform.NORMALIZE.value}, + ] + + data = InferenceModel(**minimum_inference) + assert data.transforms[0].mean == 10.4 + assert data.transforms[0].std == 3.2 diff --git a/tests/config/test_optimizers_model.py b/tests/config/test_optimizers_model.py new file mode 100644 index 00000000..ba91e152 --- /dev/null +++ b/tests/config/test_optimizers_model.py @@ -0,0 +1,146 @@ +import pytest + +from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel +from careamics.config.support.supported_optimizers import ( + SupportedOptimizer, + SupportedScheduler, +) + + +@pytest.mark.parametrize( + "optimizer_name, parameters", + [ + ( + SupportedOptimizer.Adam.value, + { + "lr": 0.08, + "betas": (0.1, 0.11), + "eps": 6e-08, + "weight_decay": 0.2, + "amsgrad": True, + }, + ), + ( + SupportedOptimizer.SGD.value, + { + "lr": 0.11, + "momentum": 5, + "dampening": 1, + "weight_decay": 8, + "nesterov": True, + }, + ), + ], +) +def test_optimizer_parameters(optimizer_name: SupportedOptimizer, parameters: dict): + """Test optimizer parameters filtering. + + For parameters, see: + https://pytorch.org/docs/stable/optim.html#algorithms + """ + # add non valid parameter + new_parameters = parameters.copy() + new_parameters["some_random_one"] = 42 + + # create optimizer and check that the parameters are filtered + optimizer = OptimizerModel(name=optimizer_name, parameters=new_parameters) + assert optimizer.parameters == parameters + + +def test_sgd_missing_parameter(): + """Test that SGD optimizer fails if `lr` is not provided. + + Note: The SGD optimizer requires the `lr` parameter. + """ + with pytest.raises(ValueError): + OptimizerModel(name=SupportedOptimizer.SGD.value, parameters={}) + + # test that it works if lr is provided + optimizer = OptimizerModel( + name=SupportedOptimizer.SGD.value, parameters={"lr": 0.1} + ) + assert optimizer.parameters == {"lr": 0.1} + + +def test_optimizer_wrong_values_by_assignments(): + """Test that wrong values cause an error during assignment.""" + optimizer = OptimizerModel( + name=SupportedOptimizer.Adam.value, parameters={"lr": 0.08} + ) + + # name + optimizer.name = SupportedOptimizer.SGD.value + with pytest.raises(ValueError): + optimizer.name = "MyOptim" + + # parameters + optimizer.parameters = {"lr": 0.1} + with pytest.raises(ValueError): + optimizer.parameters = "lr = 0.3" + + +def test_optimizer_to_dict_optional(): + """ "Test that export to dict includes optional values.""" + config = { + "name": "Adam", + "parameters": { + "lr": 0.1, + "betas": (0.1, 0.11), + }, + } + + optim_minimum = OptimizerModel(**config).model_dump() + assert optim_minimum == config + + +@pytest.mark.parametrize( + "lr_scheduler_name, parameters", + [ + ( + SupportedScheduler.ReduceLROnPlateau.value, + { + "mode": "max", + "factor": 0.3, + "patience": 5, + "threshold": 0.003, + "threshold_mode": "abs", + "cooldown": 3, + "min_lr": 0.1, + "eps": 5e-08, + }, + ), + ( + SupportedScheduler.StepLR.value, + { + "step_size": 2, + "gamma": 0.3, + "last_epoch": -5, + }, + ), + ], +) +def test_scheduler_parameters(lr_scheduler_name: SupportedScheduler, parameters: dict): + """Test lr scheduler parameters filtering. + + For parameters, see: + https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate + """ + # add non valid parameter + new_parameters = parameters.copy() + new_parameters["some_random_one"] = 42 + + # create optimizer and check that the parameters are filtered + lr_scheduler = LrSchedulerModel(name=lr_scheduler_name, parameters=new_parameters) + assert lr_scheduler.parameters == parameters + + +def test_scheduler_missing_parameter(): + """Test that StepLR scheduler fails if `step_size` is not provided""" + with pytest.raises(ValueError): + LrSchedulerModel(name=SupportedScheduler.StepLR.value, parameters={}) + + # test that it works if lr is provided + lr_scheduler = LrSchedulerModel( + name=SupportedScheduler.StepLR.value, parameters={"step_size": "5"} + ) + assert lr_scheduler.parameters == {"step_size": "5"} diff --git a/tests/config/test_tile_information.py b/tests/config/test_tile_information.py new file mode 100644 index 00000000..78b24cc8 --- /dev/null +++ b/tests/config/test_tile_information.py @@ -0,0 +1,50 @@ +import numpy as np +import pytest + +from careamics.config.tile_information import TileInformation + + +def test_defaults(): + """Test instantiating time information with defaults.""" + tile_info = TileInformation(array_shape=np.zeros((6, 6)).shape) + + assert tile_info.array_shape == (6, 6) + assert not tile_info.tiled + assert not tile_info.last_tile + assert tile_info.overlap_crop_coords is None + assert tile_info.stitch_coords is None + + +def test_tiled(): + """Test instantiating time information with parameters.""" + tile_info = TileInformation( + array_shape=np.zeros((6, 6)).shape, + tiled=True, + last_tile=True, + overlap_crop_coords=((1, 2),), + stitch_coords=((3, 4),), + ) + + assert tile_info.array_shape == (6, 6) + assert tile_info.tiled + assert tile_info.last_tile + assert tile_info.overlap_crop_coords == ((1, 2),) + assert tile_info.stitch_coords == ((3, 4),) + + +def test_validation_last_tile(): + """Test that last tile is only set if tiled is set.""" + tile_info = TileInformation(array_shape=(6, 6), last_tile=True) + assert not tile_info.last_tile + + +def test_error_on_coords(): + """Test than an error is raised if it is tiled but not coordinates are given.""" + with pytest.raises(ValueError): + TileInformation(array_shape=(6, 6), tiled=True) + + +def test_error_on_singleton_dims(): + """Test that an error is raised if the array shape contains singleton dimensions.""" + with pytest.raises(ValueError): + TileInformation(array_shape=(2, 1, 6, 6)) diff --git a/tests/config/test_torch_optimizer.py b/tests/config/test_torch_optimizer.py deleted file mode 100644 index 5f88fd7a..00000000 --- a/tests/config/test_torch_optimizer.py +++ /dev/null @@ -1,38 +0,0 @@ -from torch import optim - -from careamics.config.torch_optim import ( - TorchLRScheduler, - TorchOptimizer, - get_optimizers, - get_schedulers, -) - - -def test_get_schedulers_exist(): - """Test that the function `get_schedulers` return - existing torch schedulers. - """ - for scheduler in get_schedulers(): - assert hasattr(optim.lr_scheduler, scheduler) - - -def test_torch_schedulers_exist(): - """Test that the enum `TorchLRScheduler` contains - existing torch schedulers.""" - for scheduler in TorchLRScheduler: - assert hasattr(optim.lr_scheduler, scheduler) - - -def test_get_optimizers_exist(): - """Test that the function `get_optimizers` return - existing torch optimizers. - """ - for optimizer in get_optimizers(): - assert hasattr(optim, optimizer) - - -def test_optimizers_exist(): - """Test that the enum `TorchOptimizer` contains - existing torch optimizers.""" - for optimizer in TorchOptimizer: - assert hasattr(optim, optimizer) diff --git a/tests/config/test_training.py b/tests/config/test_training.py deleted file mode 100644 index 859848b6..00000000 --- a/tests/config/test_training.py +++ /dev/null @@ -1,468 +0,0 @@ -import pytest -from pydantic import conlist - -from careamics.config.torch_optim import ( - TorchLRScheduler, - TorchOptimizer, -) -from careamics.config.training import AMP, LrScheduler, Optimizer, Training - - -@pytest.mark.parametrize( - "optimizer_name, parameters", - [ - ( - TorchOptimizer.Adam, - { - "lr": 0.08, - "betas": (0.1, 0.11), - "eps": 6e-08, - "weight_decay": 0.2, - "amsgrad": True, - }, - ), - ( - TorchOptimizer.SGD, - { - "lr": 0.11, - "momentum": 5, - "dampening": 1, - "weight_decay": 8, - "nesterov": True, - }, - ), - ], -) -def test_optimizer_parameters(optimizer_name: TorchOptimizer, parameters: dict): - """Test optimizer parameters filtering. - - For parameters, see: - https://pytorch.org/docs/stable/optim.html#algorithms - """ - # add non valid parameter - new_parameters = parameters.copy() - new_parameters["some_random_one"] = 42 - - # create optimizer and check that the parameters are filtered - optimizer = Optimizer(name=optimizer_name, parameters=new_parameters) - assert optimizer.parameters == parameters - - -def test_sgd_missing_parameter(): - """Test that SGD optimizer fails if `lr` is not provided""" - with pytest.raises(ValueError): - Optimizer(name=TorchOptimizer.SGD, parameters={}) - - # test that it works if lr is provided - optimizer = Optimizer(name=TorchOptimizer.SGD, parameters={"lr": 0.1}) - assert optimizer.parameters == {"lr": 0.1} - - -def test_optimizer_wrong_values_by_assignments(): - """Test that wrong values cause an error during assignment.""" - optimizer = Optimizer(name=TorchOptimizer.Adam, parameters={"lr": 0.08}) - - # name - optimizer.name = TorchOptimizer.SGD - with pytest.raises(ValueError): - optimizer.name = "MyOptim" - - # parameters - optimizer.parameters = {"lr": 0.1} - with pytest.raises(ValueError): - optimizer.parameters = "lr = 0.3" - - -def test_optimizer_to_dict_minimum(minimum_config: dict): - """ "Test that export to dict does not include optional value.""" - optim_minimum = Optimizer(**minimum_config["training"]["optimizer"]).model_dump() - assert optim_minimum == minimum_config["training"]["optimizer"] - - assert "name" in optim_minimum.keys() - assert "parameters" not in optim_minimum.keys() - - -def test_optimizer_to_dict_complete(complete_config: dict): - """ "Test that export to dict does include optional value.""" - optim_minimum = Optimizer(**complete_config["training"]["optimizer"]).model_dump() - assert optim_minimum == complete_config["training"]["optimizer"] - - assert "name" in optim_minimum.keys() - assert "parameters" in optim_minimum.keys() - - -def test_optimizer_to_dict_optional(complete_config: dict): - """ "Test that export to dict does not include optional value.""" - optim_config = complete_config["training"]["optimizer"] - optim_config["parameters"] = {} - - optim_minimum = Optimizer(**optim_config).model_dump() - assert "name" in optim_minimum.keys() - assert "parameters" not in optim_minimum.keys() - - -@pytest.mark.parametrize( - "lr_scheduler_name, parameters", - [ - ( - TorchLRScheduler.ReduceLROnPlateau, - { - "mode": "max", - "factor": 0.3, - "patience": 5, - "threshold": 0.003, - "threshold_mode": "abs", - "cooldown": 3, - "min_lr": 0.1, - "eps": 5e-08, - }, - ), - ( - TorchLRScheduler.StepLR, - { - "step_size": 2, - "gamma": 0.3, - "last_epoch": -5, - }, - ), - ], -) -def test_scheduler_parameters(lr_scheduler_name: TorchLRScheduler, parameters: dict): - """Test lr scheduler parameters filtering. - - For parameters, see: - https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate - """ - # add non valid parameter - new_parameters = parameters.copy() - new_parameters["some_random_one"] = 42 - - # create optimizer and check that the parameters are filtered - lr_scheduler = LrScheduler(name=lr_scheduler_name, parameters=new_parameters) - assert lr_scheduler.parameters == parameters - - -def test_scheduler_missing_parameter(): - """Test that StepLR scheduler fails if `step_size` is not provided""" - with pytest.raises(ValueError): - LrScheduler(name=TorchLRScheduler.StepLR, parameters={}) - - # test that it works if lr is provided - lr_scheduler = LrScheduler( - name=TorchLRScheduler.StepLR, parameters={"step_size": "5"} - ) - assert lr_scheduler.parameters == {"step_size": "5"} - - -def test_scheduler_wrong_values_by_assignments(): - """Test that wrong values cause an error during assignment.""" - scheduler = LrScheduler( - name=TorchLRScheduler.ReduceLROnPlateau, parameters={"factor": 0.3} - ) - - # name - scheduler.name = TorchLRScheduler.ReduceLROnPlateau - with pytest.raises(ValueError): - # this fails because the step parameter is missing - scheduler.name = TorchLRScheduler.StepLR - - with pytest.raises(ValueError): - scheduler.name = "Schedule it yourself!" - - # parameters - scheduler.name = TorchLRScheduler.ReduceLROnPlateau - scheduler.parameters = {"factor": 0.1} - with pytest.raises(ValueError): - scheduler.parameters = "factor = 0.3" - - -def test_scheduler_to_dict_minimum(minimum_config: dict): - """ "Test that export to dict does not include optional value.""" - scheduler_minimum = LrScheduler( - **minimum_config["training"]["lr_scheduler"] - ).model_dump() - assert scheduler_minimum == minimum_config["training"]["lr_scheduler"] - - assert "name" in scheduler_minimum.keys() - assert "parameters" not in scheduler_minimum.keys() - - -def test_scheduler_to_dict_complete(complete_config: dict): - """ "Test that export to dict does include optional value.""" - scheduler_complete = LrScheduler( - **complete_config["training"]["lr_scheduler"] - ).model_dump() - assert scheduler_complete == complete_config["training"]["lr_scheduler"] - - assert "name" in scheduler_complete.keys() - assert "parameters" in scheduler_complete.keys() - - -def test_scheduler_to_dict_optional(complete_config: dict): - """ "Test that export to dict does not include optional value.""" - scheduler_config = complete_config["training"]["lr_scheduler"] - scheduler_config["parameters"] = {} - - scheduler_complete = LrScheduler(**scheduler_config).model_dump() - - assert "name" in scheduler_complete.keys() - assert "parameters" not in scheduler_complete.keys() - - -@pytest.mark.parametrize("init_scale", [512, 1024, 65536]) -def test_amp_init_scale(init_scale: int): - """Test AMP init_scale parameter.""" - amp = AMP(use=True, init_scale=init_scale) - assert amp.init_scale == init_scale - - -@pytest.mark.parametrize("init_scale", [511, 1088, 65537]) -def test_amp_wrong_init_scale(init_scale: int): - """Test wrong AMP init_scale parameter.""" - with pytest.raises(ValueError): - AMP(use=True, init_scale=init_scale) - - -def test_amp_wrong_values_by_assignments(): - """Test that wrong values cause an error during assignment.""" - amp = AMP(use=True, init_scale=1024) - - # use - amp.use = False - with pytest.raises(ValueError): - amp.use = None - - with pytest.raises(ValueError): - amp.use = 3 - - # init_scale - amp.init_scale = 512 - with pytest.raises(ValueError): - amp.init_scale = "1026" - - -def test_amp_to_dict(): - """ "Test export to dict.""" - # all values in there - vals = {"use": True, "init_scale": 512} - amp = AMP(**vals).model_dump() - assert amp == vals - - assert "use" in amp.keys() - assert "init_scale" in amp.keys() - - # optional value not in there if not specified (default) - vals = {"use": True} - amp = AMP(**vals).model_dump() - assert amp == vals - - assert "use" in amp.keys() - assert "init_scale" not in amp.keys() - - # optional value not in there if provided as default - vals = {"use": True, "init_scale": 1024} - amp = AMP(**vals).model_dump() - - assert "use" in amp.keys() - assert "init_scale" not in amp.keys() - - -@pytest.mark.parametrize("num_epochs", [1, 2, 4, 9000]) -def test_training_num_epochs(minimum_config: dict, num_epochs: int): - """Test that Training accepts num_epochs greater than 0.""" - training = minimum_config["training"] - training["num_epochs"] = num_epochs - - training = Training(**training) - assert training.num_epochs == num_epochs - - -@pytest.mark.parametrize("num_epochs", [-1, 0]) -def test_training_wrong_num_epochs(minimum_config: dict, num_epochs: int): - """Test that wrong number of epochs cause an error.""" - training = minimum_config["training"] - training["num_epochs"] = num_epochs - - with pytest.raises(ValueError): - Training(**training) - - -@pytest.mark.parametrize("batch_size", [1, 2, 4, 9000]) -def test_training_batch_size(minimum_config: dict, batch_size: int): - """Test batch size greater than 0.""" - training = minimum_config["training"] - training["batch_size"] = batch_size - - training = Training(**training) - assert training.batch_size == batch_size - - -@pytest.mark.parametrize("batch_size", [-1, 0]) -def test_training_wrong_batch_size(minimum_config: dict, batch_size: int): - """Test that wrong batch size cause an error.""" - training = minimum_config["training"] - training["batch_size"] = batch_size - - with pytest.raises(ValueError): - Training(**training) - - -@pytest.mark.parametrize("patch_size", [[2, 2], [2, 4, 2], [32, 96]]) -def test_training_patch_size(minimum_config: dict, patch_size: conlist): - """Test patch size greater than 0.""" - training = minimum_config["training"] - training["patch_size"] = patch_size - - training = Training(**training) - assert training.patch_size == patch_size - - -@pytest.mark.parametrize( - "patch_size", - [ - [ - 2, - ], - [2, 4, 2, 2], - [1, 1], - [2, 0], - [33, 32], - ], -) -def test_training_wrong_patch_size(minimum_config: dict, patch_size: conlist): - """Test that wrong patch size cause an error.""" - training = minimum_config["training"] - training["patch_size"] = patch_size - - with pytest.raises(ValueError): - Training(**training) - - -@pytest.mark.parametrize("num_workers", [0, 1, 4, 9000]) -def test_training_num_workers(complete_config: dict, num_workers: int): - """Test batch size greater than 0.""" - training = complete_config["training"] - training["num_workers"] = num_workers - - training = Training(**training) - assert training.num_workers == num_workers - - -@pytest.mark.parametrize("num_workers", [-1, -2]) -def test_training_wrong_num_workers(complete_config: dict, num_workers: int): - """Test that wrong batch size cause an error.""" - training = complete_config["training"] - training["num_workers"] = num_workers - - with pytest.raises(ValueError): - Training(**training) - - -def test_training_wrong_values_by_assignments(complete_config: dict): - """Test that wrong values cause an error during assignment.""" - training = Training(**complete_config["training"]) - - # num_epochs - training.num_epochs = 2 - with pytest.raises(ValueError): - training.num_epochs = -1 - - # batch_size - training.batch_size = 2 - with pytest.raises(ValueError): - training.batch_size = -1 - - # patch_size - training.patch_size = [2, 2] - with pytest.raises(ValueError): - training.patch_size = [5, 4] - - # optimizer - training.optimizer = Optimizer(name=TorchOptimizer.Adam, parameters={"lr": 0.1}) - with pytest.raises(ValueError): - training.optimizer = "I'd rather not to." - - # lr_scheduler - training.lr_scheduler = LrScheduler( - name=TorchLRScheduler.ReduceLROnPlateau, parameters={"factor": 0.1} - ) - with pytest.raises(ValueError): - training.lr_scheduler = "Why don't you schedule it for once? :)" - - # augmentation - training.augmentation = True - with pytest.raises(ValueError): - training.augmentation = None - - # use_wandb - training.use_wandb = True - with pytest.raises(ValueError): - training.use_wandb = None - - # amp - training.amp = AMP(use=True, init_scale=1024) - with pytest.raises(ValueError): - training.amp = "I don't want to use AMP." - - # num_workers - training.num_workers = 2 - with pytest.raises(ValueError): - training.num_workers = -1 - - -def test_training_to_dict_minimum(minimum_config: dict): - """Test that the minimum config get export to dict correctly.""" - training_minimum = Training(**minimum_config["training"]).model_dump() - assert training_minimum == minimum_config["training"] - - # Mandatory fields are present - assert "num_epochs" in training_minimum.keys() - assert "patch_size" in training_minimum.keys() - assert "batch_size" in training_minimum.keys() - - assert "optimizer" in training_minimum.keys() - assert "name" in training_minimum["optimizer"].keys() - # optional subfield - assert "parameters" not in training_minimum["optimizer"].keys() - - assert "lr_scheduler" in training_minimum.keys() - assert "name" in training_minimum["lr_scheduler"].keys() - # optional subfield - assert "parameters" not in training_minimum["lr_scheduler"].keys() - - assert "augmentation" in training_minimum.keys() - - # Optionals fields are absent - assert "wandb" not in training_minimum.keys() - assert "num_workers" not in training_minimum.keys() - assert "amp" not in training_minimum.keys() - - -def test_training_to_dict_optionals(complete_config: dict): - """Test that the optionals fields are omitted when default.""" - train_conf = complete_config["training"] - train_conf["amp"] = AMP(use=False, init_scale=1024) - train_conf["num_workers"] = 0 - train_conf["use_wandb"] = False - - training_complete = Training(**train_conf).model_dump() - - # Mandatory fields are present - assert "num_epochs" in training_complete.keys() - assert "patch_size" in training_complete.keys() - assert "batch_size" in training_complete.keys() - - assert "optimizer" in training_complete.keys() - assert "name" in training_complete["optimizer"].keys() - assert "parameters" in training_complete["optimizer"].keys() - - assert "lr_scheduler" in training_complete.keys() - assert "name" in training_complete["lr_scheduler"].keys() - assert "parameters" in training_complete["lr_scheduler"].keys() - - assert "augmentation" in training_complete.keys() - - # Optionals fields - assert "use_wandb" not in training_complete.keys() - assert "num_workers" not in training_complete.keys() - assert "amp" not in training_complete.keys() diff --git a/tests/config/test_training_model.py b/tests/config/test_training_model.py new file mode 100644 index 00000000..f37a6593 --- /dev/null +++ b/tests/config/test_training_model.py @@ -0,0 +1,13 @@ +import pytest + +from careamics.config.training_model import TrainingModel + + +def test_training_wrong_values_by_assignments(minimum_training: dict): + """Test that wrong values cause an error during assignment.""" + training = TrainingModel(**minimum_training) + + # num_epochs + training.num_epochs = 2 + with pytest.raises(ValueError): + training.num_epochs = -1 diff --git a/tests/config/transformations/test_n2v_manipulate_model.py b/tests/config/transformations/test_n2v_manipulate_model.py new file mode 100644 index 00000000..6939182c --- /dev/null +++ b/tests/config/transformations/test_n2v_manipulate_model.py @@ -0,0 +1,26 @@ +import pytest + +from careamics.config.transformations.n2v_manipulate_model import N2VManipulateModel + + +def test_odd_roi_and_mask(): + """Test that errors are thrown if we pass even roi and mask sizes.""" + # no error + model = N2VManipulateModel(name="N2VManipulate", roi_size=3, struct_mask_span=7) + assert model.roi_size == 3 + assert model.struct_mask_span == 7 + + # errors + with pytest.raises(ValueError): + N2VManipulateModel(name="N2VManipulate", roi_size=4, struct_mask_span=7) + + with pytest.raises(ValueError): + N2VManipulateModel(name="N2VManipulate", roi_size=3, struct_mask_span=6) + + +def test_extra_parameters(): + """Test that errors are thrown if we pass extra parameters.""" + with pytest.raises(ValueError): + N2VManipulateModel( + name="N2VManipulate", roi_size=3, struct_mask_span=7, extra_param=1 + ) diff --git a/tests/config/transformations/test_normalize_model.py b/tests/config/transformations/test_normalize_model.py new file mode 100644 index 00000000..324b635f --- /dev/null +++ b/tests/config/transformations/test_normalize_model.py @@ -0,0 +1,13 @@ +from careamics.config.transformations import NormalizeModel + + +def test_setting_mean_std(): + """Test that we can set the mean and std values.""" + model = NormalizeModel(name="Normalize", mean=0.5, std=0.5) + assert model.mean == 0.5 + assert model.std == 0.5 + + model.mean = 0.6 + model.std = 0.6 + assert model.mean == 0.6 + assert model.std == 0.6 diff --git a/tests/config/validators/test_validator_utils.py b/tests/config/validators/test_validator_utils.py new file mode 100644 index 00000000..740f10a5 --- /dev/null +++ b/tests/config/validators/test_validator_utils.py @@ -0,0 +1,70 @@ +import pytest + +from careamics.config.validators import ( + check_axes_validity, + patch_size_ge_than_8_power_of_2, +) + + +@pytest.mark.parametrize( + "axes, valid", + [ + # Passing + ("yx", True), + ("Yx", True), + ("Zyx", True), + ("STYX", True), + ("CYX", True), + ("YXC", True), + ("TzYX", True), + ("SZYX", True), + ("STZYX", True), + ("XY", True), + ("YXT", True), + ("ZTYX", True), + # non consecutive XY + ("YZX", False), + ("YZCXT", False), + # too few axes + ("", False), + ("X", False), + # no yx axes + ("ZT", False), + ("ZY", False), + # repeating characters + ("YYX", False), + ("YXY", False), + # invalid characters + ("YXm", False), + ("1YX", False), + ], +) +def test_are_axes_valid(axes, valid): + """Test if axes are valid""" + if valid: + check_axes_validity(axes) + else: + with pytest.raises((ValueError, NotImplementedError)): + check_axes_validity(axes) + + +@pytest.mark.parametrize( + "patch_size, error", + [ + ((2, 8, 8), True), + ((10,), True), + ((8, 10, 16), True), + ((8, 13), True), + ((8, 16, 4), True), + ((8,), False), + ((8, 8), False), + ((8, 64, 64), False), + ], +) +def test_patch_size(patch_size, error): + """Test if patch size is valid.""" + if error: + with pytest.raises(ValueError): + patch_size_ge_than_8_power_of_2(patch_size) + else: + patch_size_ge_than_8_power_of_2(patch_size) diff --git a/tests/conftest.py b/tests/conftest.py index 1bca73b2..3e41040c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,99 +1,190 @@ -import copy from pathlib import Path -from typing import Callable +from typing import Callable, Tuple import numpy as np import pytest +from careamics import CAREamist, Configuration +from careamics.config.support import SupportedData +from careamics.model_io import export_to_bmz + +# TODO add details about where each of these fixture is used (e.g. smoke test) @pytest.fixture -def minimum_config(tmp_path: Path) -> dict: - """Create a minimum configuration. +def create_tiff(path: Path, n_files: int): + """Create tiff files for testing.""" + if not path.exists(): + path.mkdir() + + for i in range(n_files): + file_path = path / f"file_{i}.tif" + file_path.touch() - Parameters - ---------- - tmp_path : Path - Temporary path for testing. + +@pytest.fixture +def minimum_algorithm_custom() -> dict: + """Create a minimum algorithm dictionary. Returns ------- dict - A minumum configuration example. + A minimum algorithm example. """ # create dictionary - configuration = { - "experiment_name": "LevitatingFrog", - "working_directory": str(tmp_path), - "algorithm": { - "loss": "n2v", - "model": "UNet", - "is_3D": False, - }, - "training": { - "num_epochs": 666, - "batch_size": 42, - "patch_size": [64, 64], - "optimizer": { - "name": "Adam", - }, - "lr_scheduler": {"name": "ReduceLROnPlateau"}, - "augmentation": True, + algorithm = { + "algorithm": "custom", + "loss": "mae", + "model": { + "architecture": "UNet", }, - "data": { - "in_memory": True, - "data_format": "tif", - "axes": "SYX", + } + + return algorithm + + +@pytest.fixture +def minimum_algorithm_n2v() -> dict: + """Create a minimum algorithm dictionary. + + Returns + ------- + dict + A minimum algorithm example. + """ + # create dictionary + algorithm = { + "algorithm": "n2v", + "loss": "n2v", + "model": { + "architecture": "UNet", }, } - return configuration + return algorithm @pytest.fixture -def complete_config(minimum_config: dict) -> dict: - """Create a complete configuration. +def minimum_algorithm_supervised() -> dict: + """Create a minimum algorithm dictionary. - This configuration should not be used for testing an Engine. + Returns + ------- + dict + A minimum algorithm example. + """ + # create dictionary + algorithm = { + "algorithm": "n2n", + "loss": "mae", + "model": { + "architecture": "UNet", + }, + } - Parameters - ---------- - minimum_config : dict - A minimum configuration. + return algorithm + + +@pytest.fixture +def minimum_data() -> dict: + """Create a minimum data dictionary. Returns ------- dict - A complete configuration example. + A minimum data example. """ - # add to configuration - complete_config = copy.deepcopy(minimum_config) + # create dictionary + data = { + "data_type": SupportedData.ARRAY.value, + "patch_size": [8, 8], + "axes": "YX", + } + + return data - complete_config["algorithm"]["masking_strategy"] = "median" - complete_config["algorithm"]["masked_pixel_percentage"] = 0.6 - complete_config["algorithm"]["roi_size"] = 13 - complete_config["algorithm"]["model_parameters"] = { - "depth": 8, - "num_channels_init": 32, +@pytest.fixture +def minimum_inference() -> dict: + """Create a minimum inference dictionary. + + Returns + ------- + dict + A minimum data example. + """ + # create dictionary + predic = { + "data_type": SupportedData.ARRAY.value, + "axes": "YX", + "mean": 2.0, + "std": 1.0, } - complete_config["training"]["optimizer"]["parameters"] = { - "lr": 0.00999, + return predic + + +@pytest.fixture +def minimum_training() -> dict: + """Create a minimum training dictionary. + + Returns + ------- + dict + A minimum training example. + """ + # create dictionary + training = { + "num_epochs": 1, } - complete_config["training"]["lr_scheduler"]["parameters"] = { - "patience": 22, + + return training + + +@pytest.fixture +def minimum_configuration( + minimum_algorithm_n2v: dict, minimum_data: dict, minimum_training: dict +) -> dict: + """Create a minimum configuration dictionary. + + Parameters + ---------- + tmp_path : Path + Temporary path for testing. + minimum_algorithm : dict + Minimum algorithm configuration. + minimum_data : dict + Minimum data configuration. + minimum_training : dict + Minimum training configuration. + + Returns + ------- + dict + A minumum configuration example. + """ + # create dictionary + configuration = { + "experiment_name": "LevitatingFrog", + "algorithm_config": minimum_algorithm_n2v, + "training_config": minimum_training, + "data_config": minimum_data, } - complete_config["training"]["use_wandb"] = True - complete_config["training"]["num_workers"] = 6 - complete_config["training"]["amp"] = { - "use": True, - "init_scale": 512, + + return configuration + + +@pytest.fixture +def supervised_configuration( + minimum_algorithm_supervised: dict, minimum_data: dict, minimum_training: dict +) -> dict: + configuration = { + "experiment_name": "LevitatingFrog", + "algorithm_config": minimum_algorithm_supervised, + "training_config": minimum_training, + "data_config": minimum_data, } - complete_config["data"]["in_memory"] = False - complete_config["data"]["mean"] = 666.666 - complete_config["data"]["std"] = 42.420 - return complete_config + return configuration @pytest.fixture @@ -119,24 +210,90 @@ def _ordered_array(shape: tuple, dtype=int) -> np.ndarray: @pytest.fixture -def array_2D(ordered_array) -> np.ndarray: - """A 2D array with shape (1, 10, 9). +def array_2D() -> np.ndarray: + """A 2D array with shape (1, 3, 10, 9). Returns ------- np.ndarray - 2D array with shape (1, 10, 9). + 2D array with shape (1, 3, 10, 9). """ - return ordered_array((1, 10, 9)) + return np.arange(90 * 3).reshape((1, 3, 10, 9)) @pytest.fixture -def array_3D(ordered_array) -> np.ndarray: - """A 3D array with shape (1, 5, 10, 9). +def array_3D() -> np.ndarray: + """A 3D array with shape (1, 3, 5, 10, 9). Returns ------- np.ndarray - 3D array with shape (1, 5, 10, 9). + 3D array with shape (1, 3, 5, 10, 9). """ - return ordered_array((1, 8, 16, 16)) + return np.arange(2048 * 3).reshape((1, 3, 8, 16, 16)) + + +@pytest.fixture +def patch_size() -> Tuple[int, int]: + return (64, 64) + + +@pytest.fixture +def overlaps() -> Tuple[int, int]: + return (32, 32) + + +@pytest.fixture +def pre_trained(tmp_path, minimum_configuration): + """Fixture to create a pre-trained CAREamics model.""" + # training data + train_array = np.arange(32 * 32).reshape((32, 32)) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array) + + # check that it trained + pre_trained_path: Path = tmp_path / "checkpoints" / "last.ckpt" + assert pre_trained_path.exists() + + return pre_trained_path + + +@pytest.fixture +def pre_trained_bmz(tmp_path, pre_trained) -> Path: + """Fixture to create a BMZ model.""" + # training data + train_array = np.ones((32, 32), dtype=np.float32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # predict (no tiling and no tta) + predicted = careamist.predict(train_array, tta_transforms=False) + + # export to BioImage Model Zoo + path = tmp_path / "model.zip" + export_to_bmz( + model=careamist.model, + config=careamist.cfg, + path=path, + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + input_array=train_array[np.newaxis, np.newaxis, ...], + output_array=predicted, + ) + assert path.exists() + + return path diff --git a/tests/dataset/dataset_utils/test_list_files.py b/tests/dataset/dataset_utils/test_list_files.py new file mode 100644 index 00000000..595fa461 --- /dev/null +++ b/tests/dataset/dataset_utils/test_list_files.py @@ -0,0 +1,232 @@ +from pathlib import Path + +import numpy as np +import pytest +import tifffile + +from careamics.config.support import SupportedData +from careamics.dataset.dataset_utils import ( + get_files_size, + list_files, + validate_source_target_files, +) + + +def test_get_files_size_tiff(tmp_path: Path): + """Test getting size of multiple TIFF files.""" + # create array + image = np.ones((10, 10)) + + # save array to tiff + path1 = tmp_path / "test1.tif" + tifffile.imwrite(path1, image) + + path2 = tmp_path / "test2.tiff" + tifffile.imwrite(path2, image) + + # save text file + path3 = tmp_path / "test3.txt" + path3.write_text("test") + + # save file in subdirectory + subdirectory = tmp_path / "subdir" + subdirectory.mkdir() + path4 = subdirectory / "test3.tif" + tifffile.imwrite(path4, image) + + # create file list + files = [path1, path2, path4] + + # get files size + size = get_files_size(files) + assert size > 0 + + +def test_list_single_file_tiff(tmp_path: Path): + """Test listing a single TIFF file.""" + # create array + image = np.ones((10, 10)) + + # save array to tiff + path = tmp_path / "test.tif" + tifffile.imwrite(path, image) + + # list file using parent directory + files = list_files(tmp_path, SupportedData.TIFF) + assert len(files) == 1 + assert files[0] == path + + # list file using file path + files = list_files(path, SupportedData.TIFF) + assert len(files) == 1 + assert files[0] == path + + +def test_list_multiple_files_tiff(tmp_path: Path): + """Test listing multiple TIFF files in subdirectories with additional files.""" + # create array + image = np.ones((10, 10)) + + # save array to /npy + path1 = tmp_path / "test1.tif" + tifffile.imwrite(path1, image) + + path2 = tmp_path / "test2.tif" + tifffile.imwrite(path2, image) + + # save text file + path3 = tmp_path / "test3.txt" + path3.write_text("test") + + # save file in subdirectory + subdirectory = tmp_path / "subdir" + subdirectory.mkdir() + path4 = subdirectory / "test3.tif" + tifffile.imwrite(path4, image) + + # create file list + ref_files = [path1, path2, path4] + + # list files using parent directory + files = list_files(tmp_path, SupportedData.TIFF) + assert len(files) == 3 + assert set(files) == set(ref_files) + + +def test_list_single_file_custom(tmp_path): + """Test listing a single custom file.""" + # create array + image = np.ones((10, 10)) + + # save as .npy + path = tmp_path / "custom.npy" + np.save(path, image) + + # list files using parent directory + files = list_files(tmp_path, SupportedData.CUSTOM) + assert len(files) == 1 + assert files[0] == path + + # list files using file path + files = list_files(path, SupportedData.CUSTOM) + assert len(files) == 1 + assert files[0] == path + + +def test_list_multiple_files_custom(tmp_path: Path): + """Test listing multiple custom files in subdirectories with additional files.""" + # create array + image = np.ones((10, 10)) + + # save array to /npy + path1 = tmp_path / "test1.npy" + np.save(path1, image) + + path2 = tmp_path / "test2.npy" + np.save(path2, image) + + # save text file + path3 = tmp_path / "test3.txt" + path3.write_text("test") + + # save file in subdirectory + subdirectory = tmp_path / "subdir" + subdirectory.mkdir() + path4 = subdirectory / "test3.npy" + np.save(path4, image) + + # create file list (even the text file is selected) + ref_files = [path1, path2, path3, path4] + + # list files using parent directory + files = list_files(tmp_path, SupportedData.CUSTOM) + assert len(files) == 4 + assert set(files) == set(ref_files) + + # list files using the file extension filter + files = list_files(tmp_path, SupportedData.CUSTOM, "*.npy") + assert len(files) == 3 + assert set(files) == {path1, path2, path4} + + +def test_validate_source_target_files(tmp_path: Path): + """Test that it passes for two folders with same number of files and same names.""" + # create two subfolders + src = tmp_path / "src" + src.mkdir() + + tar = tmp_path / "tar" + tar.mkdir() + + # populate with files + filename_1 = "test1.txt" + filename_2 = "test2.txt" + + (tmp_path / "src" / filename_1).write_text("test") + (tmp_path / "tar" / filename_1).write_text("test") + + (tmp_path / "src" / filename_2).write_text("test") + (tmp_path / "tar" / filename_2).write_text("test") + + # list files + src_files = list_files(src, SupportedData.CUSTOM) + tar_files = list_files(tar, SupportedData.CUSTOM) + + # validate files + validate_source_target_files(src_files, tar_files) + + +def test_validate_source_target_files_wrong_names(tmp_path: Path): + """Test that an error is raised if filenames are different.""" + # create two subfolders + src = tmp_path / "src" + src.mkdir() + + tar = tmp_path / "tar" + tar.mkdir() + + # populate with files + filename_1 = "test1.txt" + filename_2 = "test2.txt" + filename_3 = "test3.txt" + + (tmp_path / "src" / filename_1).write_text("test") + (tmp_path / "tar" / filename_1).write_text("test") + + (tmp_path / "src" / filename_2).write_text("test") + (tmp_path / "tar" / filename_3).write_text("test") + + # list files + src_files = list_files(src, SupportedData.CUSTOM) + tar_files = list_files(tar, SupportedData.CUSTOM) + + # validate files + with pytest.raises(ValueError): + validate_source_target_files(src_files, tar_files) + + +def test_validate_source_target_files_wrong_number(tmp_path: Path): + """Test that an error is raised if filenames are different.""" + # create two subfolders + src = tmp_path / "src" + src.mkdir() + + tar = tmp_path / "tar" + tar.mkdir() + + # populate with files + filename_1 = "test1.txt" + filename_2 = "test2.txt" + + (tmp_path / "src" / filename_1).write_text("test") + (tmp_path / "tar" / filename_1).write_text("test") + + (tmp_path / "src" / filename_2).write_text("test") + + # list files + src_files = list_files(src, SupportedData.CUSTOM) + tar_files = list_files(tar, SupportedData.CUSTOM) + + # validate files + with pytest.raises(ValueError): + validate_source_target_files(src_files, tar_files) diff --git a/tests/dataset/dataset_utils/test_read_tiff.py b/tests/dataset/dataset_utils/test_read_tiff.py new file mode 100644 index 00000000..5d534b8f --- /dev/null +++ b/tests/dataset/dataset_utils/test_read_tiff.py @@ -0,0 +1,32 @@ +import numpy as np +import pytest +import tifffile + +from careamics.dataset.dataset_utils.read_tiff import read_tiff + + +def test_read_tiff(tmp_path, ordered_array): + """Test reading a tiff file.""" + # create an array + array: np.ndarray = ordered_array((10, 10)) + + # save files + file = tmp_path / "test.tiff" + tifffile.imwrite(file, array) + + # read files + array_read = read_tiff(file) + np.testing.assert_array_equal(array_read, array) + + +def test_read_tiff_invalid(tmp_path): + # invalid file type + file = tmp_path / "test.txt" + file.write_text("test") + with pytest.raises(ValueError): + read_tiff(file) + + # non-existing file + file = tmp_path / "test.tiff" + with pytest.raises(FileNotFoundError): + read_tiff(file) diff --git a/tests/dataset/patching/test_patching_utils.py b/tests/dataset/patching/test_patching_utils.py new file mode 100644 index 00000000..5ac4568c --- /dev/null +++ b/tests/dataset/patching/test_patching_utils.py @@ -0,0 +1,44 @@ +import numpy as np +import pytest + +from careamics.dataset.patching.validate_patch_dimension import ( + validate_patch_dimensions, +) + + +@pytest.mark.parametrize( + "arr_shape, patch_size", + [ + ((1, 1, 8, 8), (2, 2)), + ((1, 1, 8, 8, 8), (2, 2, 2)), + ], +) +def test_patches_sanity_check(arr_shape, patch_size): + arr = np.zeros(arr_shape) + is_3d_patch = len(patch_size) == 3 + # check if the patch is 2D or 3D. Subtract 1 because the first dimension is sample + validate_patch_dimensions(arr, patch_size, is_3d_patch) + + +@pytest.mark.parametrize( + "arr_shape, patch_size", + [ + # Wrong number of dimensions 2D + # minimum 3 dimensions CYX + ((10, 10), (5, 5, 5)), + # Wrong number of dimensions 3D + ((1, 1, 10, 10, 10), (5, 5)), + # Wrong z patch size + ((1, 10, 10), (5, 5, 5)), + ((10, 10, 10), (10, 5, 5)), + # Wrong YX patch sizes + ((1, 10, 10), (12, 5)), + ((1, 10, 10), (5, 11)), + ], +) +def test_patches_sanity_check_invalid_cases(arr_shape, patch_size): + arr = np.zeros(arr_shape) + is_3d_patch = len(patch_size) == 3 + # check if the patch is 2D or 3D. Subtract 1 because the first dimension is sample + with pytest.raises(ValueError): + validate_patch_dimensions(arr, patch_size, is_3d_patch) diff --git a/tests/dataset/patching/test_random_patching.py b/tests/dataset/patching/test_random_patching.py new file mode 100644 index 00000000..858cfe2e --- /dev/null +++ b/tests/dataset/patching/test_random_patching.py @@ -0,0 +1,105 @@ +import numpy as np +import pytest + +from careamics.dataset.patching.random_patching import extract_patches_random + + +@pytest.mark.parametrize( + "shape, patch_size", + [ + ((1, 1, 8, 8), (3, 3)), + ((1, 3, 8, 8), (3, 3)), + ((3, 1, 8, 8), (3, 3)), + ((2, 3, 8, 8), (3, 3)), + ((1, 1, 5, 8, 8), (3, 3, 3)), + ((1, 3, 5, 8, 8), (3, 3, 3)), + ((3, 1, 5, 8, 8), (3, 3, 3)), + ((2, 3, 5, 8, 8), (3, 3, 3)), + ], +) +def test_random_patching_unsupervised(ordered_array, shape, patch_size): + """Check that the patches are extracted correctly. + + Since extract patches is called on already shaped array, dimensions S and C are + present. + """ + np.random.seed(42) + + # create array + array = ordered_array(shape) + is_3D = len(patch_size) == 3 + top_left = [] + + for _ in range(3): + patch_generator = extract_patches_random(array, patch_size=patch_size) + + # get all patches and targets + patches = [patch for patch, _ in patch_generator] + + # check patch shape + for patch in patches: + # account for C dimension + assert patch.shape[1:] == patch_size + + # get top_left index in the original array + if is_3D: + ind = np.where(array == patch[0, 0, 0, 0]) + else: + ind = np.where(array == patch[0, 0, 0]) + + top_left.append(np.array(ind)) + + # check randomness + coords = np.array(top_left).squeeze() + assert coords.min() == 0 + assert coords.max() == max(array.shape) - max(patch_size) + assert len(np.unique(coords, axis=0)) >= 0.7 * np.prod(shape) / np.prod(patch_size) + + +# @pytest.mark.parametrize( +# "patch_size", +# [ +# (2, 2), +# (4, 2), +# (4, 8), +# (8, 8), +# ], +# ) +# def test_extract_patches_random_2d(array_2D, patch_size): +# """Test extracting patches randomly in 2D.""" +# check_extract_patches_random(array_2D, "SYX", patch_size) + + +# @pytest.mark.parametrize( +# "patch_size", +# [ +# (2, 2), +# (4, 2), +# (4, 8), +# (8, 8), +# ], +# ) +# def test_extract_patches_random_supervised_2d(array_2D, patch_size): +# """Test extracting patches randomly in 2D.""" +# check_extract_patches_random( +# array_2D, +# "SYX", +# patch_size, +# target=array_2D +# ) + + +# @pytest.mark.parametrize( +# "patch_size", +# [ +# (2, 2, 4), +# (4, 2, 2), +# (2, 8, 4), +# (4, 8, 8), +# ], +# ) +# def test_extract_patches_random_3d(array_3D, patch_size): +# """Test extracting patches randomly in 3D. + +# The 3D array is a fixture of shape (1, 8, 16, 16).""" +# check_extract_patches_random(array_3D, "SZYX", patch_size) diff --git a/tests/dataset/patching/test_sequential_patching.py b/tests/dataset/patching/test_sequential_patching.py new file mode 100644 index 00000000..33ec902e --- /dev/null +++ b/tests/dataset/patching/test_sequential_patching.py @@ -0,0 +1,138 @@ +import numpy as np +import pytest + +from careamics.dataset.patching.sequential_patching import ( + _compute_number_of_patches, + _compute_overlap, + _compute_patch_steps, + _compute_patch_views, + extract_patches_sequential, +) + + +def check_extract_patches_sequential(array: np.ndarray, patch_size: tuple): + """Check that the patches are extracted correctly. + + The array should have been generated using np.arange and np.reshape.""" + patches, _ = extract_patches_sequential(array, patch_size=patch_size) + + # check patch shape + assert patches.shape[2:] == patch_size + + # check that all values are covered by the patches + n_max = np.prod(array.shape) # maximum value in the array + unique = np.unique(np.array(patches)) # unique values in the patches + assert len(unique) == n_max + + +@pytest.mark.parametrize( + "patch_size", + [ + (2, 2), + (4, 2), + (4, 8), + (8, 8), + ], +) +def test_extract_patches_sequential_2d(array_2D, patch_size): + """Test extracting patches sequentially in 2D.""" + check_extract_patches_sequential(array_2D, patch_size) + + +@pytest.mark.parametrize( + "patch_size", + [ + (2, 2, 4), + (4, 2, 2), + (2, 8, 4), + (4, 8, 8), + ], +) +def test_extract_patches_sequential_3d(array_3D, patch_size): + """Test extracting patches sequentially in 3D. + + The 3D array is a fixture of shape (1, 8, 16, 16).""" + # TODO changed the fixture to (1, 8, 16, 16), uneven shape doesnt work. We need to + # discuss the function or the test cases + check_extract_patches_sequential(array_3D, patch_size) + + +@pytest.mark.parametrize( + "shape, patch_sizes, expected", + [ + ((1, 3, 10, 10), (1, 3, 10, 5), (1, 1, 1, 2)), + ((1, 1, 9, 9), (1, 1, 4, 3), (1, 1, 3, 3)), + ((1, 3, 10, 9), (1, 3, 3, 5), (1, 1, 4, 2)), + ((1, 1, 5, 9, 10), (1, 1, 2, 3, 5), (1, 1, 3, 3, 2)), + ], +) +def test_compute_number_of_patches(shape, patch_sizes, expected): + """Test computing number of patches""" + assert _compute_number_of_patches(shape, patch_sizes) == expected + + +@pytest.mark.parametrize( + "shape, patch_sizes, expected", + [ + ((1, 3, 10, 10), (1, 3, 10, 5), (0, 0, 0, 0)), + ((1, 1, 9, 9), (1, 1, 4, 3), (0, 0, 2, 0)), + ((1, 3, 10, 10, 9), (1, 3, 2, 3, 5), (0, 0, 0, 1, 1)), + ], +) +def test_compute_overlap(shape, patch_sizes, expected): + """Test computing overlap between patches""" + assert _compute_overlap(shape, patch_sizes) == expected + + +@pytest.mark.parametrize("dims", [2, 3]) +@pytest.mark.parametrize("patch_size", [2, 3]) +@pytest.mark.parametrize("overlap", [0, 1, 4]) +def test_compute_patch_steps(dims, patch_size, overlap): + """Test computing patch steps""" + patch_sizes = (patch_size,) * dims + overlaps = (overlap,) * dims + expected = (min(patch_size - overlap, patch_size),) * dims + + assert _compute_patch_steps(patch_sizes, overlaps) == expected + + +def check_compute_reshaped_view(array: np.ndarray, window_shape, steps): + """Check the number of patches""" + + output_shape = (-1, *window_shape) + + # compute views + output = _compute_patch_views(array, window_shape, steps, output_shape) + + # check the number of patches + n_patches = [ + np.ceil((array.shape[i] - window_shape[i] + 1) / steps[i]).astype(int) + for i in range(len(window_shape)) + ] + assert output.shape == (np.prod(n_patches), *window_shape) + + +@pytest.mark.parametrize( + "window_shape, steps", + [ + ((1, 3, 5, 5), (1, 1, 1, 1)), + ((1, 3, 5, 5), (1, 1, 2, 3)), + ((1, 3, 5, 7), (1, 1, 1, 1)), + ], +) +def test_compute_reshaped_view_2d(array_2D, window_shape, steps): + """Test computing reshaped view of an array of shape (1, 10, 9).""" + check_compute_reshaped_view(array_2D, window_shape, steps) + + +@pytest.mark.parametrize( + "window_shape, steps", + [ + ((1, 3, 1, 5, 5), (1, 1, 2, 1, 2)), + ((1, 3, 2, 5, 5), (1, 1, 2, 3, 4)), + ((1, 3, 3, 7, 8), (1, 1, 1, 1, 3)), + ], +) +def test_compute_reshaped_view_3d(array_3D, window_shape, steps): + """Test computing reshaped view of an array of shape (1, 5, 10, 9).""" + check_compute_reshaped_view(array_3D, window_shape, steps) diff --git a/tests/dataset/patching/test_tiled_patching.py b/tests/dataset/patching/test_tiled_patching.py new file mode 100644 index 00000000..a7e135d4 --- /dev/null +++ b/tests/dataset/patching/test_tiled_patching.py @@ -0,0 +1,120 @@ +import numpy as np +import pytest + +from careamics.config.tile_information import TileInformation +from careamics.dataset.patching.tiled_patching import ( + _compute_crop_and_stitch_coords_1d, + extract_tiles, +) + + +def check_extract_tiles(array: np.ndarray, tile_size, overlaps): + """Test extracting patches randomly.""" + tile_data_generator = extract_tiles(array, tile_size, overlaps) + + tiles = [] + all_overlap_crop_coords = [] + all_stitch_coords = [] + + # Assemble all tiles and their respective coordinates + for tile_data in tile_data_generator: + tile = tile_data[0] + + tile_info: TileInformation = tile_data[1] + overlap_crop_coords = tile_info.overlap_crop_coords + stitch_coords = tile_info.stitch_coords + + # add data to lists + tiles.append(tile) + all_overlap_crop_coords.append(overlap_crop_coords) + all_stitch_coords.append(stitch_coords) + + # check tile shape, ignore sample dimension + assert tile.shape[1:] == tile_size + assert len(overlap_crop_coords) == len(stitch_coords) == len(tile_size) + + # check that each tile has a unique set of coordinates + assert len(tiles) == len(all_overlap_crop_coords) == len(all_stitch_coords) + + # check that all values are covered by the tiles + n_max = np.prod(array.shape) # maximum value in the array + unique = np.unique(np.array(tiles)) # unique values in the patches + assert len(unique) >= n_max + + +@pytest.mark.parametrize( + "tile_size, overlaps", + [ + ((4, 4), (2, 2)), + ((8, 8), (4, 4)), + ], +) +def test_extract_tiles_2d(array_2D, tile_size, overlaps): + """Test extracting tiles for prediction in 2D.""" + check_extract_tiles(array_2D, tile_size, overlaps) + + +@pytest.mark.parametrize( + "tile_size, overlaps", + [ + ((4, 4, 4), (2, 2, 2)), + ((8, 8, 8), (4, 4, 4)), + ], +) +def test_extract_tiles_3d(array_3D, tile_size, overlaps): + """Test extracting tiles for prediction in 3D. + + The 3D array is a fixture of shape (1, 8, 16, 16).""" + check_extract_tiles(array_3D, tile_size, overlaps) + + +@pytest.mark.parametrize("axis_size", [32, 35, 40]) +@pytest.mark.parametrize("patch_size, overlap", [(16, 4), (8, 6), (16, 8), (32, 24)]) +def test_compute_crop_and_stitch_coords_1d(axis_size, patch_size, overlap): + ( + crop_coords, + stitch_coords, + overlap_crop_coords, + ) = _compute_crop_and_stitch_coords_1d(axis_size, patch_size, overlap) + + # check that the number of patches is sufficient to cover the whole axis and that + # the number of coordinates is + # the same for all three coordinate groups + num_patches = np.ceil((axis_size - overlap) / (patch_size - overlap)).astype(int) + assert ( + len(crop_coords) + == len(stitch_coords) + == len(overlap_crop_coords) + == num_patches + ) + # check if 0 is the first coordinate, axis_size is last coordinate in all three + # coordinate groups + assert all( + all((group[0][0] == 0, group[-1][1] == axis_size)) + for group in [crop_coords, stitch_coords] + ) + # check if neighboring stitch coordinates are equal + assert all( + stitch_coords[i][1] == stitch_coords[i + 1][0] + for i in range(len(stitch_coords) - 1) + ) + + # check that the crop coordinates cover the whole axis + assert ( + np.sum(np.array(crop_coords)[:, 1] - np.array(crop_coords)[:, 0]) + == patch_size * num_patches + ) + + # check that the overlap crop coordinates cover the whole axis + assert ( + np.sum( + np.array(overlap_crop_coords)[:, 1] - np.array(overlap_crop_coords)[:, 0] + ) + == axis_size + ) + + # check that shape of all cropped tiles is equal + assert np.array_equal( + np.array(overlap_crop_coords)[:, 1] - np.array(overlap_crop_coords)[:, 0], + np.array(stitch_coords)[:, 1] - np.array(stitch_coords)[:, 0], + ) diff --git a/tests/dataset/test_dataset_utils.py b/tests/dataset/test_dataset_utils.py deleted file mode 100644 index b4e35c0d..00000000 --- a/tests/dataset/test_dataset_utils.py +++ /dev/null @@ -1,149 +0,0 @@ -import numpy as np -import pytest - -from careamics.dataset.patching import ( - _compute_crop_and_stitch_coords_1d, - _compute_number_of_patches, - _compute_overlap, - _compute_patch_steps, - _compute_reshaped_view, -) - - -@pytest.mark.parametrize( - "shape, patch_sizes, expected", - [ - ((1, 10, 10), (10, 5), (1, 2)), - ((1, 9, 9), (4, 3), (3, 3)), - ((1, 10, 9), (3, 5), (4, 2)), - ((1, 5, 9, 10), (2, 3, 5), (3, 3, 2)), - ], -) -def test_compute_number_of_patches(shape, patch_sizes, expected): - """Test computing number of patches""" - arr = np.ones(shape) - - assert _compute_number_of_patches(arr, patch_sizes) == expected - - -@pytest.mark.parametrize( - "shape, patch_sizes, expected", - [ - ((1, 10, 10), (10, 5), (0, 0)), - ((1, 9, 9), (4, 3), (2, 0)), - ((1, 10, 9), (3, 5), (1, 1)), - ], -) -def test_compute_overlap(shape, patch_sizes, expected): - """Test computing overlap between patches""" - arr = np.ones(shape) - - assert _compute_overlap(arr, patch_sizes) == expected - - -@pytest.mark.parametrize("dims", [2, 3]) -@pytest.mark.parametrize("patch_size", [2, 3]) -@pytest.mark.parametrize("overlap", [0, 1, 4]) -def test_compute_patch_steps(dims, patch_size, overlap): - """Test computing patch steps""" - patch_sizes = (patch_size,) * dims - overlaps = (overlap,) * dims - expected = (min(patch_size - overlap, patch_size),) * dims - - assert _compute_patch_steps(patch_sizes, overlaps) == expected - - -def check_compute_reshaped_view(array, window_shape, steps): - """Check the number of patches""" - - win = (1, *window_shape) - step = (1, *steps) - output_shape = (-1, *window_shape) - - # compute views - output = _compute_reshaped_view(array, win, step, output_shape) - - # check the number of patches - n_patches = [ - np.ceil((array.shape[1 + i] - window_shape[i] + 1) / steps[i]).astype(int) - for i in range(len(window_shape)) - ] - assert output.shape == (np.prod(n_patches), *window_shape) - - -@pytest.mark.parametrize("axis_size", [32, 35, 40]) -@pytest.mark.parametrize("patch_size, overlap", [(16, 4), (8, 6), (16, 8), (32, 24)]) -def test_compute_crop_and_stitch_coords_1d(axis_size, patch_size, overlap): - ( - crop_coords, - stitch_coords, - overlap_crop_coords, - ) = _compute_crop_and_stitch_coords_1d(axis_size, patch_size, overlap) - - # check that the number of patches is sufficient to cover the whole axis and that - # the number of coordinates is - # the same for all three coordinate groups - num_patches = np.ceil((axis_size - overlap) / (patch_size - overlap)).astype(int) - assert ( - len(crop_coords) - == len(stitch_coords) - == len(overlap_crop_coords) - == num_patches - ) - # check if 0 is the first coordinate, axis_size is last coordinate in all three - # coordinate groups - assert all( - all((group[0][0] == 0, group[-1][1] == axis_size)) - for group in [crop_coords, stitch_coords] - ) - # check if neighboring stitch coordinates are equal - assert all( - stitch_coords[i][1] == stitch_coords[i + 1][0] - for i in range(len(stitch_coords) - 1) - ) - - # check that the crop coordinates cover the whole axis - assert ( - np.sum(np.array(crop_coords)[:, 1] - np.array(crop_coords)[:, 0]) - == patch_size * num_patches - ) - - # check that the overlap crop coordinates cover the whole axis - assert ( - np.sum( - np.array(overlap_crop_coords)[:, 1] - np.array(overlap_crop_coords)[:, 0] - ) - == axis_size - ) - - # check that shape of all cropped tiles is equal - assert np.array_equal( - np.array(overlap_crop_coords)[:, 1] - np.array(overlap_crop_coords)[:, 0], - np.array(stitch_coords)[:, 1] - np.array(stitch_coords)[:, 0], - ) - - -@pytest.mark.parametrize( - "window_shape, steps", - [ - ((5, 5), (1, 1)), - ((5, 5), (2, 3)), - ((5, 7), (1, 1)), - ], -) -def test_compute_reshaped_view_2d(array_2D, window_shape, steps): - """Test computing reshaped view of an array of shape (1, 10, 9).""" - check_compute_reshaped_view(array_2D, window_shape, steps) - - -@pytest.mark.parametrize( - "window_shape, steps", - [ - ((1, 5, 5), (2, 1, 2)), - ((2, 5, 5), (2, 3, 4)), - ((3, 7, 8), (1, 1, 3)), - ], -) -def test_compute_reshaped_view_3d(array_3D, window_shape, steps): - """Test computing reshaped view of an array of shape (1, 5, 10, 9).""" - check_compute_reshaped_view(array_3D, window_shape, steps) diff --git a/tests/dataset/test_in_memory_dataset.py b/tests/dataset/test_in_memory_dataset.py new file mode 100644 index 00000000..8d2c2738 --- /dev/null +++ b/tests/dataset/test_in_memory_dataset.py @@ -0,0 +1,112 @@ +import numpy as np +import pytest +import tifffile + +from careamics.config import DataModel +from careamics.config.support import SupportedData +from careamics.dataset import InMemoryDataset + + +def test_number_of_patches(ordered_array): + """Test the number of patches extracted from InMemoryDataset.""" + # create array + array = ordered_array((20, 20)) + + # create config + config_dict = { + "data_type": SupportedData.ARRAY.value, + "patch_size": [8, 8], + "axes": "YX", + } + config = DataModel(**config_dict) + + # create dataset + dataset = InMemoryDataset( + data_config=config, + inputs=array, + ) + + # check number of patches + assert len(dataset) == dataset.data.shape[0] + + +def test_compute_mean_std_transform(ordered_array): + """Test that mean and std are computed and correctly added to the configuration + and transform.""" + pass + + +@pytest.mark.parametrize("percentage", [0.1, 0.6]) +def test_extracting_val_array(ordered_array, percentage): + """Test extracting a validation set patches from InMemoryDataset.""" + # create array + array = ordered_array((32, 32)) + + # create config + config_dict = { + "data_type": SupportedData.ARRAY.value, + "patch_size": [8, 8], + "axes": "YX", + } + config = DataModel(**config_dict) + + # create dataset + dataset = InMemoryDataset( + data_config=config, + inputs=array, + ) + + # compute number of patches + total_n_patches = len(dataset) + minimum_patches = 5 + n_patches = max(round(percentage * total_n_patches), minimum_patches) + + # extract datset + valset = dataset.split_dataset(percentage, minimum_patches) + + # check number of patches + assert len(valset) == n_patches + assert len(dataset) == total_n_patches - n_patches + + # check that none of the validation patch values are in the original dataset + assert np.in1d(valset.data, dataset.data).sum() == 0 + + +@pytest.mark.parametrize("percentage", [0.1, 0.6]) +def test_extracting_val_files(tmp_path, ordered_array, percentage): + """Test extracting a validation set patches from InMemoryDataset.""" + # create array + array = ordered_array((32, 32)) + + # save array to file + file_path = tmp_path / "array.tif" + tifffile.imwrite(file_path, array) + + # create config + config_dict = { + "data_type": SupportedData.ARRAY.value, + "patch_size": [8, 8], + "axes": "YX", + } + config = DataModel(**config_dict) + + # create dataset + dataset = InMemoryDataset( + data_config=config, + inputs=[file_path], + ) + + # compute number of patches + total_n_patches = len(dataset) + minimum_patches = 5 + n_patches = max(round(percentage * total_n_patches), minimum_patches) + + # extract datset + valset = dataset.split_dataset(percentage, minimum_patches) + + # check number of patches + assert len(valset) == n_patches + assert len(dataset) == total_n_patches - n_patches + + # check that none of the validation patch values are in the original dataset + assert np.in1d(valset.data, dataset.data).sum() == 0 diff --git a/tests/dataset/test_iterable_dataset.py b/tests/dataset/test_iterable_dataset.py new file mode 100644 index 00000000..5347310c --- /dev/null +++ b/tests/dataset/test_iterable_dataset.py @@ -0,0 +1,141 @@ +import numpy as np +import pytest +import tifffile + +from careamics.config import DataModel +from careamics.config.support import SupportedData +from careamics.dataset import PathIterableDataset +from careamics.dataset.dataset_utils import read_tiff + + +@pytest.mark.parametrize( + "shape", + [ + # 2D + (32, 32), + # 3D + (32, 32, 32), + ], +) +def test_number_of_files(tmp_path, ordered_array, shape): + """Test number of files in PathIterableDataset.""" + # create array + array_size = 32 + patch_size = 8 + n_files = 3 + factor = len(shape) + axes = "YX" if factor == 2 else "ZYX" + patch_sizes = [patch_size] * factor + array = ordered_array(shape) + + # save three files + files = [] + for i in range(n_files): + file = tmp_path / f"array{i}.tif" + tifffile.imwrite(file, array) + files.append(file) + + # create config + config_dict = { + "data_type": SupportedData.TIFF.value, + "patch_size": patch_sizes, + "axes": axes, + } + config = DataModel(**config_dict) + + # create dataset + dataset = PathIterableDataset( + data_config=config, src_files=files, read_source_func=read_tiff + ) + + # check number of files + assert dataset.data_files == files + + # iterate over dataset + patches = list(dataset) + assert len(patches) == n_files * (array_size / patch_size) ** factor + + +def test_read_function(tmp_path, ordered_array): + """Test reading files in PathIterableDataset using a custom read function.""" + + # read function for .npy files + def read_npy(file_path, *args, **kwargs): + return np.load(file_path) + + array_size = 32 + patch_size = 8 + n_files = 3 + patch_sizes = [patch_size] * 2 + + # create array + array: np.ndarray = ordered_array((n_files, array_size, array_size)) + + # save each plane in a single .npy file + files = [] + for i in range(array.shape[0]): + file_path = tmp_path / f"array{i}.npy" + np.save(file_path, array[i]) + files.append(file_path) + + # create config + config_dict = { + "data_type": SupportedData.CUSTOM.value, + "patch_size": patch_sizes, + "axes": "YX", + } + config = DataModel(**config_dict) + + # create dataset + dataset = PathIterableDataset( + data_config=config, + src_files=files, + read_source_func=read_npy, + ) + assert dataset.data_files == files + + # iterate over dataset + patches = list(dataset) + assert len(patches) == n_files * (array_size / patch_size) ** 2 + + +@pytest.mark.parametrize("percentage", [0.1, 0.6]) +def test_extracting_val_files(tmp_path, ordered_array, percentage): + """Test extracting a validation set patches from PathIterableDataset.""" + # create array + array = ordered_array((20, 20)) + + # save array to 25 files + files = [] + for i in range(25): + file_path = tmp_path / f"array{i}.tif" + tifffile.imwrite(file_path, array) + files.append(file_path) + + # create config + config_dict = { + "data_type": SupportedData.TIFF.value, + "patch_size": [8, 8], + "axes": "YX", + } + config = DataModel(**config_dict) + + # create dataset + dataset = PathIterableDataset( + data_config=config, src_files=files, read_source_func=read_tiff + ) + + # compute number of patches + total_n_files = dataset.get_number_of_files() + minimum_files = 5 + n_files = max(round(percentage * total_n_files), minimum_files) + + # extract datset + valset = dataset.split_dataset(percentage, minimum_files) + + # check number of patches + assert valset.get_number_of_files() == n_files + assert dataset.get_number_of_files() == total_n_files - n_files + + # check that none of the validation files are in the original dataset + assert set(valset.data_files).isdisjoint(set(dataset.data_files)) diff --git a/tests/dataset/test_patching.py b/tests/dataset/test_patching.py deleted file mode 100644 index c6e34cb3..00000000 --- a/tests/dataset/test_patching.py +++ /dev/null @@ -1,211 +0,0 @@ -import numpy as np -import pytest - -from careamics.dataset.patching import ( - ExtractionStrategy, - _extract_patches_random, - _extract_patches_sequential, - _extract_tiles, - _patches_sanity_check, - generate_patches, -) - - -def test_generate_patches_tiled_without_overlap(): - """Check that generating tiled patches fails without overlap.""" - with pytest.raises(ValueError): - generate_patches( - np.zeros((1, 32, 32)), - patch_extraction_method=ExtractionStrategy.TILED, - patch_size=(8, 8), - patch_overlap=None, - ) - - -def check_extract_patches_sequential(array, patch_size): - """Check that the patches are extracted correctly. - - The array should have been generated using np.arange and np.reshape.""" - patch_generator = _extract_patches_sequential(array, patch_size) - - # check patch shape - patches = [] - for patch in patch_generator: - patches.append(patch) - assert patch.shape[1:] == patch_size - - # check that all values are covered by the patches - n_max = np.prod(array.shape) # maximum value in the array - unique = np.unique(np.array(patches)) # unique values in the patches - assert len(unique) == n_max - - -def check_extract_patches_random(array, patch_size): - """Check that the patches are extracted correctly. - - The array should have been generated using np.arange and np.reshape.""" - patch_generator = _extract_patches_random(array, patch_size) - - # check patch shape - patches = [] - for patch in patch_generator: - patches.append(patch) - assert patch.shape == patch_size - - -def check_extract_tiles(array, tile_size, overlaps): - """Test extracting patches randomly.""" - tile_data_generator = _extract_tiles(array, tile_size, overlaps) - - tiles = [] - all_overlap_crop_coords = [] - all_stitch_coords = [] - # Assemble all tiles and their respective coordinates - for tile_data in tile_data_generator: - tile, _, _, overlap_crop_coords, stitch_coords = tile_data - tiles.append(tile) - all_overlap_crop_coords.append(overlap_crop_coords) - all_stitch_coords.append(stitch_coords) - - # check tile shape, ignore sample dimension - assert tile.shape[1:] == tile_size - assert len(overlap_crop_coords) == len(stitch_coords) == len(tile_size) - - # check that each tile has a unique set of coordinates - assert len(tiles) == len(all_overlap_crop_coords) == len(all_stitch_coords) - - # check that all values are covered by the tiles - n_max = np.prod(array.shape) # maximum value in the array - unique = np.unique(np.array(tiles)) # unique values in the patches - assert len(unique) >= n_max - - -@pytest.mark.parametrize( - "arr_shape, patch_size", - [ - ((1, 8, 8), (2, 2)), - ((1, 8, 8, 8), (2, 2, 2)), - ], -) -def test_patches_sanity_check(arr_shape, patch_size): - arr = np.zeros(arr_shape) - is_3d_patch = len(patch_size) == 3 - # check if the patch is 2D or 3D. Subtract 1 because the first dimension is sample - _patches_sanity_check(arr, patch_size, is_3d_patch) - - -@pytest.mark.parametrize( - "arr_shape, patch_size", - [ - # Wrong number of dimensions 2D - ((10, 10), (5, 5)), - # minimum 3 dimensions CYX - ((10, 10), (5, 5, 5)), - ((1, 1, 10, 10), (5, 5)), - # Wrong number of dimensions 3D - ((10, 10, 10), (5, 5, 5)), - ((1, 10, 10, 10), (5, 5)), - ((1, 1, 10, 10, 10), (5, 5)), - ((1, 1, 10, 10, 10), (5, 5, 5)), - # Wrong z patch size - ((1, 10, 10), (5, 5, 5)), - ((10, 10, 10), (10, 5, 5)), - # Wrong YX patch sizes - ((1, 10, 10), (12, 5)), - ((1, 10, 10), (5, 11)), - ], -) -def test_patches_sanity_check_invalid_cases(arr_shape, patch_size): - arr = np.zeros(arr_shape) - is_3d_patch = len(patch_size) == 3 - # check if the patch is 2D or 3D. Subtract 1 because the first dimension is sample - with pytest.raises(ValueError): - _patches_sanity_check(arr, patch_size, is_3d_patch) - - -@pytest.mark.parametrize( - "patch_size", - [ - (2, 2), - (4, 2), - (4, 8), - (8, 8), - ], -) -def test_extract_patches_sequential_2d(array_2D, patch_size): - """Test extracting patches sequentially in 2D.""" - check_extract_patches_sequential(array_2D, patch_size) - - -@pytest.mark.parametrize( - "patch_size", - [ - (2, 2, 4), - (4, 2, 2), - (2, 8, 4), - (4, 8, 8), - ], -) -def test_extract_patches_sequential_3d(array_3D, patch_size): - """Test extracting patches sequentially in 3D. - - The 3D array is a fixture of shape (1, 8, 16, 16).""" - # TODO changed the fixture to (1, 8, 16, 16), uneven shape doesnt work. We need to - # discuss the function or the test cases - check_extract_patches_sequential(array_3D, patch_size) - - -@pytest.mark.parametrize( - "patch_size", - [ - (2, 2), - (4, 2), - (4, 8), - (8, 8), - ], -) -def test_extract_patches_random_2d(array_2D, patch_size): - """Test extracting patches randomly in 2D.""" - check_extract_patches_random(array_2D, patch_size) - - -@pytest.mark.parametrize( - "patch_size", - [ - (2, 2, 4), - (4, 2, 2), - (2, 8, 4), - (4, 8, 8), - ], -) -def test_extract_patches_random_3d(array_3D, patch_size): - """Test extracting patches randomly in 3D. - - The 3D array is a fixture of shape (1, 8, 16, 16).""" - check_extract_patches_random(array_3D, patch_size) - - -@pytest.mark.parametrize( - "tile_size, overlaps", - [ - ((4, 4), (2, 2)), - ((8, 8), (4, 4)), - ], -) -def test_extract_tiles_2d(array_2D, tile_size, overlaps): - """Test extracting tiles for prediction in 2D.""" - check_extract_tiles(array_2D, tile_size, overlaps) - - -@pytest.mark.parametrize( - "tile_size, overlaps", - [ - ((4, 4, 4), (2, 2, 2)), - ((8, 8, 8), (4, 4, 4)), - ], -) -def test_extract_tiles_3d(array_3D, tile_size, overlaps): - """Test extracting tiles for prediction in 3D. - - The 3D array is a fixture of shape (1, 8, 16, 16).""" - check_extract_tiles(array_3D, tile_size, overlaps) diff --git a/tests/dataset/test_tiff_dataset.py b/tests/dataset/test_tiff_dataset.py deleted file mode 100644 index 8b82afed..00000000 --- a/tests/dataset/test_tiff_dataset.py +++ /dev/null @@ -1,57 +0,0 @@ -import numpy as np -import pytest -import tifffile - -from careamics.dataset.extraction_strategy import ExtractionStrategy -from careamics.dataset.tiff_dataset import TiffDataset - - -@pytest.mark.parametrize( - "shape, axes", - [ - ((16, 16), "YX"), - ((8, 16, 16), "ZYX"), - ((2, 16, 16), "SYX"), - ], -) -def test_tiff_dataset(tmp_path, ordered_array, shape, axes): - """Test loading tiffs.""" - array = ordered_array(shape) - array2 = array * 2 - - # save arrays - tifffile.imwrite(tmp_path / "test1.tif", array) - tifffile.imwrite(tmp_path / "test2.tif", array2) - - # create dataset - dataset = TiffDataset( - data_path=tmp_path, - data_format="tif", - axes=axes, - patch_extraction_method=ExtractionStrategy.SEQUENTIAL, - ) - - # check mean and std - all_arrays = np.concatenate([array, array2], axis=0) - mean = np.mean(all_arrays) - std = np.mean([np.std(array), np.std(array2)]) - - assert dataset.mean == pytest.approx(mean) - assert dataset.std == pytest.approx(std) - - -def test_tiff_dataset_not_dir(tmp_path, ordered_array): - """Test loading tiffs.""" - array = ordered_array((16, 16)) - - # save array - tifffile.imwrite(tmp_path / "test1.tif", array) - - # create dataset - with pytest.raises(ValueError): - TiffDataset( - data_path=tmp_path / "test1.tif", - data_format="tif", - axes="YX", - patch_extraction_method=ExtractionStrategy.SEQUENTIAL, - ) diff --git a/tests/manipulation/test_pixel_manipulation.py b/tests/manipulation/test_pixel_manipulation.py deleted file mode 100644 index 9c9b2961..00000000 --- a/tests/manipulation/test_pixel_manipulation.py +++ /dev/null @@ -1,81 +0,0 @@ -import numpy as np -import pytest - -from careamics.manipulation.pixel_manipulation import ( - default_manipulate, - get_stratified_coords, -) - - -@pytest.mark.parametrize( - "mask_pixel_perc, shape, num_iterations", - [(0.4, (32, 32), 1000), (0.4, (10, 10, 10), 1000)], -) -def test_get_stratified_coords(mask_pixel_perc, shape, num_iterations): - """Test the get_stratified_coords function. - - Ensure that the array of coordinates is randomly distributed across the - image and doesn't demonstrate any strong pattern. - """ - # Define the dummy array - array = np.zeros(shape) - - # Iterate over the number of iterations and add the coordinates. This is an MC - # simulation to ensure that the coordinates are randomly distributed and not - # biased towards any particular region. - for _ in range(num_iterations): - # Get the coordinates of the pixels to be masked - coords = get_stratified_coords(mask_pixel_perc, shape) - # Check every pair in the array of coordinates - for coord_pair in coords: - # Check that the coordinates are of the same shape as the patch - assert len(coord_pair) == len(shape) - # Check that the coordinates are positive values - assert all(coord_pair) >= 0 - # Check that the coordinates are within the shape of the array - assert [c <= s for c, s in zip(coord_pair, shape)] - - # Add the 1 to the every coordinate location. - array[tuple(np.array(coords).T.tolist())] += 1 - - # Ensure that there's no strong pattern in the array and sufficient number of - # pixels is masked. - assert np.sum(array == 0) < np.sum(shape) - - -def test_default_manipulate_2d(array_2D): - """Test the default_manipulate function. - - Ensure that the function returns an array of the same shape as the input. - """ - # Get manipulated patch, original patch and mask - patch, original_patch, mask = default_manipulate(array_2D, 0.5) - - # Add sample dimension to the moch input array - array_2D = array_2D[np.newaxis, ...] - # Check that the shapes of the arrays are the same - assert patch.shape == array_2D.shape - assert original_patch.shape == array_2D.shape - assert mask.shape == array_2D.shape - - # Check that the manipulated patch is different from the original patch - assert not np.array_equal(patch, original_patch) - - -def test_default_manipulate_3d(array_3D): - """Test the default_manipulate function. - - Ensure that the function returns an array of the same shape as the input. - """ - # Get manipulated patch, original patch and mask - patch, original_patch, mask = default_manipulate(array_3D, 0.5) - - # Add sample dimension to the moch input array - array_3D = array_3D[np.newaxis, ...] - # Check that the shapes of the arrays are the same - assert patch.shape == array_3D.shape - assert original_patch.shape == array_3D.shape - assert mask.shape == array_3D.shape - - # Check that the manipulated patch is different from the original patch - assert not np.array_equal(patch, original_patch) diff --git a/tests/model/test_model.py b/tests/model/test_model.py deleted file mode 100644 index 201c3450..00000000 --- a/tests/model/test_model.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest - -from careamics.models.unet import UNet - - -@pytest.mark.parametrize("depth", [1, 3, 5]) -def test_unet_depth(depth): - """Test that the UNet has the correct number of down and up convs - with respect to the depth.""" - model = UNet(conv_dim=2, depth=depth) - - # check that encoder has the right number of down convs - counter_down = 0 - for _, layer in model.encoder.encoder_blocks.named_children(): - if type(layer).__name__ == "Conv_Block": - counter_down += 1 - - assert counter_down == depth - - # check that decoder has the right number of up convs - counter_up = 0 - for _, layer in model.decoder.decoder_blocks.named_children(): - if type(layer).__name__ == "Conv_Block": - counter_up += 1 - - assert counter_up == depth diff --git a/tests/model/test_model_factory.py b/tests/model/test_model_factory.py deleted file mode 100644 index 142d2a17..00000000 --- a/tests/model/test_model_factory.py +++ /dev/null @@ -1,2 +0,0 @@ -# TODO test model creation from model zoo -# TODO test import model... diff --git a/tests/model_io/test_bmz_io.py b/tests/model_io/test_bmz_io.py new file mode 100644 index 00000000..a031aa0d --- /dev/null +++ b/tests/model_io/test_bmz_io.py @@ -0,0 +1,66 @@ +import numpy as np +from torch import Tensor + +from careamics import CAREamist +from careamics.model_io import export_to_bmz, load_pretrained +from careamics.model_io.bmz_io import _export_state_dict, _load_state_dict + + +def test_state_dict_io(tmp_path, pre_trained): + """Test exporting and loading a state dict.""" + # training data + train_array = np.ones((32, 32), dtype=np.float32) + path = tmp_path / "model.pth" + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # predict (no tiling and no tta) + predicted = careamist.predict(train_array, tta_transforms=False) + + # save model + _export_state_dict(careamist.model, path) + assert path.exists() + + # load model + _load_state_dict(careamist.model, path) + + # predict (no tiling and no tta) + predicted_loaded = careamist.predict(train_array, tta_transforms=False) + assert (predicted_loaded == predicted).all() + + +def test_bmz_io(tmp_path, pre_trained): + """Test exporting and loading to the BMZ.""" + # training data + train_array = np.ones((32, 32), dtype=np.float32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # predict (no tiling and no tta) + predicted = careamist.predict(train_array, tta_transforms=False) + + # export to BioImage Model Zoo + path = tmp_path / "model.zip" + export_to_bmz( + model=careamist.model, + config=careamist.cfg, + path=path, + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + input_array=train_array[np.newaxis, np.newaxis, ...], + output_array=predicted, + ) + assert path.exists() + + # load model + model, config = load_pretrained(path) + assert config == careamist.cfg + + # compare predictions + torch_array = Tensor(train_array[np.newaxis, np.newaxis, ...]) + predicted = careamist.model.forward(torch_array).detach().numpy().squeeze() + predicted_loaded = model.forward(torch_array).detach().numpy().squeeze() + assert (predicted_loaded == predicted).all() diff --git a/tests/models/test_model_factory.py b/tests/models/test_model_factory.py new file mode 100644 index 00000000..f4526aae --- /dev/null +++ b/tests/models/test_model_factory.py @@ -0,0 +1,63 @@ +import pytest +from torch import nn, ones + +from careamics.config.architectures import ( + CustomModel, + UNetModel, + VAEModel, + register_model, +) +from careamics.config.support import SupportedArchitecture +from careamics.models import UNet, model_factory + + +def test_model_registry_unet(): + """Test that""" + model_config = { + "architecture": "UNet", + } + + # instantiate model + model = model_factory(UNetModel(**model_config)) + assert isinstance(model, UNet) + + +def test_model_registry_custom(): + """Test that a custom model can be retrieved and instantiated.""" + + # create and register a custom model + @register_model(name="linear_model") + class LinearModel(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(ones(in_features, out_features)) + self.bias = nn.Parameter(ones(out_features)) + + def forward(self, input): + return (input @ self.weight) + self.bias + + model_config = { + "architecture": SupportedArchitecture.CUSTOM.value, + "name": "linear_model", + "in_features": 10, + "out_features": 5, + } + + # instantiate model + model = model_factory(CustomModel(**model_config)) + assert isinstance(model, LinearModel) + assert model.in_features == 10 + assert model.out_features == 5 + + +def test_vae(): + """Test that VAE are currently not supported.""" + model_config = { + "architecture": SupportedArchitecture.VAE.value, + } + + with pytest.raises(NotImplementedError): + model_factory(VAEModel(**model_config)) diff --git a/tests/models/test_unet.py b/tests/models/test_unet.py new file mode 100644 index 00000000..a29f1103 --- /dev/null +++ b/tests/models/test_unet.py @@ -0,0 +1,70 @@ +import pytest +import torch + +from careamics.models.layers import MaxBlurPool +from careamics.models.unet import UNet + + +@pytest.mark.parametrize("depth", [1, 3, 5]) +def test_unet_depth(depth): + """Test that the UNet has the correct number of down and up convs + with respect to the depth.""" + model = UNet(conv_dims=2, depth=depth) + + # check that encoder has the right number of down convs + counter_down = 0 + for _, layer in model.encoder.encoder_blocks.named_children(): + if type(layer).__name__ == "Conv_Block": + counter_down += 1 + + assert counter_down == depth + + # check that decoder has the right number of up convs + counter_up = 0 + for _, layer in model.decoder.decoder_blocks.named_children(): + if type(layer).__name__ == "Conv_Block": + counter_up += 1 + + assert counter_up == depth + + +@pytest.mark.parametrize( + "input_shape", + [ + (1, 1, 1024, 1024), + (1, 1, 512, 512), + (1, 1, 256, 256), + (1, 1, 128, 128), + (1, 1, 64, 64), + (1, 1, 32, 32), + (1, 1, 16, 16), + (1, 1, 8, 8), + ], +) +def test_blurpool2d(input_shape): + """Test that the BlurPool2d layer works as expected.""" + layer = MaxBlurPool(dim=2, kernel_size=3) + assert layer(torch.randn(input_shape)).shape == tuple( + [1, 1] + [i // 2 for i in input_shape[2:]] + ) + + +@pytest.mark.parametrize( + "input_shape", + [ + (1, 1, 256, 256, 256), + (1, 1, 128, 128, 128), + (1, 1, 64, 128, 128), + (1, 1, 64, 64, 64), + (1, 1, 32, 64, 64), + (1, 1, 32, 32, 32), + (1, 1, 16, 16, 16), + (1, 1, 8, 8, 8), + ], +) +def test_blurpool3d(input_shape): + """Test that the BlurPool3d layer works as expected.""" + layer = MaxBlurPool(dim=3, kernel_size=3) + assert layer(torch.randn(input_shape)).shape == tuple( + [1, 1] + [i // 2 for i in input_shape[2:]] + ) diff --git a/tests/prediction/test_stitch_prediction.py b/tests/prediction/test_stitch_prediction.py new file mode 100644 index 00000000..4908af23 --- /dev/null +++ b/tests/prediction/test_stitch_prediction.py @@ -0,0 +1,41 @@ +import pytest +from torch import from_numpy, tensor + +from careamics.dataset.patching.tiled_patching import extract_tiles +from careamics.prediction.stitch_prediction import stitch_prediction + + +@pytest.mark.parametrize( + "input_shape, tile_size, overlaps", + [ + ((1, 1, 8, 8), (4, 4), (2, 2)), + ((1, 1, 8, 8), (4, 4), (2, 2)), + ((1, 1, 7, 9), (4, 4), (2, 2)), + ((1, 1, 9, 7, 8), (4, 4, 4), (2, 2, 2)), + ((1, 1, 321, 481), (256, 256), (48, 48)), + ], +) +def test_stitch_prediction(ordered_array, input_shape, tile_size, overlaps): + """Test calculating stitching coordinates.""" + arr = ordered_array(input_shape, dtype=int) + tiles = [] + stitching_data = [] + + # extract tiles + tile_generator = extract_tiles(arr, tile_size, overlaps) + + # Assemble all tiles as it is done during the prediction stage + for tile_data, tile_info in tile_generator: + tiles.append(from_numpy(tile_data)) # need to convert to torch.Tensor + stitching_data.append( + ( # this is way too wacky + [tensor(i) for i in input_shape], # need to convert to torch.Tensor + [[tensor([j]) for j in i] for i in tile_info.overlap_crop_coords], + [[tensor([j]) for j in i] for i in tile_info.stitch_coords], + ) + ) + + # compute stitching coordinates, it returns a torch.Tensor + result = stitch_prediction(tiles, stitching_data) + + assert (result.numpy() == arr).all() diff --git a/tests/smoke_test.py b/tests/smoke_test.py deleted file mode 100644 index d4e01c45..00000000 --- a/tests/smoke_test.py +++ /dev/null @@ -1,134 +0,0 @@ -import tempfile -from pathlib import Path -from typing import Tuple - -import numpy as np -import pytest -import tifffile - -from careamics.config import Configuration -from careamics.config.algorithm import Algorithm -from careamics.config.data import Data -from careamics.config.training import LrScheduler, Optimizer, Training -from careamics.engine import Engine - - -@pytest.fixture -def temp_dir() -> Path: - with tempfile.TemporaryDirectory() as temp_dir: - yield Path(temp_dir) - - -@pytest.fixture -def example_data_path(temp_dir: Path) -> Tuple[Path, Path]: - def _example_data_path(image_size: Tuple[int, int]): - test_image = np.random.rand(*image_size) - test_image_predict = test_image[None, None, ...] - - train_path = temp_dir / "train" - val_path = temp_dir / "val" - test_path = temp_dir / "test" - train_path.mkdir() - val_path.mkdir() - test_path.mkdir() - - tifffile.imwrite(train_path / "train_image.tif", test_image) - tifffile.imwrite(val_path / "val_image.tif", test_image) - tifffile.imwrite(test_path / "test_image.tif", test_image_predict) - - return train_path, val_path, test_path - - return _example_data_path - - -@pytest.fixture -def base_configuration(temp_dir: Path, patch_size: Tuple[int, int]) -> Configuration: - def _base_configuration(axes: str) -> Configuration: - is_3d = "Z" in axes - configuration = Configuration( - experiment_name="smoke_test", - working_directory=temp_dir, - algorithm=Algorithm(loss="n2v", model="UNet", is_3D=str(is_3d)), - data=Data( - in_memory=True, - data_format="tif", - axes=axes, - ), - training=Training( - num_epochs=1, - patch_size=patch_size, - batch_size=1, - optimizer=Optimizer(name="Adam"), - lr_scheduler=LrScheduler(name="ReduceLROnPlateau"), - extraction_strategy="random", - augmentation=True, - num_workers=0, - use_wandb=False, - ), - ) - return configuration - - return _base_configuration - - -@pytest.mark.parametrize( - "image_size, axes, patch_size, overlaps", - [ - ((64, 64), "YX", (32, 32), (8, 8)), - ((2, 64, 64), "SYX", (32, 32), (8, 8)), - # ((16, 64, 64), "ZYX", (8, 32, 32), (2, 8, 8)), - # ((2, 16, 64, 64), "SZYX", (8, 32, 32), (2, 8, 8)), - ], -) -def test_is_engine_runnable( - base_configuration: Configuration, - example_data_path: Tuple[Path, Path], - image_size: Tuple[int, int], - axes: str, - patch_size: Tuple[int, int], - overlaps: Tuple[int, int], -): - """ - Test if basic workflow does not fail - train model and then predict - """ - train_path, val_path, test_path = example_data_path(image_size) - configuration = base_configuration(axes) - engine = Engine(config=configuration) - _ = engine.train(train_path, val_path) - - model_name = f"{engine.cfg.experiment_name}_best.pth" - result_model_path = engine.cfg.working_directory / model_name - - assert result_model_path.exists() - - # Test prediction with external input - test_image = np.random.rand(*image_size) - test_result = engine.predict(input=test_image) - - assert test_result is not None - - # Test prediction with pred_path without tiling - test_result = engine.predict(input=test_path) - - assert test_result is not None - - # save as bioimage - zip_path = Path(configuration.working_directory) / "model.bioimage.io.zip" - engine.save_as_bioimage(zip_path) - assert zip_path.exists() - - # Create engine from checkpoint - second_engine = Engine(model_path=result_model_path) - second_engine.cfg.data.in_memory = False - _ = second_engine.train(train_path, val_path) - - # Create engine from bioimage model - third_engine = Engine(model_path=zip_path) - third_engine.cfg.data.in_memory = False - _ = third_engine.train(train_path, val_path) - - # Test prediction with pred_path with tiling - test_result = third_engine.predict( - input=test_path, tile_shape=patch_size, overlaps=overlaps - ) - assert test_result is not None diff --git a/tests/test_augment.py b/tests/test_augment.py deleted file mode 100644 index a384295d..00000000 --- a/tests/test_augment.py +++ /dev/null @@ -1,46 +0,0 @@ -import numpy as np -import pytest - -from careamics.utils.augment import _flip_and_rotate - -ARRAY_2D = np.array([[1, 2, 3], [4, 5, 6]]) -AUG_ARRAY_2D = [ - ARRAY_2D, - np.array([[3, 2, 1], [6, 5, 4]]), # no rot + flip - np.array([[3, 6], [2, 5], [1, 4]]), # rot90 - np.array([[6, 3], [5, 2], [4, 1]]), # rot90 + flip - np.array([[6, 5, 4], [3, 2, 1]]), # rot180 - np.array([[4, 5, 6], [1, 2, 3]]), # rot180 + flip - np.array([[4, 1], [5, 2], [6, 3]]), # rot270 - np.array([[1, 4], [2, 5], [3, 6]]), # rot270 + flip -] - -ARRAY_3D = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) -AUG_ARRAY_3D = [ - ARRAY_3D, - np.array([[[3, 2, 1], [6, 5, 4]], [[9, 8, 7], [12, 11, 10]]]), # no rot + flip - np.array([[[3, 6], [2, 5], [1, 4]], [[9, 12], [8, 11], [7, 10]]]), # rot90 - np.array([[[6, 3], [5, 2], [4, 1]], [[12, 9], [11, 8], [10, 7]]]), # rot90 + flip - np.array([[[6, 5, 4], [3, 2, 1]], [[12, 11, 10], [9, 8, 7]]]), # rot180 - np.array([[[4, 5, 6], [1, 2, 3]], [[10, 11, 12], [7, 8, 9]]]), # rot180 + flip - np.array([[[4, 1], [5, 2], [6, 3]], [[10, 7], [11, 8], [12, 9]]]), # rot270 - np.array([[[1, 4], [2, 5], [3, 6]], [[7, 10], [8, 11], [9, 12]]]), # rot270 + flip -] - - -@pytest.mark.parametrize( - "array, possible_augmentations", - [ - (ARRAY_2D, AUG_ARRAY_2D), - (ARRAY_3D, AUG_ARRAY_3D), - ], -) -def test_flip_and_rotate(array, possible_augmentations): - """Test augmenting a single array with rotation and flips""" - for i_rot90 in range(4): - for j_flip in range(2): - aug_array = _flip_and_rotate(array, i_rot90, j_flip) - - assert np.array_equal( - aug_array, possible_augmentations[i_rot90 * 2 + j_flip] - ) diff --git a/tests/test_careamist.py b/tests/test_careamist.py new file mode 100644 index 00000000..f774aa37 --- /dev/null +++ b/tests/test_careamist.py @@ -0,0 +1,693 @@ +from pathlib import Path + +import numpy as np +import pytest +import tifffile + +from careamics import CAREamist, Configuration, save_configuration +from careamics.config.support import SupportedAlgorithm, SupportedData + +# TODO test 3D and channels + + +def test_no_parameters(): + """Test that CAREamics cannot be instantiated without parameters.""" + with pytest.raises(TypeError): + CAREamist() + + +def test_minimum_configuration_via_object(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be instantiated with a minimum configuration object.""" + # create configuration + config = Configuration(**minimum_configuration) + + # instantiate CAREamist + CAREamist(source=config, work_dir=tmp_path) + + +def test_minimum_configuration_via_path(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be instantiated with a path to a minimum + configuration. + """ + # create configuration + config = Configuration(**minimum_configuration) + path_to_config = save_configuration(config, tmp_path) + + # instantiate CAREamist + CAREamist(source=path_to_config) + + +def test_train_error_target_unsupervised_algorithm( + tmp_path: Path, minimum_configuration: dict +): + """Test that an error is raised when a target is provided for N2V.""" + # create configuration + config = Configuration(**minimum_configuration) + config.algorithm_config.algorithm = SupportedAlgorithm.N2V.value + + # train error with Paths + config.data_config.data_type = SupportedData.TIFF.value + careamics = CAREamist(source=config, work_dir=tmp_path) + with pytest.raises(ValueError): + careamics.train( + train_source=tmp_path, + train_target=tmp_path, + ) + + # train error with strings + with pytest.raises(ValueError): + careamics.train( + train_source=str(tmp_path), + train_target=str(tmp_path), + ) + + # train error with arrays + config.data_config.data_type = SupportedData.ARRAY.value + careamics = CAREamist(source=config, work_dir=tmp_path) + with pytest.raises(ValueError): + careamics.train( + train_source=np.ones((32, 32)), + train_target=np.ones((32, 32)), + ) + + +def test_train_single_array_no_val(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be trained with arrays.""" + # training data + train_array = np.random.rand(32, 32) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_array(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be trained on arrays.""" + # training data + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array, val_source=val_array) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_array_channel(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be trained on arrays with channels.""" + # training data + train_array = np.random.rand(32, 32, 3) + val_array = np.random.rand(32, 32, 3) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YXC" + config.algorithm_config.model.in_channels = 3 + config.algorithm_config.model.num_classes = 3 + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array, val_source=val_array) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + channel_names=["red", "green", "blue"], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_array_3d(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be trained on 3D arrays.""" + # training data + train_array = np.random.rand(8, 32, 32) + val_array = np.random.rand(8, 32, 32) + + # create configuration + minimum_configuration["data_config"]["axes"] = "ZYX" + minimum_configuration["data_config"]["patch_size"] = (8, 16, 16) + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array, val_source=val_array) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_tiff_files_in_memory_no_val(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be trained with tiff files in memory.""" + # training data + train_array = np.random.rand(32, 32) + + # save files + train_file = tmp_path / "train.tiff" + tifffile.imwrite(train_file, train_array) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.TIFF.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_file) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_tiff_files_in_memory(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be trained with tiff files in memory.""" + # training data + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + + # save files + train_file = tmp_path / "train.tiff" + tifffile.imwrite(train_file, train_array) + + val_file = tmp_path / "val.tiff" + tifffile.imwrite(val_file, val_array) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.TIFF.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_file, val_source=val_file) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_tiff_files(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be trained with tiff files by deactivating + the in memory dataset. + """ + # training data + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + + # save files + train_file = tmp_path / "train.tiff" + tifffile.imwrite(train_file, train_array) + + val_file = tmp_path / "val.tiff" + tifffile.imwrite(val_file, val_array) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.TIFF.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_file, val_source=val_file, use_in_memory=False) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_array_supervised(tmp_path: Path, supervised_configuration: dict): + """Test that CAREamics can be trained with arrays.""" + # training data + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + train_target = np.random.rand(32, 32) + val_target = np.random.rand(32, 32) + + # create configuration + config = Configuration(**supervised_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train( + train_source=train_array, + val_source=val_array, + train_target=train_target, + val_target=val_target, + ) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_tiff_files_in_memory_supervised( + tmp_path: Path, supervised_configuration: dict +): + """Test that CAREamics can be trained with tiff files in memory.""" + # training data + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + train_target = np.random.rand(32, 32) + val_target = np.random.rand(32, 32) + + # save files + images = tmp_path / "images" + images.mkdir() + train_file = images / "train.tiff" + tifffile.imwrite(train_file, train_array) + + val_file = tmp_path / "images" / "val.tiff" + tifffile.imwrite(val_file, val_array) + + targets = tmp_path / "targets" + targets.mkdir() + train_target_file = targets / "train.tiff" + tifffile.imwrite(train_target_file, train_target) + + val_target_file = targets / "val.tiff" + tifffile.imwrite(val_target_file, val_target) + + # create configuration + config = Configuration(**supervised_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.TIFF.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train( + train_source=train_file, + val_source=val_file, + train_target=train_target_file, + val_target=val_target_file, + ) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_tiff_files_supervised(tmp_path: Path, supervised_configuration: dict): + """Test that CAREamics can be trained with tiff files by deactivating + the in memory dataset. + """ + # training data + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + train_target = np.random.rand(32, 32) + val_target = np.random.rand(32, 32) + + # save files + images = tmp_path / "images" + images.mkdir() + train_file = images / "train.tiff" + tifffile.imwrite(train_file, train_array) + + val_file = tmp_path / "images" / "val.tiff" + tifffile.imwrite(val_file, val_array) + + targets = tmp_path / "targets" + targets.mkdir() + train_target_file = targets / "train.tiff" + tifffile.imwrite(train_target_file, train_target) + + val_target_file = targets / "val.tiff" + tifffile.imwrite(val_target_file, val_target) + + # create configuration + config = Configuration(**supervised_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.TIFF.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train( + train_source=train_file, + val_source=val_file, + train_target=train_target_file, + val_target=val_target_file, + use_in_memory=False, + ) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_predict_on_array_tiled( + tmp_path: Path, minimum_configuration: dict, batch_size +): + """Test that CAREamics can predict on arrays.""" + # training data + train_array = np.random.rand(32, 32) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array) + + # predict CAREamist + predicted = careamist.predict( + train_array, batch_size=batch_size, tile_size=(16, 16), tile_overlap=(4, 4) + ) + + assert predicted.squeeze().shape == train_array.shape + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_predict_arrays_no_tiling(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can predict on arrays without tiling.""" + # training data + train_array = np.random.rand(4, 32, 32) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "SYX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array) + + # predict CAREamist + predicted = careamist.predict(train_array) + + assert predicted.squeeze().shape == train_array.shape + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_predict_path(tmp_path: Path, minimum_configuration: dict, batch_size): + """Test that CAREamics can predict with tiff files.""" + # training data + train_array = np.random.rand(32, 32) + + # save files + train_file = tmp_path / "train.tiff" + tifffile.imwrite(train_file, train_array) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.TIFF.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_file) + + # predict CAREamist + predicted = careamist.predict(train_file, batch_size=batch_size) + + # check that it predicted + assert predicted.squeeze().shape == train_array.shape + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_predict_pretrained_checkpoint(tmp_path: Path, pre_trained: Path): + """Test that CAREamics can be instantiated with a pre-trained network and predict + on an array.""" + # prediction data + source_array = np.random.rand(32, 32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + assert careamist.cfg.data_config.mean is not None + assert careamist.cfg.data_config.std is not None + + # predict + predicted = careamist.predict(source_array) + + # check that it predicted + assert predicted.squeeze().shape == source_array.shape + + +def test_predict_pretrained_bmz(tmp_path: Path, pre_trained_bmz: Path): + """Test that CAREamics can be instantiated with a BMZ archive and predict.""" + # prediction data + source_array = np.random.rand(32, 32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained_bmz, work_dir=tmp_path) + + # predict + predicted = careamist.predict(source_array) + + # check that it predicted + assert predicted.squeeze().shape == source_array.shape + + +def test_export_bmz_pretrained_prediction(tmp_path: Path, pre_trained: Path): + """Test that CAREamics can be instantiated with a pre-trained network and exported + to BMZ after prediction. + + In this case, the careamist extracts the BMZ test data from the prediction + datamodule. + """ + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # prediction data + source_array = np.random.rand(32, 32) + _ = careamist.predict(source_array) + assert len(careamist.pred_datamodule.predict_dataloader()) > 0 + + # export to BMZ (random array created) + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_export_bmz_pretrained_random_array(tmp_path: Path, pre_trained: Path): + """Test that CAREamics can be instantiated with a pre-trained network and exported + to BMZ. + + In this case, the careamist creates a random array for the BMZ archive test. + """ + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # export to BMZ (random array created) + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_export_bmz_pretrained_with_array(tmp_path: Path, pre_trained: Path): + """Test that CAREamics can be instantiated with a pre-trained network and exported + to BMZ. + + In this case, we provide an array to the BMZ archive test. + """ + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # alternatively we can pass an array + array = np.random.rand(32, 32).astype(np.float32) + careamist.export_to_bmz( + path=tmp_path / "model2.zip", + name="TopModel", + input_array=array[np.newaxis, np.newaxis, ...], + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model2.zip").exists() diff --git a/tests/test_conftest.py b/tests/test_conftest.py index 8ecec6cf..901aee32 100644 --- a/tests/test_conftest.py +++ b/tests/test_conftest.py @@ -1,35 +1,30 @@ -import copy +from careamics import Configuration +from careamics.config.algorithm_model import AlgorithmModel +from careamics.config.data_model import DataModel +from careamics.config.inference_model import InferenceModel +from careamics.config.training_model import TrainingModel -import pytest -from careamics import Configuration +def test_minimum_algorithm(minimum_algorithm_n2v): + # create algorithm configuration + AlgorithmModel(**minimum_algorithm_n2v) -def _instantiate_without_key(config: dict, key: str): - if isinstance(config[key], dict): - for k in config[key]: - _instantiate_without_key(config[key], k) - else: - # copy the dict - new_config = copy.deepcopy(config) +def test_minimum_data(minimum_data): + # create data configuration + DataModel(**minimum_data) - # remove the key - new_config[key] = None - # instantiate configuration - with pytest.raises(ValueError): - Configuration(**new_config) +def test_minimum_prediction(minimum_inference): + # create prediction configuration + InferenceModel(**minimum_inference) -def test_minimum_config(minimum_config): - """ - Test that the minimum config is indeed a minimal example. +def test_minimum_training(minimum_training): + # create training configuration + TrainingModel(**minimum_training) - First we check that we can instantiate a Configuration, then we test each - key in the dictionary by removing it and checking that it raises an error. - """ - # test if the configuration is valid - Configuration(**minimum_config) - for key in minimum_config: - _instantiate_without_key(minimum_config, key) +def test_minimum_configuration(minimum_configuration): + # create configuration + Configuration(**minimum_configuration) diff --git a/tests/test_engine.py b/tests/test_engine.py deleted file mode 100644 index 8c330ff2..00000000 --- a/tests/test_engine.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytest - -from careamics.config import Configuration -from careamics.engine import Engine -from careamics.models import create_model - - -def test_engine_init_errors(): - with pytest.raises(ValueError): - Engine(config=None, config_path=None, model_path=None) - - with pytest.raises(TypeError): - Engine(config="config", config_path=None, model_path=None) - - with pytest.raises(FileNotFoundError): - Engine(config=None, config_path="some/path", model_path=None) - - with pytest.raises(FileNotFoundError): - Engine(config=None, config_path=None, model_path="some/other/path") - - -def test_engine_predict_errors(minimum_config: dict): - config = Configuration(**minimum_config) - engine = Engine(config=config) - - with pytest.raises(ValueError): - engine.predict(input=None) - - config.data.mean = None - config.data.std = None - with pytest.raises(ValueError): - engine.predict(input="some/path") - - -@pytest.mark.parametrize( - "epoch, losses", [(0, [1.0]), (1, [1.0, 0.5]), (2, [1.0, 0.5, 1.0])] -) -def test_engine_io_checkpoint(epoch, losses, minimum_config: dict): - init_config = Configuration(**minimum_config) - engine = Engine(config=init_config) - - # Mock engine attributes to test save_checkpoint - engine.optimizer.param_groups[0]["lr"] = 1 - engine.lr_scheduler.patience = 1 - path = engine._save_checkpoint(epoch=epoch, losses=losses, save_method="state_dict") - assert path.exists() - - if epoch == 0: - assert path.stem.split("_")[-1] == "best" - - if losses[-1] == min(losses): - assert path.stem.split("_")[-1] == "best" - else: - assert path.stem.split("_")[-1] == "latest" - - model, optimizer, scheduler, scaler, config = create_model(model_path=path) - assert all(model.children()) == all(engine.model.children()) - assert optimizer.__class__ == engine.optimizer.__class__ - assert scheduler.__class__ == engine.lr_scheduler.__class__ - assert scaler.__class__ == engine.scaler.__class__ - assert optimizer.param_groups[0]["lr"] == engine.optimizer.param_groups[0]["lr"] - assert optimizer.defaults["lr"] != engine.optimizer.param_groups[0]["lr"] - assert scheduler.patience == engine.lr_scheduler.patience - assert config == init_config diff --git a/tests/test_lightning_datamodule.py b/tests/test_lightning_datamodule.py new file mode 100644 index 00000000..33444971 --- /dev/null +++ b/tests/test_lightning_datamodule.py @@ -0,0 +1,140 @@ +import pytest + +from careamics import CAREamicsPredictDataModule, CAREamicsTrainDataModule +from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis + + +@pytest.fixture +def simple_array(ordered_array): + return ordered_array((10, 10)) + + +def test_lightning_train_datamodule_wrong_type(simple_array): + """Test that an error is raised if the data type is not supported.""" + with pytest.raises(ValueError): + CAREamicsTrainDataModule( + train_data=simple_array, + data_type="wrong_type", + patch_size=(10, 10), + axes="YX", + batch_size=2, + ) + + +def test_lightning_train_datamodule_array(simple_array): + """Test that the data module is created correctly with an array.""" + # create data module + data_module = CAREamicsTrainDataModule( + train_data=simple_array, + data_type="array", + patch_size=(8, 8), + axes="YX", + batch_size=2, + val_minimum_patches=2, + ) + data_module.prepare_data() + data_module.setup() + + assert len(list(data_module.train_dataloader())) > 0 + + +def test_lightning_train_datamodule_supervised_n2v_throws_error(simple_array): + """Test that an error is raised if target data is passed but the transformations + (default ones) contain N2V manipulate.""" + with pytest.raises(ValueError): + CAREamicsTrainDataModule( + train_data=simple_array, + data_type="array", + patch_size=(10, 10), + axes="YX", + batch_size=2, + train_target_data=simple_array, + val_minimum_patches=2, + ) + + +@pytest.mark.parametrize( + "use_n2v2, strategy", + [ + (True, SupportedPixelManipulation.MEDIAN), + (False, SupportedPixelManipulation.UNIFORM), + ], +) +def test_lightning_train_datamodule_n2v2(simple_array, use_n2v2, strategy): + """Test that n2v2 parameter is correctly passed.""" + data_module = CAREamicsTrainDataModule( + train_data=simple_array, + data_type="array", + patch_size=(16, 16), + axes="YX", + batch_size=2, + use_n2v2=use_n2v2, + ) + assert data_module.data_config.transforms[-1].strategy == strategy + + +def test_lightning_train_datamodule_structn2v(simple_array): + """Test that structn2v parameter is correctly passed.""" + struct_axis = SupportedStructAxis.HORIZONTAL.value + struct_span = 11 + + data_module = CAREamicsTrainDataModule( + train_data=simple_array, + data_type="array", + patch_size=(16, 16), + axes="YX", + batch_size=2, + struct_n2v_axis=struct_axis, + struct_n2v_span=struct_span, + ) + assert data_module.data_config.transforms[-1].struct_mask_axis == struct_axis + assert data_module.data_config.transforms[-1].struct_mask_span == struct_span + + +def test_lightning_predict_datamodule_wrong_type(simple_array): + """Test that an error is raised if the data type is not supported.""" + with pytest.raises(ValueError): + CAREamicsPredictDataModule( + pred_data=simple_array, + data_type="wrong_type", + mean=0.5, + std=0.1, + axes="YX", + batch_size=2, + ) + + +def test_lightning_pred_datamodule_tiling(simple_array): + """Test that the data module is created correctly with an array.""" + # create data module + data_module = CAREamicsPredictDataModule( + pred_data=simple_array, + data_type="array", + mean=0.5, + std=0.1, + axes="YX", + batch_size=2, + tile_overlap=[2, 2], + tile_size=[8, 8], + ) + + data_module.prepare_data() + data_module.setup() + assert len(list(data_module.predict_dataloader())) == 2 + + +def test_lightning_pred_datamodule_no_tiling(simple_array): + """Test that the data module is created correctly with an array.""" + # create data module + data_module = CAREamicsPredictDataModule( + pred_data=simple_array, + data_type="array", + mean=0.5, + std=0.1, + axes="YX", + batch_size=2, + ) + + data_module.prepare_data() + data_module.setup() + assert len(list(data_module.predict_dataloader())) == 1 diff --git a/tests/test_lightning_module.py b/tests/test_lightning_module.py new file mode 100644 index 00000000..c4e86ade --- /dev/null +++ b/tests/test_lightning_module.py @@ -0,0 +1,267 @@ +import pytest +import torch + +from careamics.config import AlgorithmModel +from careamics.lightning_module import CAREamicsKiln, CAREamicsModule + + +def test_careamics_module(minimum_algorithm_n2v): + """Test that the minimum algorithm allows instantiating a the Lightning API + intermediate layer.""" + algo_config = AlgorithmModel(**minimum_algorithm_n2v) + + # extract model parameters + model_parameters = algo_config.model.model_dump(exclude_none=True) + + # instantiate CAREamicsModule + CAREamicsModule( + algorithm=algo_config.algorithm, + loss=algo_config.loss, + architecture=algo_config.model.architecture, + model_parameters=model_parameters, + optimizer=algo_config.optimizer.name, + optimizer_parameters=algo_config.optimizer.parameters, + lr_scheduler=algo_config.lr_scheduler.name, + lr_scheduler_parameters=algo_config.lr_scheduler.parameters, + ) + + +def test_careamics_kiln(minimum_algorithm_n2v): + """Test that the minimum algorithm allows instantiating a CAREamicsKiln.""" + algo_config = AlgorithmModel(**minimum_algorithm_n2v) + + # instantiate CAREamicsKiln + CAREamicsKiln(algo_config) + + +@pytest.mark.parametrize( + "shape", + [ + (8, 8), + (16, 16), + (32, 32), + ], +) +def test_careamics_kiln_unet_2D_depth_2_shape(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": 1, + "num_classes": 1, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + # set model to evaluation mode to avoid batch dimension error + model.model.eval() + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize( + "shape", + [ + (8, 8), + (16, 16), + (32, 32), + (64, 64), + (128, 128), + (256, 256), + ], +) +def test_careamics_kiln_unet_2D_depth_3_shape(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": 1, + "num_classes": 1, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + # set model to evaluation mode to avoid batch dimension error + model.model.eval() + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize( + "shape", + [ + (8, 32, 16), + (16, 32, 16), + (8, 32, 32), + (32, 64, 64), + ], +) +def test_careamics_kiln_unet_depth_2_3D(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 3, + "in_channels": 1, + "num_classes": 1, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + # set model to evaluation mode to avoid batch dimension error + model.model.eval() + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize( + "shape", + [ + (8, 64, 64), + (16, 64, 64), + (16, 128, 128), + (32, 128, 128), + ], +) +def test_careamics_kiln_unet_depth_3_3D(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 3, + "in_channels": 1, + "num_classes": 1, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + # set model to evaluation mode to avoid batch dimension error + model.model.eval() + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("n_channels", [1, 3, 4]) +def test_careamics_kiln_unet_depth_2_channels_2D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + # set model to evaluation mode to avoid batch dimension error + model.model.eval() + # test forward pass + x = torch.rand((1, n_channels, 32, 32)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("n_channels", [1, 3, 4]) +def test_careamics_kiln_unet_depth_3_channels_2D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + # set model to evaluation mode to avoid batch dimension error + model.model.eval() + # test forward pass + x = torch.rand((1, n_channels, 64, 64)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("n_channels", [1, 3, 4]) +def test_careamics_kiln_unet_depth_2_channels_3D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 3, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + # set model to evaluation mode to avoid batch dimension error + model.model.eval() + # test forward pass + x = torch.rand((2, n_channels, 16, 32, 32)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("n_channels", [1, 3, 4]) +def test_careamics_kiln_unet_depth_3_channels_3D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 3, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + # set model to evaluation mode to avoid batch dimension error + model.model.eval() + # test forward pass + x = torch.rand((1, n_channels, 16, 64, 64)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape diff --git a/tests/test_prediction_utils.py b/tests/test_prediction_utils.py deleted file mode 100644 index be6e1b42..00000000 --- a/tests/test_prediction_utils.py +++ /dev/null @@ -1,86 +0,0 @@ -import pytest -from torch import from_numpy - -from careamics.dataset.patching import _extract_tiles -from careamics.prediction.prediction_utils import ( - stitch_prediction, - tta_backward, - tta_forward, -) - - -@pytest.mark.parametrize( - "input_shape, tile_size, overlaps", - [ - ((1, 8, 8), (4, 4), (2, 2)), - ((1, 7, 9), (4, 4), (2, 2)), - ((1, 9, 7, 8), (4, 4, 4), (2, 2, 2)), - ], -) -def test_stitch_prediction(input_shape, ordered_array, tile_size, overlaps): - """Test calculating stitching coordinates. - - Test cases include only valid inputs. - """ - arr = ordered_array(input_shape, dtype=int) - tiles = [] - stitching_data = [] - - # extract tiles - tiling_outputs = _extract_tiles(arr, tile_size, overlaps) - - # Assemble all tiles as it's done during the prediction stage - for tile_data in tiling_outputs: - tile, _, input_shape, overlap_crop_coords, stitch_coords = tile_data - - tiles.append(tile) - stitching_data.append( - ( - input_shape, - overlap_crop_coords, - stitch_coords, - ) - ) - # compute stitching coordinates - result = stitch_prediction(tiles, stitching_data) - assert (result == arr).all() - - -@pytest.mark.parametrize("shape", [(1, 1, 8, 8), (1, 1, 8, 8, 8)]) -def test_tta_forward(ordered_array, shape): - """Test TTA forward.""" - x = ordered_array(shape) - - # tta forward - x_aug = tta_forward(from_numpy(x)) - - # check output - assert len(x_aug) == 8 - for i, x_ in enumerate(x_aug): - # check correct shape - assert x_.shape == shape - - # arrays different (at least from the previous one) - if i > 0: - assert (x_ != x_aug[i - 1]).any() - - -@pytest.mark.parametrize( - "shape", - [ - (1, 1, 4, 4), - # (1, 1, 4, 4, 4) - ], -) -def test_tta_backward(ordered_array, shape): - """Test TTA backward.""" - x = ordered_array(shape) - - # tta forward - x_aug = tta_forward(from_numpy(x)) - - # tta backward - x_back = tta_backward(x_aug) - - # check that it returns the same array - assert (x_back == x).all() diff --git a/tests/transforms/test_manipulate_n2v.py b/tests/transforms/test_manipulate_n2v.py new file mode 100644 index 00000000..b43042fe --- /dev/null +++ b/tests/transforms/test_manipulate_n2v.py @@ -0,0 +1,33 @@ +import numpy as np +import pytest +from albumentations import Compose + +from careamics.config.support import SupportedPixelManipulation +from careamics.transforms import N2VManipulate + + +@pytest.mark.parametrize( + "strategy", + [SupportedPixelManipulation.UNIFORM.value, SupportedPixelManipulation.MEDIAN.value], +) +def test_manipulate_n2v(strategy): + """Test the N2V augmentation.""" + # create array, adding a channel to simulate a 2D image with channel last + array = np.arange(16 * 16).reshape((16, 16))[..., np.newaxis] + + # create augmentation + aug = Compose( + [N2VManipulate(roi_size=5, masked_pixel_percentage=5, strategy=strategy)] + ) + + # apply augmentation + augmented = aug(image=array) + assert "image" in augmented + assert len(augmented["image"]) == 3 # transformed_patch, original_patch, mask + + # assert that the difference between the original and transformed patch are the + # same pixels that are selected by the mask + tr_path, orig_patch, mask = augmented["image"] + diff_coords = np.array(np.where(tr_path != orig_patch)) + mask_coords = np.array(np.where(mask == 1)) + assert np.array_equal(diff_coords, mask_coords) diff --git a/tests/transforms/test_nd_flip.py b/tests/transforms/test_nd_flip.py new file mode 100644 index 00000000..626b6e5d --- /dev/null +++ b/tests/transforms/test_nd_flip.py @@ -0,0 +1,129 @@ +import numpy as np +import pytest + +from careamics.transforms import NDFlip + + +def test_randomness(ordered_array): + """Test randomness of the flipping using the `p` parameter.""" + # create array + array = ordered_array((2, 2)) + + # create augmentation that never applies + aug = NDFlip(p=0.0) + + # apply augmentation + augmented = aug(image=array)["image"] + assert np.array_equal(augmented, array) + + # create augmentation that always applies + aug = NDFlip(p=1.0) + + # apply augmentation + augmented = aug(image=array)["image"] + assert not np.array_equal(augmented, array) + + +@pytest.mark.parametrize( + "shape", + [ + # 2D + (2, 2, 1), + (2, 2, 2), + # 3D + (2, 2, 2, 1), + (2, 2, 2, 2), + ], +) +def test_flip_nd(ordered_array, shape): + """Test flipping for 2D and 3D arrays.""" + np.random.seed(42) + + # create array + array: np.ndarray = ordered_array(shape) + + # create augmentation + is_3D = len(shape) == 4 + aug = NDFlip(p=1, is_3D=is_3D, flip_z=True) + + # potential flips + axes = [0, 1, 2] if is_3D else [0, 1] + flips = [np.flip(array, axis=axis) for axis in axes] + + # apply augmentation 10 times + augs = [] + for _ in range(10): + augmented = aug(image=array)["image"] + + # check that the augmented array is one of the potential flips + which_axes = [np.array_equal(augmented, flip) for flip in flips] + + assert any(which_axes) + augs.append(which_axes.index(True)) + + # check that all flips were applied + assert set(augs) == set(axes) + + +def test_flip_z(ordered_array): + """Test turning the Z flipping off.""" + np.random.seed(42) + + # create array + array: np.ndarray = ordered_array((2, 2, 2, 2)) + + # create augmentation + aug = NDFlip(p=1, is_3D=True, flip_z=False) + + # potential flips on Y and X axes + flips = [np.flip(array, axis=1), np.flip(array, axis=2)] + + # apply augmentation 10 times + augs = [] + for _ in range(10): + augmented = aug(image=array)["image"] + + # check that the augmented array is one of the potential flips + which_axes = [np.array_equal(augmented, flip) for flip in flips] + + assert any(which_axes) + augs.append(which_axes.index(True)) + + # check that all flips were applied (first and second flip) + assert set(augs) == {0, 1} + + +def test_flip_mask(ordered_array): + """Test flipping masks in 3D.""" + np.random.seed(42) + + # create array + array: np.ndarray = ordered_array((2, 2, 2, 4)) + mask = array[..., 2:] + array = array[..., :2] + + # create augmentation + aug = NDFlip(p=1, is_3D=True, flip_z=True) + + # potential flips on Y and X axes + array_flips = [np.flip(array, axis=axis) for axis in range(3)] + mask_flips = [np.flip(mask, axis=axis) for axis in range(3)] + + # apply augmentation 10 times + for _ in range(10): + transfo = aug(image=array, mask=mask) + aug_array = transfo["image"] + aug_mask = transfo["mask"] + + # check that the augmented array is one of the potential flips + which_axes = [np.array_equal(aug_array, flip) for flip in array_flips] + assert any(which_axes) + img_axis = which_axes.index(True) + + # same for the masks + which_axes = [np.array_equal(aug_mask, flip) for flip in mask_flips] + assert any(which_axes) + mask_axis = which_axes.index(True) + + # same flip for array and mask + assert img_axis == mask_axis diff --git a/tests/transforms/test_normalize.py b/tests/transforms/test_normalize.py new file mode 100644 index 00000000..da61a911 --- /dev/null +++ b/tests/transforms/test_normalize.py @@ -0,0 +1,30 @@ +import numpy as np + +from careamics.transforms import Denormalize, Normalize + + +def test_normalize_denormalize(): + """Test the Normalize transform.""" + # Create data + array = np.arange(100).reshape((1, 1, 10, 10)) + + # Create the transform + norm = Normalize( + mean=50, + std=25, + ) + + # Apply the transform + normalized: np.array = norm(image=array)["image"] + assert np.abs(normalized.mean()) < 0.02 + assert np.abs(normalized.std() - 1) < 0.2 + + # Create the denormalize transform + denorm = Denormalize( + mean=50, + std=25, + ) + + # Apply the denormalize transform + denormalized: np.array = denorm(image=normalized)["image"] + assert np.isclose(denormalized, array).all() diff --git a/tests/transforms/test_pixel_manipulation.py b/tests/transforms/test_pixel_manipulation.py new file mode 100644 index 00000000..3488285d --- /dev/null +++ b/tests/transforms/test_pixel_manipulation.py @@ -0,0 +1,268 @@ +import numpy as np +import pytest + +from careamics.transforms.pixel_manipulation import ( + _apply_struct_mask, + _get_stratified_coords, + median_manipulate, + uniform_manipulate, +) +from careamics.transforms.struct_mask_parameters import StructMaskParameters + + +@pytest.mark.parametrize( + "mask_pixel_perc, shape, num_iterations", + [(0.4, (32, 32), 1000), (0.4, (10, 10, 10), 1000)], +) +def test_get_stratified_coords(mask_pixel_perc, shape, num_iterations): + """Test the get_stratified_coords function. + + Ensure that the array of coordinates is randomly distributed across the + image and that most pixels get selected. + """ + # Define the dummy array + array = np.zeros(shape) + + # Iterate over the number of iterations and add the coordinates. This is an MC + # simulation to ensure that the coordinates are randomly distributed and not + # biased towards any particular region. + for _ in range(num_iterations): + # Get the coordinates of the pixels to be masked + coords = _get_stratified_coords(mask_pixel_perc, shape) + + # Check that there is at least one coordinate choosen + assert len(coords) > 0 + + # Check every pair in the array of coordinates + for coord_pair in coords: + # Check that the coordinates are of the same shape as the patch dims + assert len(coord_pair) == len(shape) + + # Check that the coordinates are positive values + assert all(coord_pair) >= 0 + + # Check that the coordinates are within the shape of the array + assert [c <= s for c, s in zip(coord_pair, shape)] + + # Add the 1 to the every coordinate location. + array[tuple(np.array(coords).T.tolist())] += 1 + + # Ensure that there's no strong pattern in the array and sufficient number of + # pixels is masked. + assert np.sum(array == 0) < np.sum(shape) + + +@pytest.mark.parametrize("shape", [(8, 8), (3, 8, 8), (8, 8, 8)]) +def test_uniform_manipulate(ordered_array, shape): + """Test the uniform_manipulate function. + + Ensures that the mask corresponds to the manipulated pixels, and that the + manipulated pixels have a value taken from a ROI surrounding them. + """ + # create the array + patch = ordered_array(shape) + + # manipulate the array + transform_patch, mask = uniform_manipulate( + patch, mask_pixel_percentage=10, subpatch_size=5 + ) + + # find pixels that have different values between patch and transformed patch + diff_coords = np.array(np.where(patch != transform_patch)) + + # find non-zero pixels in the mask + mask_coords = np.array(np.where(mask == 1)) + + # check that the transformed pixels correspond to the masked pixels + assert np.array_equal(diff_coords, mask_coords) + + # for each pixel masked, check that the manipulated pixel value is within the roi + for i in range(mask_coords.shape[-1]): + # get coordinates + coords = mask_coords[..., i] + + # get roi using slice in each dimension + slices = tuple( + [ + slice(max(0, coords[i] - 2), min(shape[i], coords[i] + 3)) + for i in range(-coords.shape[0] + 1, 0) # range -4, -3, -2, -1 + ] + ) + roi = patch[ + (...,) + slices + ] # TODO ellipsis needed bc singleton dim, might need to go away + + # TODO needs to be revisited ! + # check that the pixel value comes from the actual roi + assert transform_patch[tuple(coords)] in roi + + +@pytest.mark.parametrize("shape", [(8, 8), (3, 8, 8), (8, 8, 8)]) +def test_median_manipulate(ordered_array, shape): + """Test the uniform_manipulate function. + + Ensures that the mask corresponds to the manipulated pixels, and that the + manipulated pixels have a value taken from a ROI surrounding them. + """ + # create the array + patch = ordered_array(shape).astype(np.float32) + + # manipulate the array + transform_patch, mask = median_manipulate( + patch, subpatch_size=5, mask_pixel_percentage=10 + ) + + # find pixels that have different values between patch and transformed patch + diff_coords = np.array(np.where(patch != transform_patch)) + + # find non-zero pixels in the mask + mask_coords = np.array(np.where(mask == 1)) + + # check that the transformed pixels correspond to the masked pixels + assert np.array_equal(diff_coords, mask_coords) + + # for each pixel masked, check that the manipulated pixel value is within the roi + for i in range(mask_coords.shape[-1]): + # get coordinates + coords = mask_coords[..., i] + + # get roi using slice in each dimension + slices = tuple( + [ + slice(max(0, coords[i] - 2), min(shape[i], coords[i] + 3)) + for i in range(coords.shape[0]) # range -4, -3, -2, -1 + ] + ) + roi = patch[tuple(slices)] + + # remove value of roi center from roi + roi = roi[roi != patch[tuple(coords)]] + + # check that the pixel value comes from the actual roi + assert transform_patch[tuple(coords)] == np.median(roi) + + +@pytest.mark.parametrize( + "coords, struct_axis, struct_span", + [((2, 2), 1, 5), ((3, 4), 0, 5), ((9, 0), 0, 5), (((1, 2), (3, 4)), 1, 5)], +) +def test_apply_struct_mask(coords, struct_axis, struct_span): + """Test the uniform_manipulate function. + + Ensures that the mask corresponds to the manipulated pixels, and that the + manipulated pixels have a value taken from a ROI surrounding them. + """ + struct_params = StructMaskParameters(axis=struct_axis, span=struct_span) + + # create array + patch = np.arange( + 100, + ).reshape((10, 10)) + + # make a copy of the original patch for comparison + original_patch = patch.copy() + coords = np.array(coords) + + # expand the coords if only one roi is given + if coords.ndim == 1: + coords = coords[None, :] + + # manipulate the array + transform_patch = _apply_struct_mask( + patch, + coords=coords, + struct_params=struct_params, + ) + changed_values = patch[np.where(original_patch != transform_patch)] + + # check that the transformed pixels correspond to the masked pixels + transformed = [] + axis = 1 - struct_axis + for i in range(coords.shape[0]): + # get indices to mask + indices_to_mask = [ + c + for c in range( + max(0, coords[i, axis] - struct_span // 2), + min(transform_patch.shape[1], coords[i, axis] + struct_span // 2) + 1, + ) + if c != coords[i, axis] + ] + + # add to transform + if struct_axis == 0: + transformed.append(transform_patch[coords[i, 0]][indices_to_mask]) + else: + transformed.append(transform_patch[:, coords[i, 1]][indices_to_mask]) + + assert np.array_equal( + np.sort(changed_values), np.sort(np.concatenate(transformed, axis=0)) + ) + + +@pytest.mark.parametrize( + "coords, struct_axis, struct_span", + [ + ((1, 2, 2), 1, 5), + ((2, 3, 4), 0, 5), + ((0, 9, 0), 0, 5), + (((2, 1, 2), (1, 9, 0), (0, 3, 4)), 1, 5), + ], +) +def test_apply_struct_mask_3D(coords, struct_axis, struct_span): + """Test the uniform_manipulate function. + + Ensures that the mask corresponds to the manipulated pixels, and that the + manipulated pixels have a value taken from a ROI surrounding them. + """ + struct_params = StructMaskParameters(axis=struct_axis, span=struct_span) + + # create array + patch = np.arange( + 100 * 3, + ).reshape((3, 10, 10)) + + # make a copy of the original patch for comparison + original_patch = patch.copy() + coords = np.array(coords) + + # expand the coords if only one roi is given + if coords.ndim == 1: + coords = coords[None, :] + + # manipulate the array + transform_patch = _apply_struct_mask( + patch, + coords=coords, + struct_params=struct_params, + ) + changed_values = patch[np.where(original_patch != transform_patch)] + + # check that the transformed pixels correspond to the masked pixels + transformed = [] + axis = -2 + 1 - struct_axis + for i in range(coords.shape[0]): + # get indices to mask + indices_to_mask = [ + c + for c in range( + max(0, coords[i, axis] - struct_span // 2), + min(transform_patch.shape[1] - 1, coords[i, axis] + struct_span // 2) + + 1, + ) + if c != coords[i, axis] + ] + + # add to transform + if struct_axis == 0: + transformed.append( + transform_patch[coords[i, 0], coords[i, 1]][indices_to_mask] + ) + else: + transformed.append( + transform_patch[coords[i, 0], :, coords[i, 2]][indices_to_mask] + ) + + assert np.array_equal( + np.sort(changed_values), np.sort(np.concatenate(transformed, axis=0)) + ) diff --git a/tests/transforms/test_supported_transforms.py b/tests/transforms/test_supported_transforms.py new file mode 100644 index 00000000..4026f0f2 --- /dev/null +++ b/tests/transforms/test_supported_transforms.py @@ -0,0 +1,8 @@ +from careamics.config.support import SupportedTransform +from careamics.transforms import get_all_transforms + + +def test_supported_transforms_in_accepted_transforms(): + """Test that all the supported transforms are in the accepted transforms.""" + for transform in SupportedTransform: + assert transform in get_all_transforms() diff --git a/tests/transforms/test_xy_random_rotate90.py b/tests/transforms/test_xy_random_rotate90.py new file mode 100644 index 00000000..cea51b87 --- /dev/null +++ b/tests/transforms/test_xy_random_rotate90.py @@ -0,0 +1,117 @@ +import numpy as np +import pytest + +from careamics.transforms import XYRandomRotate90 + + +def test_randomness(ordered_array): + """Test randomness of the flipping using the `p` parameter.""" + # create array + array = ordered_array((1, 2, 2, 1)) + + # create augmentation that never applies + aug = XYRandomRotate90(p=0.0) + + # apply augmentation + augmented = aug(image=array)["image"] + assert np.array_equal(augmented, array) + + # create augmentation that always applies + aug = XYRandomRotate90(p=1.0) + + # apply augmentation + augmented = aug(image=array)["image"] + assert not np.array_equal(augmented, array) + + +@pytest.mark.parametrize( + "shape", + [ + # 2D + (2, 2, 1), + (2, 2, 2), + # 3D + (2, 2, 2, 1), + (2, 2, 2, 2), + ], +) +def test_xy_rotate(ordered_array, shape): + """Test rotation for 2D and 3D arrays.""" + np.random.seed(42) + + # create array + array: np.ndarray = ordered_array(shape) + + # create augmentation + is_3D = len(shape) == 4 + aug = XYRandomRotate90(p=1, is_3D=is_3D) + + # potential rotations + axes = (1, 2) if is_3D else (0, 1) + rots = [ + np.rot90(array, k=1, axes=axes), + np.rot90(array, k=2, axes=axes), + np.rot90(array, k=3, axes=axes), + ] + + # apply augmentation 10 times + augs = [] + for _ in range(10): + augmented = aug(image=array)["image"] + + # check that the augmented array is one of the potential rots + which_number = [np.array_equal(augmented, rot) for rot in rots] + + assert any(which_number) + augs.append(which_number.index(True)) + + # check that all rots were applied (indices of rots) + assert set(augs) == {0, 1, 2} + + +def test_mask_rotate(ordered_array): + """Test rotating masks in 3D.""" + np.random.seed(42) + + # create array + array: np.ndarray = ordered_array((2, 2, 2, 4)) + mask = array[..., 2:] + array = array[..., :2] + + # create augmentation + is_3D = len(array.shape) == 4 + aug = XYRandomRotate90(p=1, is_3D=is_3D) + + # potential rotations + axes = (1, 2) + array_rots = [ + np.rot90(array, k=1, axes=axes), + np.rot90(array, k=2, axes=axes), + np.rot90(array, k=3, axes=axes), + ] + mask_rots = [ + np.rot90(mask, k=1, axes=axes), + np.rot90(mask, k=2, axes=axes), + np.rot90(mask, k=3, axes=axes), + ] + + # apply augmentation 10 times + for _ in range(10): + augmented = aug(image=array, mask=mask) + aug_array = augmented["image"] + aug_mask = augmented["mask"] + + # check that the augmented array is one of the potential rots + which_number = [np.array_equal(aug_array, rot) for rot in array_rots] + + assert any(which_number) + img_n_rots = which_number.index(True) + + # same for the masks + which_number = [np.array_equal(aug_mask, rot) for rot in mask_rots] + + assert any(which_number) + mask_n_rots = which_number.index(True) + + # same rot for array and mask + assert img_n_rots == mask_n_rots diff --git a/tests/utils/test_base_enum.py b/tests/utils/test_base_enum.py new file mode 100644 index 00000000..a181d5b9 --- /dev/null +++ b/tests/utils/test_base_enum.py @@ -0,0 +1,12 @@ +from careamics.utils.base_enum import BaseEnum + + +class MyEnum(str, BaseEnum): + A = "a" + B = "b" + C = "c" + + +def test_base_enum(): + """Test that BaseEnum allows the `in` operator with values.""" + assert "b" in MyEnum diff --git a/tests/test_metrics.py b/tests/utils/test_metrics.py similarity index 59% rename from tests/test_metrics.py rename to tests/utils/test_metrics.py index 4d6c85d6..b91e6690 100644 --- a/tests/test_metrics.py +++ b/tests/utils/test_metrics.py @@ -2,7 +2,6 @@ import pytest from careamics.utils.metrics import ( - MetricTracker, _zero_mean, scale_invariant_psnr, ) @@ -29,24 +28,3 @@ def test_zero_mean(x): ) def test_scale_invariant_psnr(gt, pred, result): assert scale_invariant_psnr(gt, pred) == pytest.approx(result, rel=5e-3) - - -def test_metric_tracker(): - tracker = MetricTracker() - - # check initial state - assert tracker.sum == 0 - assert tracker.count == 0 - assert tracker.avg == 0 - assert tracker.val == 0 - - # run a few updates - n = 5 - for i in range(n): - tracker.update(i, n) - - # check values - assert tracker.sum == n * (n * (n - 1)) / 2 - assert tracker.count == n * n - assert tracker.avg == (n - 1) / 2 - assert tracker.val == n - 1 diff --git a/tests/utils/test_torch_utils.py b/tests/utils/test_torch_utils.py index 947a3685..3b9c28af 100644 --- a/tests/utils/test_torch_utils.py +++ b/tests/utils/test_torch_utils.py @@ -1,22 +1,19 @@ -import pytest -import torch +from torch import optim -from careamics.utils.torch_utils import ( - get_device, -) +from careamics.utils.torch_utils import get_optimizers, get_schedulers -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_get_device(device): - device = get_device() - assert isinstance(device, torch.device) - assert device.type == "cuda" if torch.cuda.is_available() else "cpu" +def test_get_schedulers_exist(): + """Test that the function `get_schedulers` return + existing torch schedulers. + """ + for scheduler in get_schedulers(): + assert hasattr(optim.lr_scheduler, scheduler) -# @pytest.mark.gpu -# @pytest.mark.parametrize("deterministic", [True, False]) -# @pytest.mark.parametrize("benchmark", [True, False]) -# def test_setup_cudnn_reproducibility(deterministic, benchmark): -# setup_cudnn_reproducibility(deterministic=deterministic, benchmark=benchmark) -# assert torch.backends.cudnn.deterministic == deterministic -# assert torch.backends.cudnn.benchmark == benchmark +def test_get_optimizers_exist(): + """Test that the function `get_optimizers` return + existing torch optimizers. + """ + for optimizer in get_optimizers(): + assert hasattr(optim, optimizer) diff --git a/tests/utils/test_validators.py b/tests/utils/test_validators.py deleted file mode 100644 index b20ee024..00000000 --- a/tests/utils/test_validators.py +++ /dev/null @@ -1,70 +0,0 @@ -import numpy as np -import pytest - -from careamics.utils import add_axes, check_axes_validity - - -@pytest.mark.parametrize( - "axes, valid", - [ - # Passing - ("yx", True), - ("Yx", True), - ("Zyx", True), - ("TzYX", True), - ("SZYX", True), - # Failing due to order - ("XY", False), - ("YXZ", False), - ("YXT", False), - ("ZTYX", False), - # too few axes - ("", False), - ("X", False), - # too many axes - ("STZYX", False), - # no yx axes - ("ZT", False), - ("ZY", False), - # unsupported axes or axes pair - ("STYX", False), - ("CYX", False), - # repeating characters - ("YYX", False), - ("YXY", False), - # invalid characters - ("YXm", False), - ("1YX", False), - ], -) -def test_are_axes_valid(axes, valid): - """Test if axes are valid""" - if valid: - check_axes_validity(axes) - else: - with pytest.raises((ValueError, NotImplementedError)): - check_axes_validity(axes) - - -@pytest.mark.parametrize( - "axes, input_shape, output_shape", - [ - ("YX", (8, 8), (1, 1, 8, 8)), - ("YX", (1, 1, 8, 8), (1, 1, 8, 8)), - ("ZYX", (8, 8, 8), (1, 1, 8, 8, 8)), - ("ZYX", (1, 1, 8, 8, 8), (1, 1, 8, 8, 8)), - ("SYX", (2, 8, 8), (2, 1, 8, 8)), - ("SYX", (2, 1, 8, 8), (2, 1, 8, 8)), - ("SZYX", (2, 8, 8, 8), (2, 1, 8, 8, 8)), - ("SZYX", (2, 1, 8, 8, 8), (2, 1, 8, 8, 8)), - ("TZYX", (2, 8, 8, 8), (2, 1, 8, 8, 8)), - ("TZYX", (2, 1, 8, 8, 8), (2, 1, 8, 8, 8)), - ], -) -def test_add_axes(axes, input_shape, output_shape): - """Test that axes are added correctly.""" - input_array = np.zeros(input_shape) - output = add_axes(input_array, axes) - - # check shape - assert output.shape == output_shape diff --git a/tests/utils/test_wandb.py b/tests/utils/test_wandb.py index 97019c97..0e293f36 100644 --- a/tests/utils/test_wandb.py +++ b/tests/utils/test_wandb.py @@ -1,27 +1,19 @@ -from pathlib import Path -from unittest import mock +# @mock.patch("careamics.utils.wandb.wandb") +# def test_wandb_logger(wandb, tmp_path: Path, minimum_config: dict): +# config = Configuration(**minimum_config) +# logger = WandBLogging( +# experiment_name="test", log_path=tmp_path, config=config, model_to_watch=None +# ) -from careamics.config import Configuration -from careamics.engine import Engine -from careamics.utils.wandb import WandBLogging +# logger.log_metrics({"acc": 0.0}) +# wandb.init().log.assert_called_once_with({"acc": 0.0}, commit=True) -@mock.patch("careamics.utils.wandb.wandb") -def test_wandb_logger(wandb, tmp_path: Path, minimum_config: dict): - config = Configuration(**minimum_config) - logger = WandBLogging( - experiment_name="test", log_path=tmp_path, config=config, model_to_watch=None - ) - - logger.log_metrics({"acc": 0.0}) - wandb.init().log.assert_called_once_with({"acc": 0.0}, commit=True) - - -@mock.patch("careamics.utils.wandb.wandb") -def test_wandb_logger_engine(wandb, minimum_config: dict): - config = Configuration(**minimum_config) - config.training.use_wandb = True - engine = Engine(config=config) - if engine.use_wandb: - assert engine.wandb is not None - assert wandb.run +# @mock.patch("careamics.utils.wandb.wandb") +# def test_wandb_logger_engine(wandb, minimum_config: dict): +# config = Configuration(**minimum_config) +# config.training.use_wandb = True +# engine = Engine(config=config) +# if engine.use_wandb: +# assert engine.wandb is not None +# assert wandb.run