Skip to content

Commit

Permalink
Adding Q Bits
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 23, 2023
1 parent 561e59a commit ffb7ce4
Show file tree
Hide file tree
Showing 10 changed files with 1,354 additions and 11 deletions.
2 changes: 2 additions & 0 deletions fjformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@
from .utils import (
JaxRNG, GenerateRNG, init_rng, next_rng, count_num_params
)

__version__ = '0.0.8'
12 changes: 11 additions & 1 deletion fjformer/bits/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
from .bits import matmul_true_int8, matmul, aqt_matmul_int8
# Why Do We have BITs here?

# Bits is a clone of https://github.com/google/aqt

# AQT library by google, but I needed to be able to change a lot of parts of the code and I couldn't do this on Google
# Repo for sure so I just made a copy of the part of the library that I wanted (AQT is Apache 2 Licenced Project)
# Ill more focus on doing the job for the llm load in 8 bit so there's no need to project name be AQT Accurate QTraining

from .bits import matmul_true_int8, matmul, q_matmul_int8
from .q_dot_general import make_dot_general, make_fake_quant, DotGeneralRes
from .q_flax import QuantMode, Freezer, QDotGeneral, QEinsum
16 changes: 7 additions & 9 deletions fjformer/bits/bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def matmul(
b: Right-hand side of the matmul.
precision: Indicates precision of a and b.
dot_general: lax.dot_general by default. To use quantized matmul, the
wrapper of aqt_dot_general in which TQs and `train` flag are provided
wrapper of q_dot_general in which TQs and `train` flag are provided
should be passed into this function.
Returns:
Expand Down Expand Up @@ -106,17 +106,15 @@ def matmul_true_int8(lhs, rhs):
return result


def aqt_matmul_int8(a, w):
max_int8 = 127
def quant_int8(x):
return jnp.clip(jnp.round(x), -127, 127).astype(jnp.int8)

# This function is customizable and injectable, i.e:
# users can inject custom quant code into an AQT config.
def quant_int8(x):
return jnp.clip(jnp.round(x), -max_int8, max_int8).astype(jnp.int8)

def q_matmul_int8(a, w):

# Calibration. Calibration function is also customizable and injectable.
a_s = max_int8 / jnp.max(jnp.abs(a), axis=1, keepdims=True)
w_s = max_int8 / jnp.max(jnp.abs(w), axis=0, keepdims=True)
a_s = 127 / jnp.max(jnp.abs(a), axis=1, keepdims=True)
w_s = 127 / jnp.max(jnp.abs(w), axis=0, keepdims=True)

# int8 matmul with int32 accumulator
result = matmul_true_int8(quant_int8(a * a_s), quant_int8(w * w_s)) / (a_s * w_s)
Expand Down
52 changes: 52 additions & 0 deletions fjformer/bits/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quantization calibration methods."""

import abc
import flax.struct
import jax.numpy as jnp


class Calibration(abc.ABC):

@abc.abstractmethod
def get_bound(self, x, shared_axes) -> jnp.ndarray:
pass


@flax.struct.dataclass
class ConstantCalibration(Calibration):
bound: jnp.ndarray | float

def get_bound(self, x, shared_axes) -> jnp.ndarray:
"""Calibration."""
del shared_axes
assert self.bound > 0, 'Bound should be positive.'
return jnp.asarray(self.bound).reshape((1,) * len(x.shape))


@flax.struct.dataclass
class AbsMaxCalibration(Calibration):
"""Simple max(abs(x)) calibration."""

def get_bound(self, x, shared_axes) -> jnp.ndarray:
"""Calibration."""
msg = 'Perhaps you are using fake_quant and forgot to set them.'
assert shared_axes is not None, msg

# NOTE: If you want to clip, consider using clip and clip_gradient in
# int_numerics.IntNumerics.
abs_max = jnp.max(jnp.abs(x), axis=shared_axes, keepdims=True)
abs_max = jnp.where(abs_max == 0.0, jnp.ones_like(abs_max), abs_max)
return abs_max
Loading

0 comments on commit ffb7ce4

Please sign in to comment.