forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAffineGridGenerator.cpp
123 lines (108 loc) · 3.54 KB
/
AffineGridGenerator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
namespace at { namespace native {
at::Tensor linspace_from_neg_one(const Tensor& grid, int64_t num_steps) {
if (num_steps > 1) {
return at::linspace(-1, 1, num_steps, grid.options());
} else {
return at::tensor(-1, grid.options());
}
}
Tensor make_base_grid_4D(
const Tensor& theta,
int64_t N,
int64_t C,
int64_t H,
int64_t W) {
auto base_grid = at::empty({N, H, W, 3}, theta.options());
base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W));
base_grid.select(-1, 1).copy_(linspace_from_neg_one(theta, H).unsqueeze_(-1));
base_grid.select(-1, 2).fill_(1);
return base_grid;
}
Tensor make_base_grid_5D(
const Tensor& theta,
int64_t N,
int64_t C,
int64_t D,
int64_t H,
int64_t W) {
auto base_grid = at::empty({N, D, H, W, 4}, theta.options());
base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W));
base_grid.select(-1, 1).copy_(linspace_from_neg_one(theta, H).unsqueeze_(-1));
base_grid.select(-1, 2).copy_(linspace_from_neg_one(theta, D).unsqueeze_(-1).unsqueeze_(-1));
base_grid.select(-1, 3).fill_(1);
return base_grid;
}
Tensor affine_grid_generator_4D(
const Tensor& theta,
int64_t N,
int64_t C,
int64_t H,
int64_t W) {
Tensor base_grid = make_base_grid_4D(theta, N, C, H, W);
auto grid = base_grid.view({N, H * W, 3}).bmm(theta.transpose(1, 2));
return grid.view({N, H, W, 2});
}
Tensor affine_grid_generator_5D(
const Tensor& theta,
int64_t N,
int64_t C,
int64_t D,
int64_t H,
int64_t W) {
Tensor base_grid = make_base_grid_5D(theta, N, C, D, H, W);
auto grid = base_grid.view({N, D * H * W, 4}).bmm(theta.transpose(1, 2));
return grid.view({N, D, H, W, 3});
}
Tensor affine_grid_generator(const Tensor& theta, IntArrayRef size) {
AT_CHECK(
size.size() == 4 || size.size() == 5,
"AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs.");
if (size.size() == 4) {
return affine_grid_generator_4D(theta, size[0], size[1], size[2], size[3]);
} else {
return affine_grid_generator_5D(
theta, size[0], size[1], size[2], size[3], size[4]);
}
}
Tensor affine_grid_generator_4D_backward(
const Tensor& grad_grid,
int64_t N,
int64_t C,
int64_t H,
int64_t W) {
auto base_grid = make_base_grid_4D(grad_grid, N, C, H, W);
AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, H, W, 2}));
auto grad_theta = base_grid.view({N, H * W, 3})
.transpose(1, 2)
.bmm(grad_grid.view({N, H * W, 2}));
return grad_theta.transpose(1, 2);
}
Tensor affine_grid_generator_5D_backward(
const Tensor& grad_grid,
int64_t N,
int64_t C,
int64_t D,
int64_t H,
int64_t W) {
auto base_grid = make_base_grid_5D(grad_grid, N, C, D, H, W);
AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, D, H, W, 3}));
auto grad_theta = base_grid.view({N, D * H * W, 4})
.transpose(1, 2)
.bmm(grad_grid.view({N, D * H * W, 3}));
return grad_theta.transpose(1, 2);
}
Tensor affine_grid_generator_backward(const Tensor& grad, IntArrayRef size) {
AT_CHECK(
size.size() == 4 || size.size() == 5,
"AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs.");
if (size.size() == 4) {
return affine_grid_generator_4D_backward(
grad, size[0], size[1], size[2], size[3]);
} else {
return affine_grid_generator_5D_backward(
grad, size[0], size[1], size[2], size[3], size[4]);
}
}
}} // namespace at::native