Skip to content

Commit

Permalink
FIX: remove warnings with recent versions of pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
balbasty committed Jun 22, 2022
1 parent 5ac4bd5 commit 6b07fc8
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions interpol/nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def pull(inp, grid, bound: List[Bound], spline: List[Spline],
out1 = inp.gather(-1, idx)

# apply sign
sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)])
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
sign1: Optional[Tensor] = make_sign(sign0)
if sign1 is not None:
out1 *= sign1

Expand Down Expand Up @@ -183,7 +184,8 @@ def push(inp, grid, shape: Optional[List[int]], bound: List[Bound],
out1 = inp.clone()

# apply sign
sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)])
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
sign1: Optional[Tensor] = make_sign(sign0)
if sign1 is not None:
out1 *= sign1

Expand Down Expand Up @@ -247,7 +249,8 @@ def grad(inp, grid, bound: List[Bound], spline: List[Spline],
out0 = inp.gather(-1, idx)

# apply sign
sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)])
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
sign1: Optional[Tensor] = make_sign(sign0)
if sign1 is not None:
out0 *= sign1

Expand Down Expand Up @@ -319,7 +322,8 @@ def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound],
out0 = inp.clone()

# apply sign
sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)])
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
sign1: Optional[Tensor] = make_sign(sign0)
if sign1 is not None:
out0 *= sign1.unsqueeze(-1)

Expand Down Expand Up @@ -390,7 +394,8 @@ def hess(inp, grid, bound: List[Bound], spline: List[Spline],
out1 = inp.gather(-1, idx)

# apply sign
sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)])
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
sign1: Optional[Tensor] = make_sign(sign0)
if sign1 is not None:
out1 *= sign1

Expand Down

0 comments on commit 6b07fc8

Please sign in to comment.