From bb9268b10d0640e16c87fce10a44746f3274d1ca Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:58:18 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9C=A8=20Add=20support=20for=20train?= =?UTF-8?q?ing=20on=20Apple=20M1/M2/M3=20(mps)=20devices.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also add log printing of device used for training. --- dacapo/compute_context/local_torch.py | 2 ++ dacapo/train.py | 1 + 2 files changed, 3 insertions(+) diff --git a/dacapo/compute_context/local_torch.py b/dacapo/compute_context/local_torch.py index 5a0371a43..045300790 100644 --- a/dacapo/compute_context/local_torch.py +++ b/dacapo/compute_context/local_torch.py @@ -60,6 +60,8 @@ def device(self): if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM return torch.device("cpu") return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") else: return torch.device("cpu") return torch.device(self._device) diff --git a/dacapo/train.py b/dacapo/train.py index 70b845db2..4e4101f8d 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -135,6 +135,7 @@ def train_run(run: Run, do_validate=True): compute_context = create_compute_context() run.model = run.model.to(compute_context.device) run.move_optimizer(compute_context.device) + logger.info(f"Training on {compute_context.device}") array_store = create_array_store() run.trainer.iteration = trained_until