Skip to content

Commit

Permalink
Fix bug: noise model not training on GPU (#360)
Browse files Browse the repository at this point in the history
### Description
Fixes a bug where the noise model was unintentionally moved to the CPU
before training, and thus did not train on GPU.

- **What**: When `_set_model_mode` was called, the model was moved to
the CPU by default, which caused training to happen on the CPU instead
of the GPU.

### Changes Made
- **Modified**: Removed device handling from `_set_model_mode`, moved
final move to CPU to the end of the `fit` function.


---

**Please ensure your PR meets the following requirements:**

- [ ] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [ ] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

Co-authored-by: Joran Deschamps <[email protected]>
  • Loading branch information
veegalinova and jdeschamps authored Jan 20, 2025
1 parent aad0c19 commit cf3a094
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/careamics/models/lvae/noise_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class GaussianMixtureNoiseModel(nn.Module):
# TODO training a NM relies on getting a clean data(N2V e.g,)
def __init__(self, config: GaussianMixtureNMConfig) -> None:
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cpu")

if config.path is not None:
params = np.load(config.path)
Expand Down Expand Up @@ -319,10 +319,8 @@ def _set_model_mode(self, mode: str) -> None:
"""Move parameters to the device and set weights' requires_grad depending on the mode"""
if mode == "train":
self.weight.requires_grad = True
self.to_device(self.device)
else:
self.weight.requires_grad = False
self.to_device(torch.device("cpu"))

def polynomial_regressor(
self, weight_params: torch.Tensor, signals: torch.Tensor
Expand Down Expand Up @@ -548,6 +546,8 @@ def fit(
Upper percentile for clipping. Default is 100.
"""
self._set_model_mode(mode="train")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to_device(device)
optimizer = torch.optim.Adam([self.weight], lr=learning_rate)

sig_obs_pairs = self.get_signal_observation_pairs(
Expand Down Expand Up @@ -589,6 +589,7 @@ def fit(
counter += 1

self._set_model_mode(mode="prediction")
self.to_device(torch.device("cpu"))
print("===================\n")
return train_losses

Expand Down

0 comments on commit cf3a094

Please sign in to comment.