forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LossMultiLabelMargin.cpp
325 lines (282 loc) · 9.88 KB
/
LossMultiLabelMargin.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorUtils.h>
#include <ATen/native/LossMulti.h>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
inline scalar_t multilabel_margin_loss_forward_inner_sum_cpu(
scalar_t* input_data,
int64_t* target_data,
scalar_t* is_target_data,
int64_t dim) {
using accscalar_t = at::acc_type<scalar_t, false>;
accscalar_t sum = 0;
for (int64_t ddt = 0; ddt < dim; ddt++) {
int64_t target_idx = target_data[ddt];
if (target_idx < 0) {
break;
}
is_target_data[target_idx] = 1;
}
for (int64_t dt = 0; dt < dim; dt++) {
int64_t target_idx = target_data[dt];
if (target_idx < 0) {
break;
}
scalar_t input_target = input_data[target_idx];
for (int64_t d = 0; d < dim; d++) {
if (!is_target_data[d]) {
scalar_t z = 1 - input_target + input_data[d];
if (z > 0) {
sum += z;
}
}
}
}
return sum;
}
template <typename scalar_t>
static void multilabel_margin_loss_forward_out_frame(
const Tensor& input_contiguous,
const Tensor& target_contiguous,
Tensor& output,
Tensor& is_target,
int64_t reduction,
int64_t nframe,
int64_t dim) {
using accscalar_t = at::acc_type<scalar_t, false>;
scalar_t* input_data = input_contiguous.data_ptr<scalar_t>();
int64_t* target_data = target_contiguous.data_ptr<int64_t>();
scalar_t* is_target_data = is_target.data_ptr<scalar_t>();
if (reduction != Reduction::None || output.dim() == 0) {
scalar_t* output_data = output.data_ptr<scalar_t>();
accscalar_t sum = 0;
for (int64_t t = 0; t < nframe; t++) {
sum += multilabel_margin_loss_forward_inner_sum_cpu(
input_data, target_data, is_target_data, dim);
input_data += dim;
target_data += dim;
is_target_data += dim;
}
sum /= dim;
if (reduction == Reduction::Mean) {
sum /= nframe;
}
*output_data = sum; // write scalar output value
} else {
auto output_acc = output.accessor<scalar_t, 1>();
for (int64_t t = 0; t < nframe; t++) {
scalar_t sum = multilabel_margin_loss_forward_inner_sum_cpu(
input_data, target_data, is_target_data, dim);
sum /= dim;
output_acc[t] = sum;
input_data += dim;
target_data += dim;
is_target_data += dim;
}
}
}
static void multilabel_margin_loss_forward_out_cpu_template(
const Tensor& input,
const Tensor& target,
Tensor& output,
Tensor& is_target,
int64_t reduction) {
auto target_arg = TensorArg(target, "target", 2);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t nframe, dim;
const int64_t ndims = input.dim();
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
}
else {
nframe = input.size(0);
dim = input.size(1);
}
multilabel_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target);
// special case target.dim() <= 1: produce scalar output for scalar inputs
// even if reduction == Reduction::None
if (reduction != Reduction::None || target.dim() <= 1) {
output.resize_({});
} else {
output.resize_({nframe});
}
is_target.resize_as_(target);
TORCH_CHECK(is_target.is_contiguous(), "is_target must be contiguous");
is_target.zero_();
if (input.numel() == 0) {
return;
}
TORCH_CHECK(
target.min().item<int64_t>() >= -1, target_arg, " is out of range");
TORCH_CHECK(
target.max().item<int64_t>() < dim, target_arg, " is out of range");
auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "multilabel_margin_loss_forward_out_frame", [&] {
multilabel_margin_loss_forward_out_frame<scalar_t>(
input_contiguous, target_contiguous, output, is_target, reduction, nframe, dim);
});
}
template <typename scalar_t>
static void multilabel_margin_loss_backward_out_frame(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input_contiguous,
const Tensor& target_contiguous,
int64_t reduction,
const Tensor& is_target_contiguous,
int64_t nframe,
int64_t dim) {
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
CheckedFrom c = "multilabel_margin_loss_backward_out_frame";
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto is_target_arg = TensorArg(is_target_contiguous, "is_target", 5);
TORCH_CHECK(
is_target_contiguous.min().item<scalar_t>() >= 0, is_target_arg, " is out of range");
TORCH_CHECK(
is_target_contiguous.max().item<scalar_t>() <= 1, is_target_arg, " is out of range");
scalar_t* input_data = input_contiguous.data_ptr<scalar_t>();
int64_t* target_data = target_contiguous.data_ptr<int64_t>();
scalar_t* is_target_data = is_target_contiguous.data_ptr<scalar_t>();
scalar_t g = static_cast<scalar_t>(
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
reduction == Reduction::Mean ? 1. / (nframe * dim) : 1. / dim);
scalar_t* grad_input_row_data = grad_input.data_ptr<scalar_t>();
for (int64_t t = 0; t < nframe; t++) {
for (int64_t dt = 0; dt < dim; dt++) {
int64_t target_idx = target_data[dt];
if (target_idx < 0) {
break;
}
scalar_t input_target = input_data[target_idx];
for (int64_t d = 0; d < dim; d++) {
if (!is_target_data[d]) {
scalar_t z = 1 - input_target + input_data[d];
if (z > 0) {
grad_input_row_data[target_idx] -= g;
grad_input_row_data[d] += g;
}
}
}
}
input_data += dim;
target_data += dim;
is_target_data += dim;
grad_input_row_data += dim;
}
scalar_t* grad_input_data = grad_input.data_ptr<scalar_t>();
if (reduction != Reduction::None || grad_output.dim() == 0) {
assert(
reduction != Reduction::None || grad_output.dim() > 0 || nframe == 1);
const auto d = *grad_output.data_ptr<scalar_t>();
for (int64_t t = 0; t < nframe * dim; t++) {
grad_input_data[t] *= d;
}
} else {
check_dim_size(grad_output, 1, 0, nframe);
auto grad_output_acc = grad_output.accessor<scalar_t, 1>();
for (int64_t t = 0; t < nframe; t++) {
for (int64_t d = 0; d < dim; d++) {
grad_input_data[t * dim + d] *= grad_output_acc[t];
}
}
}
}
static void multilabel_margin_loss_backward_out_cpu_template(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
int64_t reduction,
const Tensor& is_target) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t nframe, dim;
CheckedFrom c = "multilabel_margin_loss_backward_cpu_template";
auto target_arg = TensorArg(target, "target", 3);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto is_target_arg = TensorArg(is_target, "is_target", 5);
const int64_t ndims = input.dim();
multilabel_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target);
checkSameSize(c, target_arg, is_target_arg);
grad_input.resize_as_(input);
if (grad_input.numel() == 0) {
return;
}
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
grad_input.zero_();
TORCH_CHECK(
target.min().item<int64_t>() >= -1, target_arg, " is out of range");
TORCH_CHECK(
target.max().item<int64_t>() < dim, target_arg, " is out of range");
auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
auto is_target_contiguous = is_target.contiguous();
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "multilabel_margin_loss_backward_out_frame", [&] {
multilabel_margin_loss_backward_out_frame<scalar_t>(
grad_input,
grad_output,
input_contiguous,
target_contiguous,
reduction,
is_target_contiguous,
nframe,
dim);
});
}
} // namespace
std::tuple<Tensor&, Tensor&> multilabel_margin_loss_forward_out_cpu(const Tensor& self,
const Tensor& target,
int64_t reduction,
Tensor& output,
Tensor& is_target) {
multilabel_margin_loss_forward_out_cpu_template(
self, target, output, is_target, reduction);
return std::tuple<Tensor&, Tensor&>(output, is_target);
}
std::tuple<Tensor, Tensor> multilabel_margin_loss_forward_cpu(
const Tensor& self,
const Tensor& target,
int64_t reduction) {
auto output = at::empty({0}, self.options());
auto is_target = at::empty({0}, self.options());
at::native::multilabel_margin_loss_forward_out_cpu(
self, target, reduction, output, is_target);
return std::make_tuple(output, is_target);
}
Tensor& multilabel_margin_loss_backward_cpu_out(const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
int64_t reduction,
const Tensor& is_target,
Tensor& grad_input) {
multilabel_margin_loss_backward_out_cpu_template(
grad_input, grad_output, self, target, reduction, is_target);
return grad_input;
}
Tensor multilabel_margin_loss_backward_cpu(
const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
int64_t reduction,
const Tensor& is_target) {
auto grad_input = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
at::native::multilabel_margin_loss_backward_cpu_out(
grad_output, self, target, reduction, is_target, grad_input);
return grad_input;
}
Tensor & multilabel_margin_loss_out(const Tensor & self, const Tensor & target, int64_t reduction, Tensor & output) {
Tensor is_target = at::empty({0}, self.options());
return std::get<0>(at::multilabel_margin_loss_forward_out(output, is_target, self, target, reduction));
}
Tensor multilabel_margin_loss(const Tensor & self, const Tensor & target, int64_t reduction) {
return std::get<0>(at::multilabel_margin_loss_forward(self, target, reduction));
}
} // namespace native
} // namespace at