diff --git a/fjformer/__init__.py b/fjformer/__init__.py index 6900051..283a32f 100644 --- a/fjformer/__init__.py +++ b/fjformer/__init__.py @@ -26,4 +26,4 @@ JaxRNG, GenerateRNG, init_rng, next_rng, count_num_params ) -__version__ = '0.0.9' +__version__ = '0.0.10' diff --git a/fjformer/bits/q_flax.py b/fjformer/bits/q_flax.py index d3df438..4bedabb 100644 --- a/fjformer/bits/q_flax.py +++ b/fjformer/bits/q_flax.py @@ -55,7 +55,7 @@ class QDotGeneral(nn.Module): """A layer that can be injected into flax.nn.Dense, etc.""" cfg: Optional[Union[config.DotGeneral, None]] = None - prng_name: Optional[Union[str, None]] = None + prng_name: Optional[Union[str, None]] = 'params' @nn.compact def __call__( @@ -83,7 +83,7 @@ class QEinsum(nn.Module): """Quantized Einsum class for model injection.""" cfg: Optional[Union[config.DotGeneral, None]] = None - prng_name: Optional[Union[str, None]] = None + prng_name: Optional[Union[str, None]] = 'params' @nn.compact def __call__(self, eqn, lhs, rhs): diff --git a/setup.py b/setup.py index c026279..fb8e461 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setuptools.setup( name="fjformer", - version='0.0.9', + version='0.0.10', author="Erfan Zare Chavoshi", author_email="erfanzare82@yahoo.com", long_description=long_description,