Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] aten::max_pool3d_with_indices not implemented #326

Open
conradkun opened this issue Dec 12, 2024 · 5 comments
Open

[BUG] aten::max_pool3d_with_indices not implemented #326

conradkun opened this issue Dec 12, 2024 · 5 comments
Assignees
Labels
bug Something isn't working external issue External issues to be monitored occasionally
Milestone

Comments

@conradkun
Copy link
Contributor

Describe the bug

Unfortunately, I was able to reproduce the mps bug where aten::max_pool3d_with_indices is missing on macOS GPU acceleration. Indeed, this operation is shown as "In Process" in the PyTorch mps Ops tracker, meaning it is not implemented yet. This error seems to (understandably) only pop up when trying to process 3D images using mps on macOS.

To Reproduce

I created a new environment from scratch, following the guidelines in the CAREamics installation guide.

Then I just tried to train on a 3D image, albeit one with a custom file extension (and hence custom read_npy function). I'll be happy to try with a .tiff file if you point me to one.

Code snippet allowing reproducing the behaviour:

config = create_n2v_configuration(
        experiment_name="test",
        data_type="custom",
        axes="ZYX",
        batch_size=32,
        patch_size=[8, 16, 16],
        num_epochs=5,
        logger="tensorboard"
    )

data_module = TrainDataModule(
    data_config=config.data_config,
    train_data=file_path,  
    read_source_func=read_npy_custom,  
    extension_filter="*.ims",
)

engine = CAREamist(config)
engine.train(datamodule=data_module)

Error message:

