diff --git a/equilib/equi2pers/torch.py b/equilib/equi2pers/torch.py index 1de02662..5ebf0923 100644 --- a/equilib/equi2pers/torch.py +++ b/equilib/equi2pers/torch.py @@ -117,6 +117,7 @@ def run( z_down: bool, mode: str, backend: str = "native", + clip_output: bool = True, ) -> torch.Tensor: """Run Equi2Pers @@ -242,7 +243,7 @@ def run( out = ( out.type(equi_dtype) - if equi_dtype == torch.uint8 + if equi_dtype == torch.uint8 or not clip_output else torch.clip(out, 0.0, 1.0) )