Skip to content

Commit

Permalink
Fixing Bugs on python 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 23, 2023
1 parent ffb7ce4 commit 2725e2b
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 18 deletions.
2 changes: 1 addition & 1 deletion fjformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@
JaxRNG, GenerateRNG, init_rng, next_rng, count_num_params
)

__version__ = '0.0.8'
__version__ = '0.0.9'
4 changes: 3 additions & 1 deletion fjformer/bits/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""Quantization calibration methods."""

import abc
from typing import Union

import flax.struct
import jax.numpy as jnp

Expand All @@ -27,7 +29,7 @@ def get_bound(self, x, shared_axes) -> jnp.ndarray:

@flax.struct.dataclass
class ConstantCalibration(Calibration):
bound: jnp.ndarray | float
bound: Union[jnp.ndarray, float]

def get_bound(self, x, shared_axes) -> jnp.ndarray:
"""Calibration."""
Expand Down
10 changes: 5 additions & 5 deletions fjformer/bits/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def dot_general_make(

def fully_quantized(
*,
fwd_bits: int | None = 8,
bwd_bits: int | None = 8,
fwd_bits: Union[int, None] = 8,
bwd_bits: Union[int, None] = 8,
use_fwd_quant: bool = True,
use_stochastic_rounding: Optional[bool] = True,
# Typically we have (but it's a caller's responsibility to check):
Expand Down Expand Up @@ -332,9 +332,9 @@ def fully_quantized(

def config_v3(
*,
fwd_bits: int | None,
dlhs_bits: int | None,
drhs_bits: int | None,
fwd_bits: Union[int, None],
dlhs_bits: Union[int, None],
drhs_bits: Union[int, None],
use_dummy_static_bound: bool = False,
rng_type: str = 'jax.uniform', # 'custom-1'
dlhs_local_q: Optional[LocalQ] = None,
Expand Down
21 changes: 11 additions & 10 deletions fjformer/bits/q_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import int_numerics
import flax.linen as nn
import jax.numpy as jnp
from typing import Optional, Union


class Freezer(nn.Module, config.Preprocess):
Expand Down Expand Up @@ -53,8 +54,8 @@ def __call__(self, inputs):
class QDotGeneral(nn.Module):
"""A layer that can be injected into flax.nn.Dense, etc."""

cfg: config.DotGeneral | None = None
prng_name: str | None = 'params'
cfg: Optional[Union[config.DotGeneral, None]] = None
prng_name: Optional[Union[str, None]] = None

@nn.compact
def __call__(
Expand All @@ -81,8 +82,8 @@ def __call__(
class QEinsum(nn.Module):
"""Quantized Einsum class for model injection."""

cfg: config.DotGeneral | None = None
prng_name: str | None = 'params'
cfg: Optional[Union[config.DotGeneral, None]] = None
prng_name: Optional[Union[str, None]] = None

@nn.compact
def __call__(self, eqn, lhs, rhs):
Expand Down Expand Up @@ -128,14 +129,14 @@ def set_lhs_quant_mode(

def config_v4(
*,
fwd_bits: int | None,
dlhs_bits: int | None,
drhs_bits: int | None,
fwd_bits: Union[int, None],
dlhs_bits: Union[int, None],
drhs_bits: Union[int, None],
# The dummy static bound flag is for performance benchmarking.
use_dummy_static_bound: bool = False,
rng_type: str = 'jax.uniform', # 'custom-1'
dlhs_local_q: config.LocalQ | None = None,
drhs_local_q: config.LocalQ | None = None,
dlhs_local_q: Union[config.LocalQ, None] = None,
drhs_local_q: Union[config.LocalQ, None] = None,
fwd_accumulator_dtype: ... = jnp.int32,
dlhs_accumulator_dtype: ... = jnp.int32,
drhs_accumulator_dtype: ... = jnp.int32,
Expand All @@ -145,7 +146,7 @@ def config_v4(
) -> config.DotGeneral:
"""Version 4 of user-visible AQT config."""

def tensor_config(bits: int | None) -> config.Tensor:
def tensor_config(bits: Union[int, None]) -> config.Tensor:
assert bits is None or bits >= 2, 'Need at least 2 bits.'
if bits is None:
numerics = config.NoNumerics()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setuptools.setup(
name="fjformer",
version='0.0.8',
version='0.0.9',
author="Erfan Zare Chavoshi",
author_email="[email protected]",
long_description=long_description,
Expand Down

0 comments on commit 2725e2b

Please sign in to comment.