Skip to content

Commit

Permalink
fixing bug of numpy.float32 has no attr client for attention
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 25, 2023
1 parent fc39926 commit 07d2274
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 14 deletions.
42 changes: 33 additions & 9 deletions fjformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,53 @@
from fjformer.load import (
load_and_convert_checkpoint_to_torch, float_tensor_to_dtype, read_ckpt, save_ckpt, StreamingCheckpointer,
load_and_convert_checkpoint_to_torch,
float_tensor_to_dtype,
read_ckpt,
save_ckpt,
StreamingCheckpointer,
get_float_dtype_by_name
)

from fjformer.partition_utils import (
get_jax_mesh, names_in_current_mesh, get_names_from_partition_spec, match_partition_rules,
flatten_tree, get_metrics, tree_apply, get_weight_decay_mask, named_tree_map,
tree_path_to_string, make_shard_and_gather_fns, with_sharding_constraint,
wrap_function_with_rng
get_jax_mesh,
names_in_current_mesh,
get_names_from_partition_spec,
match_partition_rules,
flatten_tree,
get_metrics,
tree_apply,
get_weight_decay_mask,
named_tree_map,
tree_path_to_string,
make_shard_and_gather_fns,
with_sharding_constraint,
wrap_function_with_rng,
create_mesh
)

from fjformer.monitor import (
run, get_mem, is_notebook, threaded_log, initialise_tracking
run,
get_mem,
is_notebook,
threaded_log,
initialise_tracking
)

from fjformer.datasets import (
get_dataloader
)

from .func import (
transpose, global_norm, average_metrics
transpose,
global_norm,
average_metrics
)

from .utils import (
JaxRNG, GenerateRNG, init_rng, next_rng, count_num_params
JaxRNG,
GenerateRNG,
init_rng,
next_rng,
count_num_params
)

__version__ = '0.0.11'
__version__ = '0.0.12'
20 changes: 16 additions & 4 deletions fjformer/partition_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from .mesh_utils import (get_jax_mesh, names_in_current_mesh, get_names_from_partition_spec, match_partition_rules,
flatten_tree, get_metrics, tree_apply, get_weight_decay_mask, named_tree_map,
tree_path_to_string, make_shard_and_gather_fns, with_sharding_constraint,
wrap_function_with_rng)
from .mesh_utils import (
get_jax_mesh,
names_in_current_mesh,
get_names_from_partition_spec,
match_partition_rules,
flatten_tree,
get_metrics,
tree_apply,
get_weight_decay_mask,
named_tree_map,
tree_path_to_string,
make_shard_and_gather_fns,
with_sharding_constraint,
wrap_function_with_rng,
create_mesh
)
14 changes: 14 additions & 0 deletions fjformer/partition_utils/mesh_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import jax
import jax.numpy as jnp
import re

from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.pjit import pjit, with_sharding_constraint as _with_sharding_constraint
import numpy as np
from jax.sharding import PartitionSpec as PS
from jax.experimental import mesh_utils
from jax.interpreters import pxla
import flax
from jax.sharding import Mesh
from typing import Sequence


def make_shard_and_gather_fns(partition_specs, dtype_specs=None):
Expand Down Expand Up @@ -220,3 +223,14 @@ def weight_decay_mask(params):
def tree_apply(fns, tree):
""" Apply a pytree of functions to the pytree. """
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)


def create_mesh(
axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ("dp", "fsdp", "tp", "mp"), backend=''
):
array_devices = jax.numpy.ones((len(jax.devices() if backend == '' else jax.devices(backend)), 1))
resh = array_devices.reshape(axis_dims).shape

return jax.sharding.Mesh(
create_device_mesh(resh), axis_names
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setuptools.setup(
name="fjformer",
version='0.0.11',
version='0.0.12',
author="Erfan Zare Chavoshi",
author_email="[email protected]",
long_description=long_description,
Expand Down

0 comments on commit 07d2274

Please sign in to comment.