diff --git a/equilib/equi2pers/numpy.py b/equilib/equi2pers/numpy.py index ae43b846..dbec2f93 100644 --- a/equilib/equi2pers/numpy.py +++ b/equilib/equi2pers/numpy.py @@ -131,6 +131,7 @@ def run( z_down: bool, mode: str, override_func: Optional[Callable[[], Any]] = None, + clip_output: bool = True, ) -> np.ndarray: """Run Equi2Pers @@ -224,7 +225,7 @@ def run( out = ( out.astype(equi_dtype) - if equi_dtype == np.dtype(np.uint8) + if equi_dtype == np.dtype(np.uint8) or not clip_output else np.clip(out, 0.0, 1.0) ) diff --git a/equilib/equi2pers/torch.py b/equilib/equi2pers/torch.py index 1de02662..5ebf0923 100644 --- a/equilib/equi2pers/torch.py +++ b/equilib/equi2pers/torch.py @@ -117,6 +117,7 @@ def run( z_down: bool, mode: str, backend: str = "native", + clip_output: bool = True, ) -> torch.Tensor: """Run Equi2Pers @@ -242,7 +243,7 @@ def run( out = ( out.type(equi_dtype) - if equi_dtype == torch.uint8 + if equi_dtype == torch.uint8 or not clip_output else torch.clip(out, 0.0, 1.0) )