Skip to content

Commit

Permalink
making imports more efficient
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 4, 2023
1 parent 4fe8033 commit f0b641f
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 39 deletions.
28 changes: 10 additions & 18 deletions fjformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
from .attention import (dot_product_attention_multiquery, dot_product_attention_multihead,
dot_product_attention_queries_per_head, efficient_attention)
from .load import (
load_and_convert_checkpoint_to_torch, float_tensor_to_dtype, read_ckpt, save_ckpt, StreamingCheckpointer
from fjformer.load import (
load_and_convert_checkpoint_to_torch, float_tensor_to_dtype, read_ckpt, save_ckpt, StreamingCheckpointer,
get_float_dtype_by_name
)

from .optimizers import (
get_adamw_with_cosine_scheduler, get_adamw_with_warm_up_cosine_scheduler,
get_adamw_with_warmup_linear_scheduler, get_adamw_with_linear_scheduler,
get_lion_with_cosine_scheduler, get_lion_with_with_warmup_linear_scheduler,
get_lion_with_warm_up_cosine_scheduler, get_lion_with_linear_scheduler,
get_adafactor_with_cosine_scheduler, get_adafactor_with_warm_up_cosine_scheduler,
get_adafactor_with_warmup_linear_scheduler, get_adafactor_with_linear_scheduler,
optax_add_scheduled_weight_decay

)

from .partition_utils import (
from fjformer.partition_utils import (
get_jax_mesh, names_in_current_mesh, get_names_from_partition_spec, match_partition_rules,
flatten_tree, get_metrics, tree_apply, get_weight_decay_mask, named_tree_map,
tree_path_to_string, make_shard_and_gather_fns, with_sharding_constraint,
wrap_function_with_rng
)

from .monitor import (
from fjformer.monitor import (
run, get_mem, is_notebook, threaded_log, initialise_tracking
)

from .datasets import (
from fjformer.datasets import (
get_dataloader
)

from .func import (
transpose, global_norm, average_metrics
)

from .utils import (
JaxRNG, GenerateRNG, init_rng, next_rng, count_num_params
)
3 changes: 2 additions & 1 deletion fjformer/load/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .streamer import StreamingCheckpointer
from ._load import float_tensor_to_dtype, read_ckpt, save_ckpt, load_and_convert_checkpoint_to_torch
from ._load import (float_tensor_to_dtype, read_ckpt, save_ckpt, load_and_convert_checkpoint_to_torch,
get_float_dtype_by_name)
43 changes: 43 additions & 0 deletions fjformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,46 @@ def count_num_params(_p):

def count_params(_p):
print('\033[1;31mModel Contain : ', count_num_params(_p) / 1e9, ' Billion Parameters')


class JaxRNG(object):
@classmethod
def from_seed(cls, seed):
return cls(jax.random.PRNGKey(seed))

def __init__(self, rng):
self.rng = rng

def __call__(self, keys=None):
if keys is None:
self.rng, split_rng = jax.random.split(self.rng)
return split_rng
elif isinstance(keys, int):
split_rngs = jax.random.split(self.rng, num=keys + 1)
self.rng = split_rngs[0]
return tuple(split_rngs[1:])
else:
split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
self.rng = split_rngs[0]
return {key: val for key, val in zip(keys, split_rngs[1:])}


def init_rng(seed):
global jax_utils_rng
jax_utils_rng = JaxRNG.from_seed(seed)


def next_rng(*args, **kwargs):
global jax_utils_rng
return jax_utils_rng(*args, **kwargs)


class GenerateRNG:
def __init__(self, seed: int = 0):
self.seed = seed
self.rng = jax.random.PRNGKey(seed)

def __next__(self):
while True:
self.rng, ke = jax.random.split(self.rng, 2)
return ke
15 changes: 5 additions & 10 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
numpy~=1.25.2
chex~=0.1.7
typing
jax>=0.4.10
transformers>=4.34.0
typing~=3.7.4.3
numpy
flax>=0.7.1
optax>=0.1.7
einops>=0.6.1
msgpack>=1.0.5
einops
msgpack
ml_collections
ipython~=8.14.0
chex~=0.1.82
setuptools~=68.0.0
torch>=2.0.0
einops
datasets
IPython>=8.17.2
IPython
17 changes: 7 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,26 @@

setuptools.setup(
name="fjformer",
version='0.0.1',
version='0.0.6',
author="Erfan Zare Chavoshi",
author_email="[email protected]",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/erfanzar/",
packages=setuptools.find_packages(),
install_requires=[
"numpy",
"chex~=0.1.7",
"typing",
"jax>=0.4.10",
"transformers>=4.34.0",
"typing~=3.7.4.3",
"numpy",
"flax>=0.7.1",
"optax>=0.1.7",
"einops>=0.6.1",
"msgpack>=1.0.5",
"einops",
"msgpack",
"ml_collections",
"torch>=2.0.0",
"einops",
"datasets",
"IPython>=8.17.2"
"IPython"
],
python_requires=">=3.7",
license='Apache License 2.0',
Expand All @@ -37,10 +35,9 @@
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
],
)
)

0 comments on commit f0b641f

Please sign in to comment.