diff --git a/CHANGELOG.md b/CHANGELOG.md index a253f90f6ef..5d5c9554ef0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Delete `Device2Host` caused by comm with device and host ([#2840](https://github.com/PyTorchLightning/metrics/pull/2840)) --- diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 3c5a840efa1..991606e4b25 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -182,8 +182,14 @@ def _binary_precision_recall_curve_format( preds = preds[idx] target = target[idx] - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + # "sigmoid_cpu" not implemented for 'Half' + if preds.dtype != torch.float16 or preds.device != torch.device("cpu"): + out_of_bounds = (preds < 0) | (preds > 1) + out_of_bounds = out_of_bounds.any() + preds = torch.where(out_of_bounds, preds.sigmoid(), preds) + else: + if not torch.all((preds >= 0) * (preds <= 1)): + preds = preds.sigmoid() thresholds = _adjust_threshold_arg(thresholds, preds.device) return preds, target, thresholds @@ -761,8 +767,15 @@ def _multilabel_precision_recall_curve_format( """ preds = preds.transpose(0, 1).reshape(num_labels, -1).T target = target.transpose(0, 1).reshape(num_labels, -1).T - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + + # "sigmoid_cpu" not implemented for 'Half' + if preds.dtype != torch.float16 or preds.device != torch.device("cpu"): + out_of_bounds = (preds < 0) | (preds > 1) + out_of_bounds = out_of_bounds.any() + preds = torch.where(out_of_bounds, preds.sigmoid(), preds) + else: + if not torch.all((preds >= 0) * (preds <= 1)): + preds = preds.sigmoid() thresholds = _adjust_threshold_arg(thresholds, preds.device) if ignore_index is not None and thresholds is not None: diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 7c034c528e6..3772073536c 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -106,15 +106,27 @@ def test_binary_precision_recall_curve_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") - self.run_precision_test_cpu( - preds=preds, - target=target, - metric_module=BinaryPrecisionRecallCurve, - metric_functional=binary_precision_recall_curve, - metric_args={"thresholds": None}, - dtype=dtype, - ) + try: + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryPrecisionRecallCurve, + metric_functional=binary_precision_recall_curve, + metric_args={"thresholds": None}, + dtype=dtype, + ) + except Exception as e: + print(f"An unexpected error occurred: {e}") + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + else: + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryPrecisionRecallCurve, + metric_functional=binary_precision_recall_curve, + metric_args={"thresholds": None}, + dtype=dtype, + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.parametrize("dtype", [torch.half, torch.double])