diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index a1c3c52ed..0b9101107 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -461,7 +461,7 @@ def forward( labels = labels.to(self.dtype) # Set the initial parameters of flow close to truth # This speeds up training and helps with NaN - if self._initialized is False: + if (self._initialized is False) & (self.training): self._flow.init_params(data=deepcopy(labels).cpu()) self._flow.to(self.device) self._initialized = True # This is only done once