Skip to content

Commit

Permalink
FIX(nd): try to remove warning in recent version of pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
balbasty committed Jun 17, 2022
1 parent 5ac4bd5 commit 34c0d59
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions interpol/nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def pull(inp, grid, bound: List[Bound], spline: List[Spline],
out1 = inp.gather(-1, idx)

# apply sign
sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)])
sign1 = make_sign([torch.jit.annotate(Optional[Tensor], sgn[n])
for sgn, n in zip(signs, nodes)])
if sign1 is not None:
out1 *= sign1

Expand Down Expand Up @@ -183,7 +184,8 @@ def push(inp, grid, shape: Optional[List[int]], bound: List[Bound],
out1 = inp.clone()

# apply sign
sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)])
sign1 = make_sign([torch.jit.annotate(Optional[Tensor], sgn[n])
for sgn, n in zip(signs, nodes)])
if sign1 is not None:
out1 *= sign1

Expand Down Expand Up @@ -247,7 +249,8 @@ def grad(inp, grid, bound: List[Bound], spline: List[Spline],
out0 = inp.gather(-1, idx)

# apply sign
sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)])
sign1 = make_sign([torch.jit.annotate(Optional[Tensor], sgn[n])
for sgn, n in zip(signs, nodes)])
if sign1 is not None:
out0 *= sign1

Expand Down Expand Up @@ -319,7 +322,8 @@ def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound],
out0 = inp.clone()

# apply sign
sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)])
sign1 = make_sign([torch.jit.annotate(Optional[Tensor], sgn[n])
for sgn, n in zip(signs, nodes)])
if sign1 is not None:
out0 *= sign1.unsqueeze(-1)

Expand Down Expand Up @@ -390,7 +394,8 @@ def hess(inp, grid, bound: List[Bound], spline: List[Spline],
out1 = inp.gather(-1, idx)

# apply sign
sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)])
sign1 = make_sign([torch.jit.annotate(Optional[Tensor], sgn[n])
for sgn, n in zip(signs, nodes)])
if sign1 is not None:
out1 *= sign1

Expand Down Expand Up @@ -436,4 +441,4 @@ def hess(inp, grid, bound: List[Bound], spline: List[Spline],
out.unbind(-1)[d2].unbind(-1)[d].copy_(out.unbind(-1)[d].unbind(-1)[d2])

out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-2:]))
return out
return out

0 comments on commit 34c0d59

Please sign in to comment.