From 34c0d590b66bf97ebd25b23fd5873446e53c1927 Mon Sep 17 00:00:00 2001 From: Yael Balbastre Date: Fri, 17 Jun 2022 15:42:06 -0400 Subject: [PATCH] FIX(nd): try to remove warning in recent version of pytorch --- interpol/nd.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/interpol/nd.py b/interpol/nd.py index 59eebc8..232dcec 100644 --- a/interpol/nd.py +++ b/interpol/nd.py @@ -117,7 +117,8 @@ def pull(inp, grid, bound: List[Bound], spline: List[Spline], out1 = inp.gather(-1, idx) # apply sign - sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)]) + sign1 = make_sign([torch.jit.annotate(Optional[Tensor], sgn[n]) + for sgn, n in zip(signs, nodes)]) if sign1 is not None: out1 *= sign1 @@ -183,7 +184,8 @@ def push(inp, grid, shape: Optional[List[int]], bound: List[Bound], out1 = inp.clone() # apply sign - sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)]) + sign1 = make_sign([torch.jit.annotate(Optional[Tensor], sgn[n]) + for sgn, n in zip(signs, nodes)]) if sign1 is not None: out1 *= sign1 @@ -247,7 +249,8 @@ def grad(inp, grid, bound: List[Bound], spline: List[Spline], out0 = inp.gather(-1, idx) # apply sign - sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)]) + sign1 = make_sign([torch.jit.annotate(Optional[Tensor], sgn[n]) + for sgn, n in zip(signs, nodes)]) if sign1 is not None: out0 *= sign1 @@ -319,7 +322,8 @@ def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound], out0 = inp.clone() # apply sign - sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)]) + sign1 = make_sign([torch.jit.annotate(Optional[Tensor], sgn[n]) + for sgn, n in zip(signs, nodes)]) if sign1 is not None: out0 *= sign1.unsqueeze(-1) @@ -390,7 +394,8 @@ def hess(inp, grid, bound: List[Bound], spline: List[Spline], out1 = inp.gather(-1, idx) # apply sign - sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)]) + sign1 = make_sign([torch.jit.annotate(Optional[Tensor], sgn[n]) + for sgn, n in zip(signs, nodes)]) if sign1 is not None: out1 *= sign1 @@ -436,4 +441,4 @@ def hess(inp, grid, bound: List[Bound], spline: List[Spline], out.unbind(-1)[d2].unbind(-1)[d].copy_(out.unbind(-1)[d].unbind(-1)[d2]) out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-2:])) - return out \ No newline at end of file + return out