From aceaa0481edffca445401798e8b497e7debfe133 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ya=C3=ABl=20Balbastre?= Date: Sat, 26 Aug 2023 10:29:58 -0400 Subject: [PATCH] FIX(coeff) (#13) --- interpol/coeff.py | 66 ++++++++++++++++++----------------------------- 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/interpol/coeff.py b/interpol/coeff.py index b7b7a04..fcbd116 100644 --- a/interpol/coeff.py +++ b/interpol/coeff.py @@ -73,6 +73,11 @@ def get_gain(poles: List[float]) -> float: return lam +@torch.jit.script +def _dot(x, y): + return x.unsqueeze(-2).matmul(y.unsqueeze(-1)).squeeze(-1).squeeze(-1) + + @torch.jit.script def dft_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): @@ -90,15 +95,13 @@ def dft_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): inp0 = inp[0] inp = inp[1-max_iter:] inp = movedim1(inp, 0, -1) - out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) - out = out + inp0.unsqueeze(-1) - if keepdim: - out = movedim1(out, -1, dim) - else: - out = out.squeeze(-1) + out = _dot(inp, poles) + inp0 pole = pole ** max_iter out = out / (1 - pole) + + if keepdim: + out = out.unsqueeze(dim) return out @@ -112,29 +115,20 @@ def dct1_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) poles = poles.pow( - torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device) + torch.arange(0, max_iter, dtype=inp.dtype, device=inp.device) ) inp = movedim1(inp, dim, 0) - inp0 = inp[0] - inp = inp[1:max_iter] + inp = inp[:max_iter] inp = movedim1(inp, 0, -1) - out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) - out = out + inp0.unsqueeze(-1) - if keepdim: - out = movedim1(out, -1, dim) - else: - out = out.squeeze(-1) + out = _dot(inp, poles) else: max_iter = n polen = pole ** (n - 1) - inp0 = inp[0] + polen * inp[-1] - inp = inp[1:-1] - inp = movedim1(inp, 0, -1) - out = inp0.unsqueeze(-1) + out = inp[0] + polen * inp[-1] if n > 2: poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) @@ -143,26 +137,18 @@ def dct1_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): ) poles = poles + (polen * polen) / poles - out = out + torch.matmul( - inp.unsqueeze(-2), poles.unsqueeze(-1) - ).squeeze(-1) - - if keepdim: - out = movedim1(out, -1, dim) - else: - out = out.squeeze(-1) + inp = inp[1:-1] + inp = movedim1(inp, 0, -1) + out = out + _dot(inp, poles) pole = pole ** (max_iter - 1) out = out / (1 - pole * pole) + if keepdim: + out = out.unsqueeze(dim) return out -@torch.jit.script -def _dot(x, y): - return x.unsqueeze(-2).matmul(y.unsqueeze(-1)).squeeze(-1).squeeze(-1) - - @torch.jit.script def dct2_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): # Ported from scipy: @@ -190,7 +176,6 @@ def dct2_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): if keepdim: out = out.unsqueeze(dim) - return out @@ -210,15 +195,14 @@ def dft_final(inp, pole: float, dim: int = -1, keepdim: bool = False): inp0 = inp[-1] inp = inp[:max_iter-1] inp = movedim1(inp, 0, -1) - out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) - out = out.add(inp0.unsqueeze(-1), alpha=pole) - if keepdim: - out = movedim1(out, -1, dim) - else: - out = out.squeeze(-1) + out = _dot(inp, poles) + out = out.add(inp0, alpha=pole) pole = pole ** max_iter out = out / (pole - 1) + + if keepdim: + out = out.unsqueeze(dim) return out @@ -228,7 +212,7 @@ def dct1_final(inp, pole: float, dim: int = -1, keepdim: bool = False): out = pole * inp[-2] + inp[-1] out = out * (pole / (pole*pole - 1)) if keepdim: - out = movedim1(out.unsqueeze(0), 0, dim) + out = out.unsqueeze(dim) return out @@ -239,7 +223,7 @@ def dct2_final(inp, pole: float, dim: int = -1, keepdim: bool = False): inp = movedim1(inp, dim, 0) out = inp[-1] * (pole / (pole - 1)) if keepdim: - out = movedim1(out.unsqueeze(0), 0, dim) + out = out.unsqueeze(dim) return out