-
Notifications
You must be signed in to change notification settings - Fork 20
/
AfterSessionExercises.txt
36 lines (23 loc) · 2.23 KB
/
AfterSessionExercises.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
(1) Consider the function f(A,B) = jax.nn.relu(A@B). (Assume A and B are square martrices.)
What percentage faster will jit(f) be than f? Does it depend on the size of A and B?
(2) Assume A is a [SEQUENCE, KV_HEAD] matrix, B is a [KV_HEAD, SEQUENCE] matrix and C is a [KV_HEAD, SEQUENCE] matrix.
So (A@B) is [SEQUENCE, SEQUENCE] and C @ jax.nn.relu(A @ B) is [KV_HEAD, SEQUENCE] matrix.
Assume KV_HEAD is 128 and SEQUENCE is 10000. How many bytes is the input and the output? How many FLOPs?
Actually running the code an you in practice get near the theoretical TFLOP limit in Jax?
##################################
(1 - solution) The JIT will allow the relu to be fused into A@B. If the matrix dimension is N, the ideal HBM traffic might be
2 bytes per param * (2*N^2) input params and 2 bytes per param * (N^2) output params = 6 * N^2 bytes.
Without the fusion, there will be an additional round trip (4 * N^2 bytes total).
So if we're memory bound, the time for f will be proportional to 10 * N^2 bytes and the time for jit(f) will be proportional
to 6 * N^2 bytes the jit will be 40% faster.
(2 - solution) The input is 3 matrices each of 2 * KV_HEAD * SEQUENCE. The output is one of 2 * KV_HEAD * SEQUENCE.
So the total input and output bandwidth use is 8 * KV_HEAD * SEQUENCE.
The two matrix multiplies are each 2*KV_HEAD * SEQUENCE^2 so the total flops are 4 * KV_HEAD * SEQUENCE^2.
So in theory our arithmetic intensity will be (4 * KV_HEAD * SEQUENCE^2)/(8 * KV_HEAD * SEQUENCE) = SEQUENCE/2.
In practice, this won't happen! A@B is [SEQUENCE, SEQUENCE] and I expect the XLA compiler will not find a strategy that avoids
writing it back to HBM! And if we need to write it to HBM and back that adds 2*2*SEQUENCE^2 bytes of HBM traffic, much more
than we were using before! Now our bandwidth use is 4*SEQUENCE^2 + 8 * KV_HEAD * SEQUENCE which rounds to 4*SEQUENCE^2.
And our arithmetic intensity is 4 * KV_HEAD * SEQUENCE^2 / 4*SEQUENCE^2 = KV_HEAD which is less that the arithmetic intensity
on the reference hardware in class (v4 TPU).
P.S. -- solving this problem for a slightly different setting is FlashAttention! Also see what happens if you remove the relu!
Without a JIT is `C @ (A @ B)` or `(C @ A) @ B` faster? With a JIT are you still memory bound?