-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_corr.py
156 lines (137 loc) · 4.08 KB
/
image_corr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
import spatial_correlation_sampler_backend as correlation
def spatial_correlation_sample(
input1,
input2,
kernel_size=1,
patch_size=1,
stride=1,
padding=0,
dilation=1,
dilation_patch=1,
):
"""Apply spatial correlation sampling on from input1 to input2,
Every parameter except input1 and input2 can be either single int
or a pair of int. For more information about Spatial Correlation
Sampling, see this page.
https://lmb.informatik.uni-freiburg.de/Publications/2015/DFIB15/
Args:
input1 : The first parameter.
input2 : The second parameter.
kernel_size : total size of your correlation kernel, in pixels
patch_size : total size of your patch, determining how many
different shifts will be applied
stride : stride of the spatial sampler, will modify output
height and width
padding : padding applied to input1 and input2 before applying
the correlation sampling, will modify output height and width
dilation_patch : step for every shift in patch
Returns:
Tensor: Result of correlation sampling
"""
return SpatialCorrelationSamplerFunction.apply(
input1,
input2,
kernel_size,
patch_size,
stride,
padding,
dilation,
dilation_patch,
)
class SpatialCorrelationSamplerFunction(Function):
@staticmethod
def forward(
ctx,
input1,
input2,
kernel_size=1,
patch_size=1,
stride=1,
padding=0,
dilation=1,
dilation_patch=1,
):
ctx.save_for_backward(input1, input2)
kH, kW = ctx.kernel_size = _pair(kernel_size)
patchH, patchW = ctx.patch_size = _pair(patch_size)
padH, padW = ctx.padding = _pair(padding)
dilationH, dilationW = ctx.dilation = _pair(dilation)
dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(dilation_patch)
dH, dW = ctx.stride = _pair(stride)
output = correlation.forward(
input1,
input2,
kH,
kW,
patchH,
patchW,
padH,
padW,
dilationH,
dilationW,
dilation_patchH,
dilation_patchW,
dH,
dW,
)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input1, input2 = ctx.saved_tensors
kH, kW = ctx.kernel_size
patchH, patchW = ctx.patch_size
padH, padW = ctx.padding
dilationH, dilationW = ctx.dilation
dilation_patchH, dilation_patchW = ctx.dilation_patch
dH, dW = ctx.stride
grad_input1, grad_input2 = correlation.backward(
input1,
input2,
grad_output,
kH,
kW,
patchH,
patchW,
padH,
padW,
dilationH,
dilationW,
dilation_patchH,
dilation_patchW,
dH,
dW,
)
return grad_input1, grad_input2, None, None, None, None, None, None
class SpatialCorrelationSampler(nn.Module):
def __init__(
self,
kernel_size=1,
patch_size=1,
stride=1,
padding=0,
dilation=1,
dilation_patch=1,
):
super(SpatialCorrelationSampler, self).__init__()
self.kernel_size = kernel_size
self.patch_size = patch_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.dilation_patch = dilation_patch
def forward(self, input1, input2):
return SpatialCorrelationSamplerFunction.apply(
input1,
input2,
self.kernel_size,
self.patch_size,
self.stride,
self.padding,
self.dilation,
self.dilation_patch,
)