From 9a9965e02da42958df123b12b7082bf75ca8f7b4 Mon Sep 17 00:00:00 2001 From: Yael Balbastre Date: Tue, 5 Jul 2022 12:33:52 -0400 Subject: [PATCH] FIX(make_sign): workaround some weitd JIT bug + PERF(pull): use less inplace operations --- interpol/iso0.py | 16 ++--- interpol/iso1.py | 143 +++++++++++++++++++++--------------------- interpol/jit_utils.py | 31 +++++++-- interpol/nd.py | 40 ++++++------ 4 files changed, 125 insertions(+), 105 deletions(-) diff --git a/interpol/iso0.py b/interpol/iso0.py index c86fdb3..7f43a81 100644 --- a/interpol/iso0.py +++ b/interpol/iso0.py @@ -150,9 +150,9 @@ def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1): out = inp.gather(-1, idx) sign = make_sign([signx, signy]) if sign is not None: - out *= sign + out = out * sign if mask is not None: - out *= mask + out = mask * mask out = out.reshape(out.shape[:2] + oshape) return out @@ -198,9 +198,9 @@ def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound], if sign is not None or mask is not None: inp = inp.clone() if sign is not None: - inp *= sign + inp = inp * sign if mask is not None: - inp *= mask + inp = inp * mask out.scatter_add_(-1, idx, inp) out = out.reshape(out.shape[:2] + shape) @@ -244,9 +244,9 @@ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1): out = inp.gather(-1, idx) sign = signx if sign is not None: - out *= sign + out = out * sign if mask is not None: - out *= mask + out = out * mask out = out.reshape(out.shape[:2] + oshape) return out @@ -291,9 +291,9 @@ def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound], if sign is not None or mask is not None: inp = inp.clone() if sign is not None: - inp *= sign + inp = inp * sign if mask is not None: - inp *= mask + inp = inp * mask out.scatter_add_(-1, idx, inp) out = out.reshape(out.shape[:2] + shape) diff --git a/interpol/iso1.py b/interpol/iso1.py index 1d64c40..6c86a88 100644 --- a/interpol/iso1.py +++ b/interpol/iso1.py @@ -3,12 +3,13 @@ from .bounds import Bound from .jit_utils import (sub2ind_list, make_sign, inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d) -from typing import List, Optional +from typing import List, Tuple, Optional Tensor = torch.Tensor @torch.jit.script -def get_weights_and_indices(g, n: int, bound: Bound): +def get_weights_and_indices(g, n: int, bound: Bound) \ + -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: g0 = g.floor().long() g1 = g0 + 1 sign1 = bound.transform(g1, n) @@ -60,71 +61,71 @@ def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1): out = inp.gather(-1, idx) sign = make_sign([signx0, signy0, signz0]) if sign is not None: - out *= sign - out *= (1 - gx) * (1 - gy) * (1 - gz) + out = out * sign + out = out * ((1 - gx) * (1 - gy) * (1 - gz)) # - corner 001 idx = sub2ind_list([gx0, gy0, gz1], shape) idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) sign = make_sign([signx0, signy0, signz1]) if sign is not None: - out1 *= sign - out1 *= (1 - gx) * (1 - gy) * gz - out += out1 + out1 = out1 * sign + out1 = out1 * ((1 - gx) * (1 - gy) * gz) + out = out + out1 # - corner 010 idx = sub2ind_list([gx0, gy1, gz0], shape) idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) sign = make_sign([signx0, signy1, signz0]) if sign is not None: - out1 *= sign - out1 *= (1 - gx) * gy * (1 - gz) - out += out1 + out1 = out1 * sign + out1 = out1 * ((1 - gx) * gy * (1 - gz)) + out = out + out1 # - corner 011 idx = sub2ind_list([gx0, gy1, gz1], shape) idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) sign = make_sign([signx0, signy1, signz1]) if sign is not None: - out1 *= sign - out1 *= (1 - gx) * gy * gz - out += out1 + out1 = out1 * sign + out1 = out1 * ((1 - gx) * gy * gz) + out = out + out1 # - corner 100 idx = sub2ind_list([gx1, gy0, gz0], shape) idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) sign = make_sign([signx1, signy0, signz0]) if sign is not None: - out1 *= sign - out1 *= gx * (1 - gy) * (1 - gz) - out += out1 + out1 = out1 * sign + out1 = out1 * (gx * (1 - gy) * (1 - gz)) + out = out + out1 # - corner 101 idx = sub2ind_list([gx1, gy0, gz1], shape) idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) sign = make_sign([signx1, signy0, signz1]) if sign is not None: - out1 *= sign - out1 *= gx * (1 - gy) * gz - out += out1 + out1 = out1 * sign + out1 = out1 * (gx * (1 - gy) * gz) + out = out + out1 # - corner 110 idx = sub2ind_list([gx1, gy1, gz0], shape) idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) sign = make_sign([signx1, signy1, signz0]) if sign is not None: - out1 *= sign - out1 *= gx * gy * (1 - gz) - out += out1 + out1 = out1 * sign + out1 = out1 * (gx * gy * (1 - gz)) + out = out + out1 # - corner 111 idx = sub2ind_list([gx1, gy1, gz1], shape) idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) sign = make_sign([signx1, signy1, signz1]) if sign is not None: - out1 *= sign - out1 *= gx * gy * gz - out += out1 + out1 = out1 * sign + out1 = out1 * (gx * gy * gz) + out = out + out1 if mask is not None: out *= mask @@ -177,10 +178,10 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], out1 = inp.clone() sign = make_sign([signx0, signy0, signz0]) if sign is not None: - out1 *= sign + out1 = out1 * sign if mask is not None: - out1 *= mask - out1 *= (1 - gx) * (1 - gy) * (1 - gz) + out1 = out1 * mask + out1 = out1 * ((1 - gx) * (1 - gy) * (1 - gz)) out.scatter_add_(-1, idx, out1) # - corner 001 idx = sub2ind_list([gx0, gy0, gz1], shape) @@ -188,10 +189,10 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], out1 = inp.clone() sign = make_sign([signx0, signy0, signz1]) if sign is not None: - out1 *= sign + out1 = out1 * sign if mask is not None: - out1 *= mask - out1 *= (1 - gx) * (1 - gy) * gz + out1 = out1 * mask + out1 = out1 * ((1 - gx) * (1 - gy) * gz) out.scatter_add_(-1, idx, out1) # - corner 010 idx = sub2ind_list([gx0, gy1, gz0], shape) @@ -199,10 +200,10 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], out1 = inp.clone() sign = make_sign([signx0, signy1, signz0]) if sign is not None: - out1 *= sign + out1 = out1 * sign if mask is not None: - out1 *= mask - out1 *= (1 - gx) * gy * (1 - gz) + out1 = out1 * mask + out1 = out1 * ((1 - gx) * gy * (1 - gz)) out.scatter_add_(-1, idx, out1) # - corner 011 idx = sub2ind_list([gx0, gy1, gz1], shape) @@ -210,10 +211,10 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], out1 = inp.clone() sign = make_sign([signx0, signy1, signz1]) if sign is not None: - out1 *= sign + out1 = out1 * sign if mask is not None: - out1 *= mask - out1 *= (1 - gx) * gy * gz + out1 = out1 * mask + out1 = out1 * ((1 - gx) * gy * gz) out.scatter_add_(-1, idx, out1) # - corner 100 idx = sub2ind_list([gx1, gy0, gz0], shape) @@ -221,10 +222,10 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], out1 = inp.clone() sign = make_sign([signx1, signy0, signz0]) if sign is not None: - out1 *= sign + out1 = out1 * sign if mask is not None: - out1 *= mask - out1 *= gx * (1 - gy) * (1 - gz) + out1 = out1 * mask + out1 = out1 * (gx * (1 - gy) * (1 - gz)) out.scatter_add_(-1, idx, out1) # - corner 101 idx = sub2ind_list([gx1, gy0, gz1], shape) @@ -232,10 +233,10 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], out1 = inp.clone() sign = make_sign([signx1, signy0, signz1]) if sign is not None: - out1 *= sign + out1 = out1 * sign if mask is not None: - out1 *= mask - out1 *= gx * (1 - gy) * gz + out1 = out1 * mask + out1 = out1 * (gx * (1 - gy) * gz) out.scatter_add_(-1, idx, out1) # - corner 110 idx = sub2ind_list([gx1, gy1, gz0], shape) @@ -243,10 +244,10 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], out1 = inp.clone() sign = make_sign([signx1, signy1, signz0]) if sign is not None: - out1 *= sign + out1 = out1 * sign if mask is not None: - out1 *= mask - out1 *= gx * gy * (1 - gz) + out1 = out1 * mask + out1 = out1 * (gx * gy * (1 - gz)) out.scatter_add_(-1, idx, out1) # - corner 111 idx = sub2ind_list([gx1, gy1, gz1], shape) @@ -254,10 +255,10 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], out1 = inp.clone() sign = make_sign([signx1, signy1, signz1]) if sign is not None: - out1 *= sign + out1 = out1 * sign if mask is not None: - out1 *= mask - out1 *= gx * gy * gz + out1 = out1 * mask + out1 = out1 * (gx * gy * gz) out.scatter_add_(-1, idx, out1) out = out.reshape(list(out.shape[:2]) + shape) @@ -714,35 +715,35 @@ def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1): out = inp.gather(-1, idx) sign = make_sign([signx0, signy0]) if sign is not None: - out *= sign - out *= (1 - gx) * (1 - gy) + out = out * sign + out = out * ((1 - gx) * (1 - gy)) # - corner 01 idx = sub2ind_list([gx0, gy1], shape) idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) sign = make_sign([signx0, signy1]) if sign is not None: - out1 *= sign - out1 *= (1 - gx) * gy - out += out1 + out1 = out1 * sign + out1 = out1 * ((1 - gx) * gy) + out = out + out1 # - corner 10 idx = sub2ind_list([gx1, gy0], shape) idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) sign = make_sign([signx1, signy0]) if sign is not None: - out1 *= sign - out1 *= gx * (1 - gy) - out += out1 + out1 = out1 * sign + out1 = out1 * (gx * (1 - gy)) + out = out + out1 # - corner 11 idx = sub2ind_list([gx1, gy1], shape) idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) sign = make_sign([signx1, signy1]) if sign is not None: - out1 *= sign - out1 *= gx * gy - out += out1 + out1 = out1 * sign + out1 = out1 * (gx * gy) + out = out + out1 if mask is not None: out *= mask @@ -1124,17 +1125,17 @@ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1): out = inp.gather(-1, idx) sign = signx0 if sign is not None: - out *= sign - out *= (1 - gx) + out = out * sign + out = out * (1 - gx) # - corner 1 idx = gx1 idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) sign = signx1 if sign is not None: - out1 *= sign - out1 *= gx - out += out1 + out1 = out1 * sign + out1 = out1 * gx + out = out + out1 if mask is not None: out *= mask @@ -1185,10 +1186,10 @@ def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound], out1 = inp.clone() sign = signx0 if sign is not None: - out1 *= sign + out1 = out1 * sign if mask is not None: - out1 *= mask - out1 *= (1 - gx) + out1 = out1 * mask + out1 = out1 * (1 - gx) out.scatter_add_(-1, idx, out1) # - corner 1 idx = gx1 @@ -1196,10 +1197,10 @@ def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound], out1 = inp.clone() sign = signx1 if sign is not None: - out1 *= sign + out1 = out1 * sign if mask is not None: - out1 *= mask - out1 *= gx + out1 = out1 * mask + out1 = out1 * gx out.scatter_add_(-1, idx, out1) out = out.reshape(list(out.shape[:2]) + shape) diff --git a/interpol/jit_utils.py b/interpol/jit_utils.py index 18224fb..b58f18c 100644 --- a/interpol/jit_utils.py +++ b/interpol/jit_utils.py @@ -68,6 +68,26 @@ def list_sum_int(x: List[int]) -> int: return x0 +@torch.jit.script +def list_prod_tensor(x: List[Tensor]) -> Tensor: + if len(x) == 0: + return torch.ones([]) + x0 = x[0] + for x1 in x[1:]: + x0 = x0 * x1 + return x0 + + +@torch.jit.script +def list_sum_tensor(x: List[Tensor]) -> Tensor: + if len(x) == 0: + return torch.ones([]) + x0 = x[0] + for x1 in x[1:]: + x0 = x0 + x1 + return x0 + + @torch.jit.script def list_reverse_int(x: List[int]) -> List[int]: if len(x) == 0: @@ -313,14 +333,13 @@ def inbounds_mask_1d(extrapolate: int, gx, nx: int) -> Optional[Tensor]: @torch.jit.script def make_sign(sign: List[Optional[Tensor]]) -> Optional[Tensor]: - osign: Optional[Tensor] = None + if all([s is None for s in sign]): + return None + filt_sign: List[Tensor] = [] for s in sign: if s is not None: - if osign is None: - osign = s - else: - osign = osign * s - return osign + filt_sign.append(s) + return list_prod_tensor(filt_sign) @torch.jit.script diff --git a/interpol/nd.py b/interpol/nd.py index 0cfdef3..55ad1f1 100644 --- a/interpol/nd.py +++ b/interpol/nd.py @@ -120,18 +120,18 @@ def pull(inp, grid, bound: List[Bound], spline: List[Spline], sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] sign1: Optional[Tensor] = make_sign(sign0) if sign1 is not None: - out1 *= sign1 + out1 = out1 * sign1 # apply weights for weight, n in zip(weights, nodes): - out1 *= weight[n] + out1 = out1 * weight[n] # accumulate - out += out1 + out = out + out1 # out-of-bounds mask if mask is not None: - out *= mask + out = out * mask out = out.reshape(list(out.shape[:2]) + oshape) return out @@ -187,15 +187,15 @@ def push(inp, grid, shape: Optional[List[int]], bound: List[Bound], sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] sign1: Optional[Tensor] = make_sign(sign0) if sign1 is not None: - out1 *= sign1 + out1 = out1 * sign1 # out-of-bounds mask if mask is not None: - out1 *= mask + out1 = out1 * mask # apply weights for weight, n in zip(weights, nodes): - out1 *= weight[n] + out1 = out1 * weight[n] # accumulate out.scatter_add_(-1, idx, out1) @@ -252,7 +252,7 @@ def grad(inp, grid, bound: List[Bound], spline: List[Spline], sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] sign1: Optional[Tensor] = make_sign(sign0) if sign1 is not None: - out0 *= sign1 + out0 = out0 * sign1 for d in range(dim): out1 = out0.clone() @@ -261,16 +261,16 @@ def grad(inp, grid, bound: List[Bound], spline: List[Spline], if d == dd: grad11 = grad1[n] if grad11 is not None: - out1 *= grad11 + out1 = out1 * grad11 else: - out1 *= weight[n] + out1 = out1 * weight[n] # accumulate out.unbind(-1)[d].add_(out1) # out-of-bounds mask if mask is not None: - out *= mask.unsqueeze(-1) + out = out * mask.unsqueeze(-1) out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-1:])) return out @@ -325,11 +325,11 @@ def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound], sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] sign1: Optional[Tensor] = make_sign(sign0) if sign1 is not None: - out0 *= sign1.unsqueeze(-1) + out0 = out0 * sign1.unsqueeze(-1) # out-of-bounds mask if mask is not None: - out0 *= mask.unsqueeze(-1) + out0 = out0 * mask.unsqueeze(-1) for d in range(dim): out1 = out0.unbind(-1)[d].clone() @@ -338,9 +338,9 @@ def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound], if d == dd: grad11 = grad1[n] if grad11 is not None: - out1 *= grad11 + out1 = out1 * grad11 else: - out1 *= weight[n] + out1 = out1 * weight[n] # accumulate out.scatter_add_(-1, idx, out1) @@ -397,7 +397,7 @@ def hess(inp, grid, bound: List[Bound], spline: List[Spline], sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] sign1: Optional[Tensor] = make_sign(sign0) if sign1 is not None: - out1 *= sign1 + out1 = out1 * sign1 for d in range(dim): # -- diagonal -- @@ -410,7 +410,7 @@ def hess(inp, grid, bound: List[Bound], spline: List[Spline], if hess11 is not None: out1 *= hess11 else: - out1 *= weight[n] + out1 = out1 * weight[n] # accumulate out.unbind(-1)[d].unbind(-1)[d].add_(out1) @@ -424,16 +424,16 @@ def hess(inp, grid, bound: List[Bound], spline: List[Spline], if dd in (d, d2): grad11 = grad1[n] if grad11 is not None: - out1 *= grad11 + out1 = out1 * grad11 else: - out1 *= weight[n] + out1 = out1 * weight[n] # accumulate out.unbind(-1)[d].unbind(-1)[d2].add_(out1) # out-of-bounds mask if mask is not None: - out *= mask.unsqueeze(-1) + out = out * mask.unsqueeze(-1) # fill lower triangle for d in range(dim):