forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathQuantizedLinear.cpp
309 lines (269 loc) · 10.4 KB
/
QuantizedLinear.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtilsMulti.h"
#ifdef USE_FBGEMM
#include "fbgemm/Fbgemm.h"
#include "fbgemm/QuantUtils.h"
#endif // USE_FBGEMM
#include <array>
#include <cctype>
#include <cmath>
#include <cstddef>
#include <sstream>
#include <string>
#include <vector>
#include <chrono>
namespace at {
namespace native {
#ifdef USE_FBGEMM
Tensor fbgemm_linear_int8_weight(
const Tensor& input,
const Tensor& weight,
const Tensor& packed,
const Tensor& col_offsets,
Scalar weight_scale,
Scalar weight_zero_point,
const Tensor& bias) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
auto input_contig = input.contiguous();
auto* input_ptr = input_contig.data<float>();
AT_ASSERT(input.dim() >= 2);
int64_t M = 1;
for (size_t i = 0; i < input.dim() - 1; ++i) {
M *= input.size(i);
}
int64_t K = input.size(input.dim() - 1);
AT_ASSERT(weight.dim() == 2);
AT_ASSERT(K == weight.size(1));
auto N = weight.size(0);
AT_ASSERT(bias.dim() == 1);
AT_ASSERT(bias.size(0) == N);
AT_ASSERT(weight_scale.isFloatingPoint());
AT_ASSERT(weight_zero_point.isIntegral());
// Calculate statistics for quantization of the input Tensor
float x_min, x_max;
fbgemm::FindMinMax(
/*m=*/input_ptr,
/*min=*/&x_min,
/*max=*/&x_max,
/*len=*/input.numel());
// Input tensor is quantized as 8-bit unsigned values
static constexpr int precision = 8;
static constexpr bool is_signed = false;
// Calculate scale and zero point for quantization of input tensor
auto q_params = fbgemm::ChooseQuantizationParams(
/*min=*/x_min,
/*max=*/x_max,
/*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
/*qmax=*/is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
/*preserve_sparsity=*/false);
q_params.precision = precision;
// This operation does the following:
// 1) Quantizes the input matrix given the statistics we've calculated above
// 2) Creates a "row buffer" vector with offset values that must be added
// to the integer matrix multiplication operation to ensure correctness
// 3) Packs the resulting quantized matrix into vector-register and cache
// friendly tiles.
//
// Note this is not executed eagerly, but rather within the fbgemmPacked call
// below.
fbgemm::PackAWithQuantRowOffset<uint8_t> packA(
/*trans=*/fbgemm::matrix_op_t::NoTranspose,
/*nRow=*/M,
/*nCol=*/K,
/*smat=*/input_ptr,
/*ld=*/K,
/*pmat=*/nullptr, // packA manages ownership of `pmat`
/*scale=*/q_params.scale,
/*zero_pt=*/q_params.zero_point);
// ReQuantizeForFloat requires pointers to the scale and zero point values,
// since in the case of rowwise quantization these will be arrays rather than
// scalars. But in this case, we're doing whole-tensor quantization so we just
// pass a pointer to the scale values (and internally ReQuantizeFor Float
// won't index past 0
float weight_scale_float = static_cast<float>(weight_scale.to<double>());
int32_t weight_zero_point_int32 =
static_cast<int32_t>(weight_zero_point.to<int64_t>());
// This is the end of the pipeline, pass the resulting matrix through
fbgemm::DoNothing<float, float> doNothingObj{};
auto bias_contig = bias.contiguous();
// After the uint8 * int8 matrix multiplication is performed, this operation
// does:
// 1) Add in row and column offsets to the rows and columns, respectively
// 2) Dequantize the results into floating point
// 3) Add in the bias term
fbgemm::ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
/*nextop=*/doNothingObj,
/*Aq_scale=*/q_params.scale,
/*Bq_scale=*/&weight_scale_float,
/*Aq_zero_point=*/q_params.zero_point,
/*Bq_zero_point=*/&weight_zero_point_int32,
/*row_offsets=*/packA.getRowOffsetBuffer(),
/*col_offsets=*/col_offsets.data<int32_t>(),
/*bias=*/bias_contig.data<float>(),
/*ncol=*/N);
// Allocate output Tensor and a buffer for fbgemmPacked to use
auto output = at::zeros({M, N}, bias.options().dtype(at::kFloat));
auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));
// Pull out the PackBMatrix instance from the owning tensor
auto* packB = reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(
packed.storage().data_ptr().get());
// Do the GEMM
fbgemm::fbgemmPacked(
/*packA=*/packA,
/*packB=*/*packB,
/*C=*/output.data<float>(),
/*C_buffer=*/buffer.data<int32_t>(),
/*ldc=*/N,
/*outProcess=*/outputProcObj,
/*thread_id=*/0,
/*num_threads=*/1);
// The resulting matrix here is 2-D, let's view it with the original
// left hand dimensions of the input.
std::vector<int64_t> out_sizes = input.sizes().vec();
out_sizes.back() = N;
return output.view(out_sizes);
}
namespace {
// Calculate the column offsets
// Note this includes the sum of the columns as well as the scalar term
// B_zero_point * K, whereas the row_offsets created by PackAWithQuantRowOffset
// is only the sum of the A rows.
void calc_col_offsets_transpose(
int K,
int N,
const int8_t* Bint8,
int32_t B_zero_point,
int32_t* col_offsets) {
for (size_t i = 0; i < N; ++i) {
int32_t sum = 0;
for (size_t j = 0; j < K; ++j) {
sum += Bint8[i * K + j];
}
col_offsets[i] = sum - B_zero_point * K;
}
}
} // namespace
std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
const Tensor& weight) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
auto weight_contig = weight.contiguous();
// Calculate weight statistics
float w_min, w_max;
fbgemm::FindMinMax(
/*m=*/weight_contig.data<float>(),
/*min=*/&w_min,
/*max=*/&w_max,
/*len=*/weight_contig.numel());
// Choose parameters for quantizing the weight as 8-bit signed integer
static constexpr bool is_signed = true;
static constexpr int precision = 8;
auto q_params = fbgemm::ChooseQuantizationParams(
/*min=*/w_min,
/*max=*/w_max,
/*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
/*qmax=*/is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
/*preserve_sparsity=*/false);
q_params.precision = precision;
auto quantized = at::zeros_like(weight_contig).to(at::kChar).contiguous();
fbgemm::Quantize<int8_t>(
/*src=*/weight_contig.data<float>(),
/*dst=*/quantized.data<int8_t>(),
/*len=*/weight_contig.numel(),
/*qparams=*/q_params);
// Calculate column offsets of the weight and store them away in a tensor.
// Similarly to quantization, this can be done once and cached.
auto col_offsets =
at::zeros_like(quantized).sum({1}).to(at::kInt).contiguous();
calc_col_offsets_transpose(
/*K=*/quantized.size(1),
/*N=*/quantized.size(0),
/*Bint8=*/quantized.data<int8_t>(),
/*B_zero_point=*/q_params.zero_point,
/*col_offsets=*/col_offsets.data<int32_t>());
return std::make_tuple(
quantized, col_offsets, q_params.scale, q_params.zero_point);
}
bool fbgemm_is_cpu_supported() {
return fbgemm::fbgemmSupportedCPU();
}
Tensor fbgemm_pack_quantized_matrix(
const Tensor& weight,
int64_t K,
int64_t N) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
auto weight_contig = weight.contiguous();
auto contiguous_ptr = weight_contig.data<int8_t>();
auto* ptr = new fbgemm::PackBMatrix<int8_t>(
/*trans=*/fbgemm::matrix_op_t::Transpose,
/*nRow=*/K,
/*nCol=*/N,
/*smat=*/contiguous_ptr,
/*ld=*/K,
/*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
/*groups=*/1);
// We store this instance away in a Tensor and register a deleter function
// so that we do not leak memory. On the other side, we pull out the storage's
// data_ptr and get the PackBMatrix's pointer.
at::DataPtr at_ptr(
ptr,
ptr,
[](void* ptr) {
fbgemm::PackBMatrix<int8_t>* typed_ptr =
reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(ptr);
delete typed_ptr;
},
at::kCPU);
auto retval = at::empty(
{sizeof(fbgemm::PackBMatrix<int8_t>)}, weight.options().dtype(at::kByte));
retval.storage().set_data_ptr(std::move(at_ptr));
return retval;
}
#else // USE_FBGEMM
Tensor fbgemm_linear_int8_weight(
const Tensor& /*input*/,
const Tensor& /*weight*/,
const Tensor& /*packed*/,
const Tensor& /*col_offsets*/,
Scalar /*weight_scale*/,
Scalar /*weight_zero_point*/,
const Tensor& /*bias*/) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(
false, "This PyTorch installation was not built with FBGEMM operators");
}
std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
const Tensor& /*weight*/) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(
false, "This PyTorch installation was not built with FBGEMM operators");
}
Tensor fbgemm_pack_quantized_matrix(
const Tensor& /*input*/,
int64_t /*K*/,
int64_t /*N*/) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(
false, "This PyTorch installation was not built with FBGEMM operators");
}
bool fbgemm_is_cpu_supported() {
return false;
}
#endif // USE_FBGEMM
}
} // namespace at