Skip to content

Commit

Permalink
Adding Bits
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 22, 2023
1 parent 8748109 commit 561e59a
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 0 deletions.
1 change: 1 addition & 0 deletions fjformer/bits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .bits import matmul_true_int8, matmul, aqt_matmul_int8
124 changes: 124 additions & 0 deletions fjformer/bits/bits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import builtins

import jax
from jax import core
from jax import lax
import jax.numpy as jnp
import numpy as np

shape = np.shape
ndim = np.ndim
_max = builtins.max


def matmul(
a: jnp.ndarray,
b: jnp.ndarray,
*,
precision=None,
dot_general=lax.dot_general
) -> jnp.ndarray:
"""Quantized jax.numpy.matmul.
Args:
a: Left-hand side of the matmul.
b: Right-hand side of the matmul.
precision: Indicates precision of a and b.
dot_general: lax.dot_general by default. To use quantized matmul, the
wrapper of aqt_dot_general in which TQs and `train` flag are provided
should be passed into this function.
Returns:
An array containing the result with the same dtype as 'a' and 'b'.
"""
arraylike = (jax.Array, np.ndarray)
if not isinstance(a, arraylike) or not isinstance(b, arraylike):
raise TypeError(f"matmul requires array-like arguments, got {a} and {b}")
for i, x in enumerate((a, b)):
if ndim(x) < 1:
msg = (f"matmul input operand {i} must have ndim at least 1, "
f"but it has ndim {ndim(x)}")
raise ValueError(msg)

dtype = jnp.result_type(a.dtype, b.dtype)
a = a.astype(dtype)
b = b.astype(dtype)

a_is_mat, b_is_mat = (ndim(a) > 1), (ndim(b) > 1)
a_batch_dims = shape(a)[:-2] if a_is_mat else ()
b_batch_dims = shape(b)[:-2] if b_is_mat else ()
num_batch_dims = _max(len(a_batch_dims), len(b_batch_dims))
a_batch_dims = (None,) * (num_batch_dims - len(a_batch_dims)) + a_batch_dims
b_batch_dims = (None,) * (num_batch_dims - len(b_batch_dims)) + b_batch_dims

# Dimensions to squeeze from the inputs.
a_squeeze = []
b_squeeze = []

# Positions of batch dimensions in squeezed inputs.
a_batch = []
b_batch = []

# Desired index in final output of each kind of dimension, in the order that
# aqt_dot_general will emit them.
idx_batch = []
idx_a_other = [] # other = non-batch, non-contracting.
idx_b_other = []
for i, (ba, bb) in enumerate(zip(a_batch_dims, b_batch_dims)):
if ba is None:
idx_b_other.append(i)
elif bb is None:
idx_a_other.append(i)
elif core.symbolic_equal_dim(ba, 1):
idx_b_other.append(i)
a_squeeze.append(len(idx_batch) + len(idx_a_other) + len(a_squeeze))
elif core.symbolic_equal_dim(bb, 1):
idx_a_other.append(i)
b_squeeze.append(len(idx_batch) + len(idx_b_other) + len(b_squeeze))
elif core.symbolic_equal_dim(ba, bb):
a_batch.append(len(idx_batch) + len(idx_a_other))
b_batch.append(len(idx_batch) + len(idx_b_other))
idx_batch.append(i)
else:
raise ValueError("Incompatible shapes for matmul arguments: {} and {}"
.format(shape(a), shape(b)))

if a_is_mat:
idx_a_other.append(num_batch_dims)
if b_is_mat:
idx_b_other.append(num_batch_dims + a_is_mat)
perm = np.argsort(np.concatenate([idx_batch, idx_a_other, idx_b_other]))

a = lax.squeeze(a, tuple(a_squeeze))
b = lax.squeeze(b, tuple(b_squeeze))
out = dot_general(
a,
b, (((ndim(a) - 1,), (ndim(b) - 1 - b_is_mat,)), (a_batch, b_batch)),
precision=precision)
return lax.transpose(out, perm)


def matmul_true_int8(lhs, rhs):
assert lhs.dtype == jnp.int8
assert rhs.dtype == jnp.int8
result = jnp.matmul(lhs, rhs, preferred_element_type=jnp.int32)
assert result.dtype == jnp.int32
return result


def aqt_matmul_int8(a, w):
max_int8 = 127

# This function is customizable and injectable, i.e:
# users can inject custom quant code into an AQT config.
def quant_int8(x):
return jnp.clip(jnp.round(x), -max_int8, max_int8).astype(jnp.int8)

# Calibration. Calibration function is also customizable and injectable.
a_s = max_int8 / jnp.max(jnp.abs(a), axis=1, keepdims=True)
w_s = max_int8 / jnp.max(jnp.abs(w), axis=0, keepdims=True)

# int8 matmul with int32 accumulator
result = matmul_true_int8(quant_int8(a * a_s), quant_int8(w * w_s)) / (a_s * w_s)

return result

0 comments on commit 561e59a

Please sign in to comment.