diff --git a/dmlcloud/metrics.py b/dmlcloud/metrics.py index 3a2644a..6d70a5e 100644 --- a/dmlcloud/metrics.py +++ b/dmlcloud/metrics.py @@ -231,7 +231,7 @@ def register_metric(self, name, reduction=None, dim=None, globally=True): def track(self, name, value): if isinstance(value, torch.Tensor): - value = value.detach().cpu() + value = value.detach().to('cpu', non_blocking=True) if name not in self: raise ValueError(f'Metric {name} does not exist')