Skip to content

Commit

Permalink
adding quantizers
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Apr 10, 2024
1 parent 9fe85f4 commit e33d320
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 23 additions & 1 deletion src/fjformer/linen/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,23 @@
from .linear import Linear, LinearBitKernel, quantize, de_quantize
from .linear import (
Linear as Linear,
LinearBitKernel as LinearBitKernel,
quantize as quantize,
de_quantize as de_quantize,
quantize_params as quantize_params,
de_quantize_params as de_quantize_params,
Conv as Conv,
Embed as Embed,
promote_dtype as promote_dtype
)

__all__ = (
"Linear",
"LinearBitKernel",
"quantize",
"de_quantize",
"quantize_params",
"de_quantize_params",
"Conv",
"Embed",
"promote_dtype"
)
91 changes: 80 additions & 11 deletions src/fjformer/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@
Sequence,
Tuple,
Union,
Mapping
)

import chex
import flax.traverse_util
import jax.tree_util
from flax.linen import initializers
from flax.linen.dtypes import promote_dtype
from flax.linen.module import compact
Expand All @@ -36,6 +40,7 @@
from jax.core import ShapedArray
import jax.numpy as jnp
import numpy as np
from ..partition_utils import with_sharding_constraint

PRNGKey = Any
Shape = Tuple[int, ...]
Expand All @@ -58,12 +63,9 @@ def quantize(
array: jnp.ndarray,
int_dtype: jnp.dtype = jnp.int8,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
scale = (
jnp.max(array) - jnp.min(array)
) / (
jnp.iinfo(int_dtype).max - jnp.iinfo(int_dtype).min
)
return round(array / scale), scale
max_scale = (jnp.iinfo(int_dtype).max + abs(jnp.iinfo(int_dtype).min)) / 2
scale = jnp.max(jnp.abs(array), axis=-1, keepdims=True)
return jnp.int8(jnp.rint(array * (max_scale / scale))), scale


def de_quantize(
Expand All @@ -72,14 +74,15 @@ def de_quantize(
float_dtype: jnp.dtype = jnp.float16,
threshold: float = 1e-6
):
return (quantized.astype(float_dtype) * scale) + threshold
max_scale = (jnp.iinfo(quantized.dtype).max + abs(jnp.iinfo(quantized.dtype).min)) / 2
return ((quantized.astype(float_dtype) * scale) / max_scale) + threshold


@dataclasses.dataclass
class LinearBitKernel:
kernel: Array
scale: Optional[float] = .0
_is_quantized: bool = False
scale: Array
_is_quantized: bool = True

@property
def shape(self):
Expand All @@ -90,6 +93,50 @@ def quantized(self):
return self._is_quantized


jax.tree_util.register_pytree_node(
LinearBitKernel,
lambda x: ([x.kernel, x.scale, x.quantized], ()),
lambda _, children: LinearBitKernel(children[0], children[1], children[2])
)


def quantize_params(
params: jax.tree_util.PyTreeDef,
quantize_dtype: jnp.dtype = jnp.int8
):
return jax.tree_util.tree_map(
lambda prm: LinearBitKernel(
*quantize(
prm, quantize_dtype
)
),
params
)


def de_quantize_params(
params: jax.tree_util.PyTreeDef,
dtype: jnp.dtype = jnp.float32,
shard_funcs: Optional[Mapping[str, Callable[[chex.Array], chex.Array]]] = None
):
def _q(prm):
if isinstance(prm, LinearBitKernel):
return jnp.array(
de_quantize(
prm.kernel, prm.scale, dtype, 0
)
)
return prm

prm = flax.traverse_util.flatten_dict(params)
for key in list(prm.keys()):
value = _q(prm[key])
if shard_funcs is not None:
value = shard_funcs[key](value)
prm[key] = value
return flax.traverse_util.unflatten_dict(prm)


def _normalize_axes(axes: Tuple[int, ...], ndim: int) -> Tuple[int, ...]:
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
return tuple(sorted(ax if ax >= 0 else ndim + ax for ax in axes))
Expand Down Expand Up @@ -143,14 +190,36 @@ def __call__(self, inputs: Array) -> Array:
(jnp.shape(inputs)[-1], self.features),
self.param_dtype,
)
if isinstance(kernel, LinearBitKernel):
org_sharding = kernel.kernel.sharding
kernel = de_quantize(
kernel.kernel,
kernel.scale,
self.param_dtype,
.0
)

kernel = jax.device_put(kernel, org_sharding)

if self.use_bias:
bias = self.param(
"bias", self.bias_init, (self.features,), self.param_dtype
"bias",
self.bias_init,
(self.features,),
self.param_dtype
)
if isinstance(bias, LinearBitKernel):
org_sharding = bias.kernel.sharding
bias = de_quantize(
bias.kernel,
bias.scale,
self.param_dtype,
.0
)
bias = jax.device_put(bias, org_sharding)
else:
bias = None
inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)

if self.dot_general_cls is not None:
dot_general = self.dot_general_cls()
elif self.dot_general is not None:
Expand Down
29 changes: 29 additions & 0 deletions test/linear_bit_kernel_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import jax.random
from jax import numpy as jnp
import src.fjformer.linen as nn
from src.fjformer import GenerateRNG


def main():
rng_gen = GenerateRNG(42)
neuron = nn.Linear(
4,
use_bias=True
)

params = neuron.init(
rng_gen.rng,
jax.random.normal(rng_gen.rng, (1, 68, 4))
)

quantized_params = nn.quantize_params(params)

inputs = jax.random.normal(rng_gen.rng, (1, 1, 4))

org_pred = neuron.apply(params, inputs)
qun_pred = neuron.apply(quantized_params, inputs)
print(jnp.allclose(org_pred, qun_pred, rtol=1e-2, atol=1e-8))


if __name__ == '__main__':
main()

0 comments on commit e33d320

Please sign in to comment.