forked from saprmarks/dictionary_learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
kernels.py
433 lines (368 loc) · 11.6 KB
/
kernels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
"""
Copied from https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py
"""
import torch
import triton
import triton.language as tl
def triton_sparse_transpose_dense_matmul(
sparse_indices: torch.Tensor,
sparse_values: torch.Tensor,
dense: torch.Tensor,
N: int,
BLOCK_SIZE_AK=128,
) -> torch.Tensor:
"""
calculates sparse.T @ dense (i.e reducing along the collated dimension of sparse)
dense must be contiguous along dim 0 (in other words, dense.T is contiguous)
sparse_indices is shape (A, k)
sparse_values is shape (A, k)
dense is shape (A, B)
output is shape (N, B)
"""
assert sparse_indices.shape == sparse_values.shape
assert sparse_indices.is_contiguous()
assert sparse_values.is_contiguous()
assert dense.is_contiguous() # contiguous along B
K = sparse_indices.shape[1]
A = dense.shape[0]
B = dense.shape[1]
assert sparse_indices.shape[0] == A
# COO-format and sorted
sorted_indices = sparse_indices.view(-1).sort()
coo_indices = torch.stack(
[
torch.arange(A, device=sparse_indices.device).repeat_interleave(K)[
sorted_indices.indices
],
sorted_indices.values,
]
) # shape (2, A * K)
coo_values = sparse_values.view(-1)[sorted_indices.indices] # shape (A * K,)
return triton_coo_sparse_dense_matmul(
coo_indices, coo_values, dense, N, BLOCK_SIZE_AK
)
def triton_coo_sparse_dense_matmul(
coo_indices: torch.Tensor,
coo_values: torch.Tensor,
dense: torch.Tensor,
N: int,
BLOCK_SIZE_AK=128,
) -> torch.Tensor:
AK = coo_indices.shape[1]
B = dense.shape[1]
out = torch.zeros(N, B, device=dense.device, dtype=coo_values.dtype)
grid = lambda META: (
triton.cdiv(AK, META["BLOCK_SIZE_AK"]),
1,
)
triton_sparse_transpose_dense_matmul_kernel[grid](
coo_indices,
coo_values,
dense,
out,
stride_da=dense.stride(0),
stride_db=dense.stride(1),
B=B,
N=N,
AK=AK,
BLOCK_SIZE_AK=BLOCK_SIZE_AK,
BLOCK_SIZE_B=triton.next_power_of_2(B),
)
return out
@triton.jit
def triton_sparse_transpose_dense_matmul_kernel(
coo_indices_ptr,
coo_values_ptr,
dense_ptr,
out_ptr,
stride_da,
stride_db,
B,
N,
AK,
BLOCK_SIZE_AK: tl.constexpr,
BLOCK_SIZE_B: tl.constexpr,
):
"""
coo_indices is shape (2, AK)
coo_values is shape (AK,)
dense is shape (A, B), contiguous along B
out is shape (N, B)
"""
pid_ak = tl.program_id(0)
pid_b = tl.program_id(1)
coo_offsets = tl.arange(0, BLOCK_SIZE_AK)
b_offsets = tl.arange(0, BLOCK_SIZE_B)
A_coords = tl.load(
coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets,
mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK,
)
K_coords = tl.load(
coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets + AK,
mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK,
)
values = tl.load(
coo_values_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets,
mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK,
)
last_k = tl.min(K_coords)
accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32)
for ind in range(BLOCK_SIZE_AK):
if ind + pid_ak * BLOCK_SIZE_AK < AK:
# workaround to do A_coords[ind]
a = tl.sum(
tl.where(
tl.arange(0, BLOCK_SIZE_AK) == ind,
A_coords,
tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64),
)
)
k = tl.sum(
tl.where(
tl.arange(0, BLOCK_SIZE_AK) == ind,
K_coords,
tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64),
)
)
v = tl.sum(
tl.where(
tl.arange(0, BLOCK_SIZE_AK) == ind,
values,
tl.zeros((BLOCK_SIZE_AK,), dtype=tl.float32),
)
)
tl.device_assert(k < N)
if k != last_k:
tl.atomic_add(
out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets,
accum,
mask=BLOCK_SIZE_B * pid_b + b_offsets < B,
)
accum *= 0
last_k = k
if v != 0:
accum += v * tl.load(
dense_ptr + a * stride_da + b_offsets, mask=b_offsets < B
)
tl.atomic_add(
out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets,
accum,
mask=BLOCK_SIZE_B * pid_b + b_offsets < B,
)
def triton_sparse_dense_matmul(
sparse_indices: torch.Tensor,
sparse_values: torch.Tensor,
dense: torch.Tensor,
) -> torch.Tensor:
"""
calculates sparse @ dense (i.e reducing along the uncollated dimension of sparse)
dense must be contiguous along dim 0 (in other words, dense.T is contiguous)
sparse_indices is shape (A, k)
sparse_values is shape (A, k)
dense is shape (N, B)
output is shape (A, B)
"""
N = dense.shape[0]
assert sparse_indices.shape == sparse_values.shape
assert sparse_indices.is_contiguous()
assert sparse_values.is_contiguous()
assert dense.is_contiguous() # contiguous along B
A = sparse_indices.shape[0]
K = sparse_indices.shape[1]
B = dense.shape[1]
out = torch.zeros(A, B, device=dense.device, dtype=sparse_values.dtype)
with torch.cuda.device(dense.device.index):
triton_sparse_dense_matmul_kernel[(A,)](
sparse_indices,
sparse_values,
dense,
out,
stride_dn=dense.stride(0),
stride_db=dense.stride(1),
A=A,
B=B,
N=N,
K=K,
BLOCK_SIZE_K=triton.next_power_of_2(K),
BLOCK_SIZE_B=triton.next_power_of_2(B),
)
return out
@triton.jit
def triton_sparse_dense_matmul_kernel(
sparse_indices_ptr,
sparse_values_ptr,
dense_ptr,
out_ptr,
stride_dn,
stride_db,
A,
B,
N,
K,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_B: tl.constexpr,
):
"""
sparse_indices is shape (A, K)
sparse_values is shape (A, K)
dense is shape (N, B), contiguous along B
out is shape (A, B)
"""
pid = tl.program_id(0)
offsets_k = tl.arange(0, BLOCK_SIZE_K)
sparse_indices = tl.load(
sparse_indices_ptr + pid * K + offsets_k, mask=offsets_k < K
) # shape (K,)
sparse_values = tl.load(
sparse_values_ptr + pid * K + offsets_k, mask=offsets_k < K
) # shape (K,)
accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32)
offsets_b = tl.arange(0, BLOCK_SIZE_B)
for k in range(K):
# workaround to do sparse_indices[k]
i = tl.sum(
tl.where(
tl.arange(0, BLOCK_SIZE_K) == k,
sparse_indices,
tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64),
)
)
# workaround to do sparse_values[k]
v = tl.sum(
tl.where(
tl.arange(0, BLOCK_SIZE_K) == k,
sparse_values,
tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32),
)
)
tl.device_assert(i < N)
if v != 0:
accum += v * tl.load(
dense_ptr + i * stride_dn + offsets_b * stride_db, mask=offsets_b < B
)
tl.store(
out_ptr + pid * B + offsets_b, accum.to(sparse_values.dtype), mask=offsets_b < B
)
def triton_dense_dense_sparseout_matmul(
dense1: torch.Tensor,
dense2: torch.Tensor,
at_indices: torch.Tensor,
) -> torch.Tensor:
"""
dense1: shape (A, B)
dense2: shape (B, N)
at_indices: shape (A, K)
out values: shape (A, K)
calculates dense1 @ dense2 only for the indices in at_indices
equivalent to (dense1 @ dense2).gather(1, at_indices)
"""
A, B = dense1.shape
N = dense2.shape[1]
assert dense2.shape[0] == B
assert at_indices.shape[0] == A
K = at_indices.shape[1]
assert at_indices.is_contiguous()
assert dense1.stride(1) == 1, "dense1 must be contiguous along B"
assert dense2.stride(0) == 1, "dense2 must be contiguous along B"
if K > 512:
# print("WARN - using naive matmul for large K")
# naive is more efficient for large K
return (dense1 @ dense2).gather(1, at_indices)
out = torch.zeros(A, K, device=dense1.device, dtype=dense1.dtype)
# grid = lambda META: (triton.cdiv(A, META['BLOCK_SIZE_A']),)
triton_dense_dense_sparseout_matmul_kernel[(A,)](
dense1,
dense2,
at_indices,
out,
stride_d1a=dense1.stride(0),
stride_d1b=dense1.stride(1),
stride_d2b=dense2.stride(0),
stride_d2n=dense2.stride(1),
A=A,
B=B,
N=N,
K=K,
BLOCK_SIZE_B=triton.next_power_of_2(B),
BLOCK_SIZE_N=triton.next_power_of_2(N),
BLOCK_SIZE_K=triton.next_power_of_2(K),
)
return out
@triton.jit
def triton_dense_dense_sparseout_matmul_kernel(
dense1_ptr,
dense2_ptr,
at_indices_ptr,
out_ptr,
stride_d1a,
stride_d1b,
stride_d2b,
stride_d2n,
A,
B,
N,
K,
BLOCK_SIZE_B: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
"""
dense1: shape (A, B)
dense2: shape (B, N)
at_indices: shape (A, K)
out values: shape (A, K)
"""
pid = tl.program_id(0)
offsets_k = tl.arange(0, BLOCK_SIZE_K)
at_indices = tl.load(
at_indices_ptr + pid * K + offsets_k, mask=offsets_k < K
) # shape (K,)
offsets_b = tl.arange(0, BLOCK_SIZE_B)
dense1 = tl.load(
dense1_ptr + pid * stride_d1a + offsets_b * stride_d1b, mask=offsets_b < B
) # shape (B,)
accum = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
for k in range(K):
# workaround to do at_indices[b]
i = tl.sum(
tl.where(
tl.arange(0, BLOCK_SIZE_K) == k,
at_indices,
tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64),
)
)
tl.device_assert(i < N)
dense2col = tl.load(
dense2_ptr + offsets_b * stride_d2b + i * stride_d2n, mask=offsets_b < B
) # shape (B,)
accum += tl.where(
tl.arange(0, BLOCK_SIZE_K) == k,
tl.sum(dense1 * dense2col),
tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64),
)
tl.store(out_ptr + pid * K + offsets_k, accum, mask=offsets_k < K)
class TritonDecoder(torch.autograd.Function):
@staticmethod
def forward(ctx, sparse_indices, sparse_values, decoder_weight):
ctx.save_for_backward(sparse_indices, sparse_values, decoder_weight)
return triton_sparse_dense_matmul(
sparse_indices, sparse_values, decoder_weight.T
)
@staticmethod
def backward(ctx, grad_output):
sparse_indices, sparse_values, decoder_weight = ctx.saved_tensors
assert (
grad_output.is_contiguous()
), "grad_output must be contiguous; this is probably because the subsequent op was a .sum() or something like that, which returns a non contiguous gradient"
decoder_grad = triton_sparse_transpose_dense_matmul(
sparse_indices, sparse_values, grad_output, N=decoder_weight.shape[1]
).T
return (
None,
triton_dense_dense_sparseout_matmul(
grad_output, decoder_weight, sparse_indices
),
# decoder is contiguous when transposed so this is a matching layout
decoder_grad,
None,
)