From ff720c79809bc15b9ab4ca7c5e90a1711cd19213 Mon Sep 17 00:00:00 2001 From: Yael Balbastre Date: Tue, 3 Jan 2023 17:09:46 -0500 Subject: [PATCH] FIX(nd.pull): expand index in channel dimension --- interpol/nd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/interpol/nd.py b/interpol/nd.py index 2ce7749..9e6ba63 100644 --- a/interpol/nd.py +++ b/interpol/nd.py @@ -117,7 +117,7 @@ def pull(inp, grid, bound: List[Bound], spline: List[Spline], # gather idx = [c[n] for c, n in zip(coords, nodes)] idx = sub2ind_list(idx, shape).unsqueeze(1) - idx = idx.expand([batch, 1, idx.shape[-1]]) + idx = idx.expand([batch, channel, idx.shape[-1]]) out1 = inp.gather(-1, idx) # apply sign