Skip to content

Commit

Permalink
FIX(make_sign): workaround some weitd JIT bug + PERF(pull): use less …
Browse files Browse the repository at this point in the history
…inplace operations
  • Loading branch information
balbasty committed Jul 5, 2022
1 parent 3336f8a commit 9a9965e
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 105 deletions.
16 changes: 8 additions & 8 deletions interpol/iso0.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1):
out = inp.gather(-1, idx)
sign = make_sign([signx, signy])
if sign is not None:
out *= sign
out = out * sign
if mask is not None:
out *= mask
out = mask * mask
out = out.reshape(out.shape[:2] + oshape)
return out

Expand Down Expand Up @@ -198,9 +198,9 @@ def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound],
if sign is not None or mask is not None:
inp = inp.clone()
if sign is not None:
inp *= sign
inp = inp * sign
if mask is not None:
inp *= mask
inp = inp * mask
out.scatter_add_(-1, idx, inp)

out = out.reshape(out.shape[:2] + shape)
Expand Down Expand Up @@ -244,9 +244,9 @@ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
out = inp.gather(-1, idx)
sign = signx
if sign is not None:
out *= sign
out = out * sign
if mask is not None:
out *= mask
out = out * mask
out = out.reshape(out.shape[:2] + oshape)
return out

Expand Down Expand Up @@ -291,9 +291,9 @@ def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
if sign is not None or mask is not None:
inp = inp.clone()
if sign is not None:
inp *= sign
inp = inp * sign
if mask is not None:
inp *= mask
inp = inp * mask
out.scatter_add_(-1, idx, inp)

out = out.reshape(out.shape[:2] + shape)
Expand Down
143 changes: 72 additions & 71 deletions interpol/iso1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from .bounds import Bound
from .jit_utils import (sub2ind_list, make_sign,
inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d)
from typing import List, Optional
from typing import List, Tuple, Optional
Tensor = torch.Tensor