{
	"name": "NotImplementedError",
	"message": "The operator 'aten::max_pool3d_with_indices' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.",
	"stack": "---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[32], line 18
     16 engine = CAREamist(config)
     17 start_time = time()
---> 18 engine.train(datamodule=data_module)
     19 training_elapsed = time() - start_time

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/careamics/careamist.py:333, in CAREamist.train(self, datamodule, train_source, val_source, train_target, val_target, use_in_memory, val_percentage, val_minimum_split)
    331 # train
    332 if datamodule is not None:
--> 333     self._train_on_datamodule(datamodule=datamodule)
    335 else:
    336     # raise error if target is provided to N2V
    337     if self.cfg.algorithm_config.algorithm == SupportedAlgorithm.N2V.value:

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/careamics/careamist.py:410, in CAREamist._train_on_datamodule(self, datamodule)
    407 self.trainer.limit_val_batches = 1.0  # 100%
    409 # train
--> 410 self.trainer.fit(self.model, datamodule=datamodule)

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    536 self.state.status = TrainerStatus.RUNNING
    537 self.training = True
--> 538 call._call_and_handle_interrupt(
    539     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    540 )

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     45     if trainer.strategy.launcher is not None:
     46         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47     return trainer_fn(*args, **kwargs)
     49 except _TunerExitException:
     50     _call_teardown_hook(trainer)

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    567 assert self.state.fn is not None
    568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    569     self.state.fn,
    570     ckpt_path,
    571     model_provided=True,
    572     model_connected=self.lightning_module is not None,
    573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
    576 assert self.state.stopped
    577 self.training = False

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
    976 self._signal_connector.register_signal_handlers()
    978 # ----------------------------
    979 # RUN THE TRAINER
    980 # ----------------------------
--> 981 results = self._run_stage()
    983 # ----------------------------
    984 # POST-Training CLEAN UP
    985 # ----------------------------
    986 log.debug(f\"{self.__class__.__name__}: trainer tearing down\")

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1023, in Trainer._run_stage(self)
   1021 if self.training:
   1022     with isolate_rng():
-> 1023         self._run_sanity_check()
   1024     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
   1025         self.fit_loop.run()

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1052, in Trainer._run_sanity_check(self)
   1049 call._call_callback_hooks(self, \"on_sanity_check_start\")
   1051 # run eval step
-> 1052 val_loop.run()
   1054 call._call_callback_hooks(self, \"on_sanity_check_end\")
   1056 # reset logger connector

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:178, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    176     context_manager = torch.no_grad
    177 with context_manager():
--> 178     return loop_run(self, *args, **kwargs)

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:135, in _EvaluationLoop.run(self)
    133     self.batch_progress.is_last_batch = data_fetcher.done
    134     # run step hooks
--> 135     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
    136 except StopIteration:
    137     # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
    138     break

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:396, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
    390 hook_name = \"test_step\" if trainer.testing else \"validation_step\"
    391 step_args = (
    392     self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
    393     if not using_dataloader_iter
    394     else (dataloader_iter,)
    395 )
--> 396 output = call._call_strategy_hook(trainer, hook_name, *step_args)
    398 self.batch_progress.increment_processed()
    400 if using_dataloader_iter:
    401     # update the hook kwargs now that the step method might have consumed the iterator

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:319, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    316     return None
    318 with trainer.profiler.profile(f\"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}\"):
--> 319     output = fn(*args, **kwargs)
    321 # restore current_fx when nested context
    322 pl_module._current_fx_name = prev_fx_name

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:411, in Strategy.validation_step(self, *args, **kwargs)
    409 if self.model != self.lightning_module:
    410     return self._forward_redirection(self.model, self.lightning_module, \"validation_step\", *args, **kwargs)
--> 411 return self.lightning_module.validation_step(*args, **kwargs)

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/careamics/lightning/lightning_module.py:139, in FCNModule.validation_step(self, batch, batch_idx)
    129 \"\"\"Validation step.
    130 
    131 Parameters
   (...)
    136     Batch index.
    137 \"\"\"
    138 x, *aux = batch
--> 139 out = self.model(x)
    140 val_loss = self.loss_func(out, *aux)
    142 # log validation loss

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/careamics/models/unet.py:439, in UNet.forward(self, x)
    425 def forward(self, x: torch.Tensor) -> torch.Tensor:
    426     \"\"\"
    427     Forward pass.
    428 
   (...)
    437         Output of the model.
    438     \"\"\"
--> 439     encoder_features = self.encoder(x)
    440     x = self.decoder(*encoder_features)
    441     x = self.final_conv(x)

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/careamics/models/unet.py:124, in UnetEncoder.forward(self, x)
    122 encoder_features = []
    123 for module in self.encoder_blocks:
--> 124     x = module(x)
    125     if isinstance(module, Conv_Block):
    126         encoder_features.append(x)

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/nn/modules/pooling.py:296, in MaxPool3d.forward(self, input)
    295 def forward(self, input: Tensor):
--> 296     return F.max_pool3d(
    297         input,
    298         self.kernel_size,
    299         self.stride,
    300         self.padding,
    301         self.dilation,
    302         ceil_mode=self.ceil_mode,
    303         return_indices=self.return_indices,
    304     )

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/_jit_internal.py:624, in boolean_dispatch.<locals>.fn(*args, **kwargs)
    622     return if_true(*args, **kwargs)
    623 else:
--> 624     return if_false(*args, **kwargs)

File ~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/nn/functional.py:920, in _max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices)
    918 if stride is None:
    919     stride = torch.jit.annotate(List[int], [])
--> 920 return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)

NotImplementedError: The operator 'aten::max_pool3d_with_indices' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS."
}

System

Important

  • OS: macOS 13.6.2
  • Python version 3.10.15
  • PyTorch version 2.5.1
  • PyTorch lightning version 2.4.0
  • CAREamics version 0.0.4.2

Environment

channels:
  - pytorch
  - defaults
  - conda-forge
dependencies:
  - blas=1.0=openblas
  - brotli-python=1.0.9=py310h313beb8_8
  - bzip2=1.0.8=h80987f9_6
  - ca-certificates=2024.11.26=hca03da5_0
  - certifi=2024.8.30=py310hca03da5_0
  - charset-normalizer=3.3.2=pyhd3eb1b0_0
  - filelock=3.13.1=py310hca03da5_0
  - freetype=2.12.1=h1192e45_0
  - giflib=5.2.2=h80987f9_0
  - gmp=6.2.1=hc377ac9_3
  - gmpy2=2.1.2=py310h8c48613_0
  - idna=3.7=py310hca03da5_0
  - jinja2=3.1.4=py310hca03da5_1
  - jpeg=9e=h80987f9_3
  - lcms2=2.12=hba8e193_0
  - lerc=3.0=hc377ac9_0
  - libcxx=14.0.6=h848a8c0_0
  - libdeflate=1.17=h80987f9_1
  - libffi=3.4.4=hca03da5_1
  - libgfortran=5.0.0=11_3_0_hca03da5_28
  - libgfortran5=11.3.0=h009349e_28
  - libjpeg-turbo=2.0.0=h1a28f6b_0
  - libopenblas=0.3.21=h269037a_0
  - libpng=1.6.39=h80987f9_0
  - libtiff=4.5.1=h313beb8_0
  - libwebp=1.3.2=ha3663a8_0
  - libwebp-base=1.3.2=h80987f9_1
  - llvm-openmp=14.0.6=hc6e5704_0
  - lz4-c=1.9.4=h313beb8_1
  - markupsafe=2.1.3=py310h80987f9_0
  - mpc=1.1.0=h8c48613_1
  - mpfr=4.0.2=h695f6f0_1
  - mpmath=1.3.0=py310hca03da5_0
  - ncurses=6.4=h313beb8_0
  - networkx=3.2.1=py310hca03da5_0
  - openjpeg=2.5.2=h54b8e55_0
  - openssl=3.0.15=h80987f9_0
  - pillow=11.0.0=py310hfaf4e14_0
  - pip=24.2=py310hca03da5_0
  - pysocks=1.7.1=py310hca03da5_0
  - python=3.10.15=hb885b13_1
  - pytorch=2.5.1=py3.10_0
  - pyyaml=6.0.2=py310h80987f9_0
  - readline=8.2=h1a28f6b_0
  - requests=2.32.3=py310hca03da5_1
  - setuptools=75.1.0=py310hca03da5_0
  - sqlite=3.45.3=h80987f9_0
  - tk=8.6.14=h6ba3021_0
  - torchvision=0.20.1=py310_cpu
  - typing_extensions=4.11.0=py310hca03da5_0
  - urllib3=2.2.3=py310hca03da5_0
  - wheel=0.44.0=py310hca03da5_0
  - xz=5.4.6=h80987f9_1
  - yaml=0.2.5=h1a28f6b_0
  - zlib=1.2.13=h18a0788_1
  - zstd=1.5.6=hfb09047_0
  - pip:
      - absl-py==2.1.0
      - aiohappyeyeballs==2.4.4
      - aiohttp==3.11.10
      - aiosignal==1.3.1
      - annotated-types==0.7.0
      - anyio==4.7.0
      - appnope==0.1.4
      - argon2-cffi==23.1.0
      - argon2-cffi-bindings==21.2.0
      - arrow==1.3.0
      - asciitree==0.3.3
      - asttokens==3.0.0
      - async-lru==2.0.4
      - async-timeout==5.0.1
      - attrs==24.2.0
      - babel==2.16.0
      - beautifulsoup4==4.12.3
      - bioimageio-core==0.7.0
      - bioimageio-spec==0.5.3.5
      - bleach==6.2.0
      - blinker==1.9.0
      - careamics==0.0.4.2
      - careamics-portfolio==0.0.14
      - cffi==1.17.1
      - click==8.1.7
      - comm==0.2.2
      - contourpy==1.3.1
      - cycler==0.12.1
      - dash==2.18.2
      - dash-core-components==2.0.0
      - dash-html-components==2.0.0
      - dash-table==5.0.0
      - debugpy==1.8.9
      - decorator==5.1.1
      - defusedxml==0.7.1
      - distro==1.9.0
      - dnspython==2.7.0
      - email-validator==2.2.0
      - exceptiongroup==1.2.2
      - executing==2.1.0
      - fasteners==0.19
      - fastjsonschema==2.21.1
      - flask==3.0.3
      - fonttools==4.55.2
      - fqdn==1.5.1
      - frozenlist==1.5.0
      - fsspec==2024.10.0
      - grpcio==1.68.1
      - h11==0.14.0
      - h5py==3.12.1
      - httpcore==1.0.7
      - httpx==0.28.1
      - imageio==2.36.1
      - importlib-metadata==8.5.0
      - ipykernel==6.29.5
      - ipython==8.30.0
      - ipywidgets==8.1.5
      - isoduration==20.11.0
      - itsdangerous==2.2.0
      - jedi==0.19.2
      - json5==0.10.0
      - jsonpointer==3.0.0
      - jsonschema==4.23.0
      - jsonschema-specifications==2024.10.1
      - jupyter==1.1.1
      - jupyter-client==8.6.3
      - jupyter-console==6.6.3
      - jupyter-core==5.7.2
      - jupyter-events==0.10.0
      - jupyter-lsp==2.2.5
      - jupyter-server==2.14.2
      - jupyter-server-terminals==0.5.3
      - jupyterlab==4.3.3
      - jupyterlab-pygments==0.3.0
      - jupyterlab-server==2.27.3
      - jupyterlab-widgets==3.0.13
      - kiwisolver==1.4.7
      - lazy-loader==0.4
      - lightning-utilities==0.11.9
      - loguru==0.7.3
      - markdown==3.7
      - markdown-it-py==3.0.0
      - matplotlib==3.9.3
      - matplotlib-inline==0.1.7
      - mdurl==0.1.2
      - mistune==3.0.2
      - multidict==6.1.0
      - nbclient==0.10.1
      - nbconvert==7.16.4
      - nbformat==5.10.4
      - nest-asyncio==1.6.0
      - notebook==7.3.1
      - notebook-shim==0.2.4
      - numcodecs==0.13.1
      - numpy==1.26.4
      - overrides==7.7.0
      - packaging==24.2
      - pandas==2.2.3
      - pandocfilters==1.5.1
      - parso==0.8.4
      - pexpect==4.9.0
      - platformdirs==4.3.6
      - plotly==5.24.1
      - pooch==1.8.2
      - prometheus-client==0.21.1
      - prompt-toolkit==3.0.48
      - propcache==0.2.1
      - protobuf==3.20.3
      - psutil==6.1.0
      - ptyprocess==0.7.0
      - pure-eval==0.2.3
      - pycparser==2.22
      - pydantic==2.8.2
      - pydantic-core==2.20.1
      - pydantic-settings==2.6.1
      - pygments==2.18.0
      - pyparsing==3.2.0
      - python-dateutil==2.9.0.post0
      - python-dotenv==1.0.1
      - python-json-logger==2.0.7
      - pytorch-lightning==2.4.0
      - pytz==2024.2
      - pyzmq==26.2.0
      - referencing==0.35.1
      - retrying==1.3.4
      - rfc3339-validator==0.1.4
      - rfc3986-validator==0.1.1
      - rich==13.9.4
      - rpds-py==0.22.3
      - ruyaml==0.91.0
      - scikit-image==0.23.2
      - scipy==1.14.1
      - send2trash==1.8.3
      - shellingham==1.5.4
      - six==1.17.0
      - sniffio==1.3.1
      - soupsieve==2.6
      - stack-data==0.6.3
      - sympy==1.13.1
      - tenacity==9.0.0
      - tensorboard==2.18.0
      - tensorboard-data-server==0.7.2
      - terminado==0.18.1
      - tifffile==2024.9.20
      - tinycss2==1.4.0
      - tomli==2.2.1
      - torchmetrics==1.6.0
      - tornado==6.4.2
      - tqdm==4.67.1
      - traitlets==5.14.3
      - typer==0.12.3
      - types-python-dateutil==2.9.0.20241206
      - tzdata==2024.2
      - uri-template==1.3.0
      - wcwidth==0.2.13
      - webcolors==24.11.1
      - webencodings==0.5.1
      - websocket-client==1.8.0
      - werkzeug==3.0.6
      - widgetsnbextension==4.0.13
      - xarray==2024.11.0
      - yarl==1.18.3
      - zarr==2.18.3
      - zipp==3.21.0
@conradkun conradkun added the bug Something isn't working label Dec 12, 2024
@conradkun conradkun changed the title [BUG] [BUG] aten::max_pool3d_with_indices not implemented Dec 12, 2024
@jdeschamps
Copy link
Member

jdeschamps commented Dec 12, 2024

Hi Conrad!

Thanks a lot for giving it a go. It never clicked in my head that it was the max pool 3D, and we probably tested it using a 2D dataset...

We will need to figure out a work around...

In the error it states "As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op.".

Have you tried this?

conda env config vars set PYTORCH_ENABLE_MPS_FALLBACK=1
conda activate <test-env>

@jdeschamps jdeschamps added this to the v0.0.6 milestone Dec 12, 2024
@conradkun
Copy link
Contributor Author

Hi Joran,

I haven't tried it, but because it kind of defeats the purpose... Obviously it is possible to run everything on CPU; what I'm interested in though is the ballpark of the time it would take to process different sized images on macOS.

I took a quick look at the PyTorch issue for maxpool_3d_with_indices. Maybe I muster the motivation to give it a go over the holidays...

@jdeschamps
Copy link
Member

jdeschamps commented Dec 17, 2024

My understanding of the flag is that it only runs that single operation on CPU (or any operation that is still not supported), not the whole training. I am not familiar with the mps architecture, so I don't know how much overhead there is in transferring tensors between the silicon device and the cpu on macOS so often.

It is on our todo list to test it out quantitatively, but the end of the year is a bit tough...

@jdeschamps jdeschamps added the external issue External issues to be monitored occasionally label Dec 17, 2024
@melisande-c
Copy link
Member

Hi @conradkun,

I have just taken the time to do a quick comparison of running 3D N2V on mac with only CPU vs MPS with PYTORCH_ENABLE_MPS_FALLBACK=1.

To do this I ran the following script (use pip install "careamics[examples]" to have access to careamics-portfolio where some example data can be found):

from pathlib import Path
import timeit

import tifffile

from careamics import CAREamist
from careamics.config import create_n2v_configuration
from careamics_portfolio import PortfolioManager

# instantiate data portfolio manage
portfolio = PortfolioManager()

# and download the data
root_path = Path("./data")
file = portfolio.denoising.Flywing.download(root_path)

train_image = tifffile.imread(file[0])

n_epochs = 1
config = create_n2v_configuration(
    experiment_name="flywing_n2v",
    data_type="array",
    axes="ZYX",
    patch_size=(16, 64, 64),
    batch_size=2,
    num_epochs=n_epochs,
    augmentations=[],  # remove augmentations
)

# train
n_repeats = 10
total_train_time = timeit.timeit(
    lambda: CAREamist(source=config).train(
        train_source=train_image,
        val_percentage=0.0,
        val_minimum_split=10,  # use 10 patches as validation
    ),
    number=n_repeats,
)
average_train_time = total_train_time/n_repeats

print(f"Average time taken to train {n_epochs} epoch(s): {average_train_time :.2f}s")

For CPU only the print out was

Average time taken to train 1 epoch(s): 541.62s

For MPS with PYTORCH_ENABLE_MPS_FALLBACK=1 the print out was

Average time taken to train 1 epoch(s): 133.83s

This is still about a 4-fold speed-up so the overhead of transferring the tensors between the GPU and CPU doesn't seem too bad. Hopefully the MPS implementation of max_pool3d_with_indices is completed quickly, but in the meantime this compromise still provides a decent speed-up.

Let me know if you have any questions 😊

Specs
Mac: M1, 2020
Chip: Apple M1
Memory: 16 GB

@jdeschamps
Copy link
Member

And if anyone runs into that issue as well, and stumble upon that message, please thumb up the following message: pytorch/pytorch#77764 (comment)

In the off-chance that it will push the operation up the priority list!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working external issue External issues to be monitored occasionally
Projects
None yet
Development

No branches or pull requests

5 participants