Skip to content

Commit

Permalink
FIX(nd.pull): expand index in channel dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
balbasty committed Jan 3, 2023
1 parent da85e83 commit ff720c7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion interpol/nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def pull(inp, grid, bound: List[Bound], spline: List[Spline],
# gather
idx = [c[n] for c, n in zip(coords, nodes)]
idx = sub2ind_list(idx, shape).unsqueeze(1)
idx = idx.expand([batch, 1, idx.shape[-1]])
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)

# apply sign
Expand Down

0 comments on commit ff720c7

Please sign in to comment.