-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
1,354 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,5 @@ | |
from .utils import ( | ||
JaxRNG, GenerateRNG, init_rng, next_rng, count_num_params | ||
) | ||
|
||
__version__ = '0.0.8' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,11 @@ | ||
from .bits import matmul_true_int8, matmul, aqt_matmul_int8 | ||
# Why Do We have BITs here? | ||
|
||
# Bits is a clone of https://github.com/google/aqt | ||
|
||
# AQT library by google, but I needed to be able to change a lot of parts of the code and I couldn't do this on Google | ||
# Repo for sure so I just made a copy of the part of the library that I wanted (AQT is Apache 2 Licenced Project) | ||
# Ill more focus on doing the job for the llm load in 8 bit so there's no need to project name be AQT Accurate QTraining | ||
|
||
from .bits import matmul_true_int8, matmul, q_matmul_int8 | ||
from .q_dot_general import make_dot_general, make_fake_quant, DotGeneralRes | ||
from .q_flax import QuantMode, Freezer, QDotGeneral, QEinsum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Quantization calibration methods.""" | ||
|
||
import abc | ||
import flax.struct | ||
import jax.numpy as jnp | ||
|
||
|
||
class Calibration(abc.ABC): | ||
|
||
@abc.abstractmethod | ||
def get_bound(self, x, shared_axes) -> jnp.ndarray: | ||
pass | ||
|
||
|
||
@flax.struct.dataclass | ||
class ConstantCalibration(Calibration): | ||
bound: jnp.ndarray | float | ||
|
||
def get_bound(self, x, shared_axes) -> jnp.ndarray: | ||
"""Calibration.""" | ||
del shared_axes | ||
assert self.bound > 0, 'Bound should be positive.' | ||
return jnp.asarray(self.bound).reshape((1,) * len(x.shape)) | ||
|
||
|
||
@flax.struct.dataclass | ||
class AbsMaxCalibration(Calibration): | ||
"""Simple max(abs(x)) calibration.""" | ||
|
||
def get_bound(self, x, shared_axes) -> jnp.ndarray: | ||
"""Calibration.""" | ||
msg = 'Perhaps you are using fake_quant and forgot to set them.' | ||
assert shared_axes is not None, msg | ||
|
||
# NOTE: If you want to clip, consider using clip and clip_gradient in | ||
# int_numerics.IntNumerics. | ||
abs_max = jnp.max(jnp.abs(x), axis=shared_axes, keepdims=True) | ||
abs_max = jnp.where(abs_max == 0.0, jnp.ones_like(abs_max), abs_max) | ||
return abs_max |
Oops, something went wrong.