@torch.jit.script
def get_weights_and_indices(g, n: int, bound: Bound):
def get_weights_and_indices(g, n: int, bound: Bound) \
-> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
g0 = g.floor().long()
g1 = g0 + 1
sign1 = bound.transform(g1, n)
Expand Down Expand Up @@ -60,71 +61,71 @@ def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1):
out = inp.gather(-1, idx)
sign = make_sign([signx0, signy0, signz0])
if sign is not None:
out *= sign
out *= (1 - gx) * (1 - gy) * (1 - gz)
out = out * sign
out = out * ((1 - gx) * (1 - gy) * (1 - gz))
# - corner 001
idx = sub2ind_list([gx0, gy0, gz1], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
sign = make_sign([signx0, signy0, signz1])
if sign is not None:
out1 *= sign
out1 *= (1 - gx) * (1 - gy) * gz
out += out1
out1 = out1 * sign
out1 = out1 * ((1 - gx) * (1 - gy) * gz)
out = out + out1
# - corner 010
idx = sub2ind_list([gx0, gy1, gz0], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
sign = make_sign([signx0, signy1, signz0])
if sign is not None:
out1 *= sign
out1 *= (1 - gx) * gy * (1 - gz)
out += out1
out1 = out1 * sign
out1 = out1 * ((1 - gx) * gy * (1 - gz))
out = out + out1
# - corner 011
idx = sub2ind_list([gx0, gy1, gz1], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
sign = make_sign([signx0, signy1, signz1])
if sign is not None:
out1 *= sign
out1 *= (1 - gx) * gy * gz
out += out1
out1 = out1 * sign
out1 = out1 * ((1 - gx) * gy * gz)
out = out + out1
# - corner 100
idx = sub2ind_list([gx1, gy0, gz0], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
sign = make_sign([signx1, signy0, signz0])
if sign is not None:
out1 *= sign
out1 *= gx * (1 - gy) * (1 - gz)
out += out1
out1 = out1 * sign
out1 = out1 * (gx * (1 - gy) * (1 - gz))
out = out + out1
# - corner 101
idx = sub2ind_list([gx1, gy0, gz1], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
sign = make_sign([signx1, signy0, signz1])
if sign is not None:
out1 *= sign
out1 *= gx * (1 - gy) * gz
out += out1
out1 = out1 * sign
out1 = out1 * (gx * (1 - gy) * gz)
out = out + out1
# - corner 110
idx = sub2ind_list([gx1, gy1, gz0], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
sign = make_sign([signx1, signy1, signz0])
if sign is not None:
out1 *= sign
out1 *= gx * gy * (1 - gz)
out += out1
out1 = out1 * sign
out1 = out1 * (gx * gy * (1 - gz))
out = out + out1
# - corner 111
idx = sub2ind_list([gx1, gy1, gz1], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
sign = make_sign([signx1, signy1, signz1])
if sign is not None:
out1 *= sign
out1 *= gx * gy * gz
out += out1
out1 = out1 * sign
out1 = out1 * (gx * gy * gz)
out = out + out1

if mask is not None:
out *= mask
Expand Down Expand Up @@ -177,87 +178,87 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound],
out1 = inp.clone()
sign = make_sign([signx0, signy0, signz0])
if sign is not None:
out1 *= sign
out1 = out1 * sign
if mask is not None:
out1 *= mask
out1 *= (1 - gx) * (1 - gy) * (1 - gz)
out1 = out1 * mask
out1 = out1 * ((1 - gx) * (1 - gy) * (1 - gz))
out.scatter_add_(-1, idx, out1)
# - corner 001
idx = sub2ind_list([gx0, gy0, gz1], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.clone()
sign = make_sign([signx0, signy0, signz1])
if sign is not None:
out1 *= sign
out1 = out1 * sign
if mask is not None:
out1 *= mask
out1 *= (1 - gx) * (1 - gy) * gz
out1 = out1 * mask
out1 = out1 * ((1 - gx) * (1 - gy) * gz)
out.scatter_add_(-1, idx, out1)
# - corner 010
idx = sub2ind_list([gx0, gy1, gz0], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.clone()
sign = make_sign([signx0, signy1, signz0])
if sign is not None:
out1 *= sign
out1 = out1 * sign
if mask is not None:
out1 *= mask
out1 *= (1 - gx) * gy * (1 - gz)
out1 = out1 * mask
out1 = out1 * ((1 - gx) * gy * (1 - gz))
out.scatter_add_(-1, idx, out1)
# - corner 011
idx = sub2ind_list([gx0, gy1, gz1], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.clone()
sign = make_sign([signx0, signy1, signz1])
if sign is not None:
out1 *= sign
out1 = out1 * sign
if mask is not None:
out1 *= mask
out1 *= (1 - gx) * gy * gz
out1 = out1 * mask
out1 = out1 * ((1 - gx) * gy * gz)
out.scatter_add_(-1, idx, out1)
# - corner 100
idx = sub2ind_list([gx1, gy0, gz0], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.clone()
sign = make_sign([signx1, signy0, signz0])
if sign is not None:
out1 *= sign
out1 = out1 * sign
if mask is not None:
out1 *= mask
out1 *= gx * (1 - gy) * (1 - gz)
out1 = out1 * mask
out1 = out1 * (gx * (1 - gy) * (1 - gz))
out.scatter_add_(-1, idx, out1)
# - corner 101
idx = sub2ind_list([gx1, gy0, gz1], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.clone()
sign = make_sign([signx1, signy0, signz1])
if sign is not None:
out1 *= sign
out1 = out1 * sign
if mask is not None:
out1 *= mask
out1 *= gx * (1 - gy) * gz
out1 = out1 * mask
out1 = out1 * (gx * (1 - gy) * gz)
out.scatter_add_(-1, idx, out1)
# - corner 110
idx = sub2ind_list([gx1, gy1, gz0], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.clone()
sign = make_sign([signx1, signy1, signz0])
if sign is not None:
out1 *= sign
out1 = out1 * sign
if mask is not None:
out1 *= mask
out1 *= gx * gy * (1 - gz)
out1 = out1 * mask
out1 = out1 * (gx * gy * (1 - gz))
out.scatter_add_(-1, idx, out1)
# - corner 111
idx = sub2ind_list([gx1, gy1, gz1], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.clone()
sign = make_sign([signx1, signy1, signz1])
if sign is not None:
out1 *= sign
out1 = out1 * sign
if mask is not None:
out1 *= mask
out1 *= gx * gy * gz
out1 = out1 * mask
out1 = out1 * (gx * gy * gz)
out.scatter_add_(-1, idx, out1)

out = out.reshape(list(out.shape[:2]) + shape)
Expand Down Expand Up @@ -714,35 +715,35 @@ def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1):
out = inp.gather(-1, idx)
sign = make_sign([signx0, signy0])
if sign is not None:
out *= sign
out *= (1 - gx) * (1 - gy)
out = out * sign
out = out * ((1 - gx) * (1 - gy))
# - corner 01
idx = sub2ind_list([gx0, gy1], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
sign = make_sign([signx0, signy1])
if sign is not None:
out1 *= sign
out1 *= (1 - gx) * gy
out += out1
out1 = out1 * sign
out1 = out1 * ((1 - gx) * gy)
out = out + out1
# - corner 10
idx = sub2ind_list([gx1, gy0], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
sign = make_sign([signx1, signy0])
if sign is not None:
out1 *= sign
out1 *= gx * (1 - gy)
out += out1
out1 = out1 * sign
out1 = out1 * (gx * (1 - gy))
out = out + out1
# - corner 11
idx = sub2ind_list([gx1, gy1], shape)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
sign = make_sign([signx1, signy1])
if sign is not None:
out1 *= sign
out1 *= gx * gy
out += out1
out1 = out1 * sign
out1 = out1 * (gx * gy)
out = out + out1

if mask is not None:
out *= mask
Expand Down Expand Up @@ -1124,17 +1125,17 @@ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
out = inp.gather(-1, idx)
sign = signx0
if sign is not None:
out *= sign
out *= (1 - gx)
out = out * sign
out = out * (1 - gx)
# - corner 1
idx = gx1
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
sign = signx1
if sign is not None:
out1 *= sign
out1 *= gx
out += out1
out1 = out1 * sign
out1 = out1 * gx
out = out + out1

if mask is not None:
out *= mask
Expand Down Expand Up @@ -1185,21 +1186,21 @@ def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
out1 = inp.clone()
sign = signx0
if sign is not None:
out1 *= sign
out1 = out1 * sign
if mask is not None:
out1 *= mask
out1 *= (1 - gx)
out1 = out1 * mask
out1 = out1 * (1 - gx)
out.scatter_add_(-1, idx, out1)
# - corner 1
idx = gx1
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.clone()
sign = signx1
if sign is not None:
out1 *= sign
out1 = out1 * sign
if mask is not None:
out1 *= mask
out1 *= gx
out1 = out1 * mask
out1 = out1 * gx
out.scatter_add_(-1, idx, out1)

out = out.reshape(list(out.shape[:2]) + shape)
Expand Down
Loading

0 comments on commit 9a9965e

Please sign in to comment.