forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SparseBlasImpl.cpp
289 lines (256 loc) · 10.2 KB
/
SparseBlasImpl.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <ATen/native/mkl/SparseBlasImpl.h>
#include <ATen/native/sparse/SparseBlasImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Operators.h>
#else
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/zeros.h>
#endif
namespace at {
namespace native {
namespace sparse {
namespace impl {
Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& strided, Tensor& result) {
const auto compressed_layout = compressed.layout();
const auto compressed_layout_str = at::sparse_csr::layoutToString(compressed_layout);
// Device restrictions
TORCH_CHECK(compressed.device() == strided.device()
&& compressed.device() == result.device(),
"spmm_out(): all input arguments are expected to be on the same device.");
// Layout restrictions.
TORCH_CHECK(compressed_layout == kSparseCsr || compressed_layout == kSparseBsr,
"spmm(", compressed_layout_str, ", Strided): only Csr and Bsr formats are supported for the sparse argument.");
TORCH_CHECK(result.layout() == kStrided,
"spmm_out(): out argument is expected to be strided.");
// Dtype restrictions.
TORCH_CHECK(compressed.scalar_type() == strided.scalar_type(),
"spmm(", compressed_layout_str, ", Strided): arguments expected to have the same dtype.");
// Dim restrictions.
TORCH_CHECK(compressed.dim() == 2,
"spmm(", compressed_layout_str, ", Strided): sparse arguments which are not 2D are not supported.");
TORCH_CHECK(strided.dim() >= 2,
"spmm(", compressed_layout_str, ", Strided): expects strided inputs to be at least 2D.");
const auto m = compressed.sizes()[0];
const auto k = compressed.sizes()[1];
const auto n = strided.size(-1);
// Matrix product size compatibility.
TORCH_CHECK(strided.size(-2) == k,
"spmm(", compressed_layout_str, "Strided): argument sizes are not compatible for matrix multiplication. ",
"Got ", compressed_layout_str, ".sizes(-1) == ", k, " is not equal to ",
"Strided.sizes(-2) == ", strided.size(-2), ".");
// We assume that result is properly resized.
auto result_expected_size = at::DimVector(strided.sizes().slice(0, strided.dim() - 2));
result_expected_size.push_back(m);
result_expected_size.push_back(n);
TORCH_CHECK(result.sizes() == result_expected_size,
"spmm_out(): out argument has wrong size. ",
"Expected (", result_expected_size, ") but got (", result.sizes(), ").");
auto values = compressed.values();
using Blocksize = std::array<int64_t, 2>;
// We refer to these as (b0, b1) in the comments below.
Blocksize blocksize = {1, 1};
if (compressed_layout == kSparseBsr) {
blocksize = {values.size(-2), values.size(-1)};
}
// (..., r, c) -> (..., r / b0, c / b1, b0, b1)
// NOTE: this function ALWAYS creates a view upon successful execution.
const auto tile_tensor = [compressed_layout](
const Tensor& t, Blocksize blocksize) -> Tensor {
if (compressed_layout == kSparseCsr) {
return t.unsqueeze(-1).unsqueeze_(-1);
}
else {
const auto size_neg_2_blocked = t.size(-2) / blocksize[0];
const auto size_neg_1_blocked = t.size(-1) / blocksize[1];
auto tiled_sizes = at::DimVector(t.sizes().slice(0, t.dim() - 2));
tiled_sizes.push_back(size_neg_2_blocked);
tiled_sizes.push_back(blocksize[0]);
tiled_sizes.push_back(size_neg_1_blocked);
tiled_sizes.push_back(blocksize[1]);
return t.reshape(tiled_sizes).transpose(-3, -2);
}
};
// Note that sparse values are (..., b0, b1). This means that
// the strided input has to be "tilable" to (..., b1, x) with
// any x >= 1 such that all the shapes are (block) matrix product
// compatible. The matrix product will then have shape (..., b0, x).
// This in turn means the the result has to be "tilable" to
// (..., b0, x).
//
// These observations imply the following restrictions:
// 1. strided.size(-2) has to be divisible by b1.
// 2. result.size(-2) has to be divisible by b0.
// 3. both strided.size(-1) and result.size(-1)
// have to be divisible by x.
//
// Restrictions 1 and 2 are trivially satisfied.
// Regarding restriction 3:
// it would make sense to take the largest possible x for better
// performance since it is very likely that the last dimension
// is contiguous. As such, this value is exactly
// x = strided.size(-1), since strided.size(-1) == result.size(-1)
// See the comments above. This is our x.
const auto outer_blocksize = n;
Blocksize strided_blocksize = {blocksize[1], outer_blocksize};
const auto strided_tiled = tile_tensor(strided, strided_blocksize);
// Left argument is (..., b0, b1) and right is (..., b1, x).
// This naturally implies the result should be "tilable" as
// (..., b0, x).
Blocksize result_blocksize = {blocksize[0], outer_blocksize};
auto result_tiled = tile_tensor(result, result_blocksize);
if (compressed_layout == kSparseCsr) {
values.unsqueeze_(-1).unsqueeze_(-1);
}
Tensor compressed_indices, plain_indices;
std::tie(compressed_indices, plain_indices) = at::sparse_csr::getCompressedPlainIndices(compressed);
// Select block rows of the strided input that intersect with the block colums of the sparse input.
auto strided_tiled_selected_rows = strided_tiled.index_select(-4, plain_indices);
// Promote to float if output is half or bfloat16 for better precision
const auto mm_dtype = (result.scalar_type() == kHalf || result.scalar_type() == kBFloat16)
? kFloat : result.scalar_type();
// Now that we know which block rows intersect with which block columns,
// we can perform matrix products between pairs of blocks.
// NOTE: .to is a no-op when result.scalar_type() == mm_dtype.
const auto pairwise_block_mm = values.unsqueeze(-3).to(mm_dtype)
.matmul(strided_tiled_selected_rows.to(mm_dtype));
// Having pairwise block matrix products stored in pairwise_block_mm,
// it is sufficient to sum all the block products that share the same row
// encoded in the sparse index. Since the reduction step is done via
// advanced indexing methods, the compressed index ought to get converted
// to the COO format.
const auto compressed_indices_coo = at::_convert_indices_from_csr_to_coo(
compressed_indices,
plain_indices,
compressed_indices.scalar_type() == kInt).select(0, 0);
// Reduction step.
// If result is neither half nor bfloat16, do everyting in-place.
if (result.scalar_type() == mm_dtype) {
// Zero out and sum over the blocks that share the same row indices.
result_tiled.zero_();
result_tiled.index_add_(
/*dim=*/-4,
/*index=*/compressed_indices_coo,
/*source=*/pairwise_block_mm);
}
// Otherwise accumulate into a buffer and then copy.
else {
// No need to zero out, sum over the blocks goes into a buffer
// followed by a copy into result.
auto promoted_result_tiled = at::zeros(
result_tiled.sizes(),
result_tiled.options().dtype(mm_dtype));
promoted_result_tiled.index_add_(
/*dim=*/-4,
/*index=*/compressed_indices_coo,
/*source=*/pairwise_block_mm);
result_tiled.copy_(promoted_result_tiled);
}
return result;
}
Tensor& _compressed_row_strided_addmm_out(
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
// If result is not the same as self, it could always be used as out argument to mm.
if (!result.is_same(self)) {
_compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);
// Process beta
if (beta.toComplexDouble() != 0.) {
result.add_(self.mul(beta));
}
}
// Otherwise we need to allocate external memory for mm if beta != 0.
else {
// Process beta
if (beta.toComplexDouble() != 0.) {
result.mul_(beta);
auto mm = at::empty_like(result);
_compressed_row_strided_mm_out(mat1, mat2, mm);
mm.mul_(alpha);
result.add_(mm);
}
else {
_compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);
}
}
return result;
}
namespace cpu {
/*
Computes a sparse matrix-dense vector product defined as
y <- alpha*op(A)*x + beta*y
Args:
* `mat` - Tensor storing sparse m x n matrix A.
* `vec` - Tensor storing dense vector x of size n.
* `result` - [in] Tensor storing dense vector y of size m.
[out] result of the operation.
*/
void addmv_out_sparse_csr(
const Tensor& mat,
const Tensor& vec,
const Scalar& beta,
const Scalar& alpha,
const Tensor& result) {
#if !AT_MKL_ENABLED()
TORCH_CHECK(
false,
"Calling addmv on a sparse CPU tensor requires compiling PyTorch with MKL. ",
"Please use PyTorch built MKL support.");
#else
sparse::impl::mkl::addmv_out_sparse_csr(mat, vec, beta, alpha, result);
#endif
}
/*
Computes a sum of two sparse matrices defined as
result <- mat1 + alpha*mat2
Args:
* `mat1` - CSR Tensor storing sparse m x n matrix.
* `mat2` - CSR Tensor storing sparse m x n matrix.
* `result` - [in] CSR Tensor storing sparse m x n matrix.
[out] result of the operation.
*/
void add_out_sparse_csr(
const Tensor& mat1,
const Tensor& mat2,
const Scalar& alpha,
const Tensor& result) {
#if !AT_MKL_ENABLED()
TORCH_CHECK(
false,
"Calling add on a sparse CPU tensor requires compiling PyTorch with MKL. ",
"Please use PyTorch built MKL support.");
#else
sparse::impl::mkl::add_out_sparse_csr(mat1, mat2, alpha, result);
#endif
}
void triangular_solve_out_sparse_csr(
const Tensor& A,
const Tensor& B,
const Tensor& X,
bool upper,
bool transpose,
bool unitriangular) {
#if !AT_MKL_ENABLED()
TORCH_CHECK(
false,
"Calling triangular_solve on a sparse CPU tensor requires compiling PyTorch with MKL. ",
"Please use PyTorch built MKL support.");
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.layout() == kSparseCsr || A.layout() == kSparseBsr);
sparse::impl::mkl::triangular_solve_out_sparse_csr(A, B, X, upper, transpose, unitriangular);
#endif
}
} // namespace cpu
} // namespace impl
} // namespace sparse
} // namespace native
} // namespace at