Skip to content

Commit

Permalink
Version 0.0.17
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Oct 5, 2023
1 parent 73d3ea1 commit 09a7fcc
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 17 deletions.
12 changes: 7 additions & 5 deletions fjutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from fjutils.checkpointing import StreamingCheckpointer
from fjutils.easylm import with_sharding_constraint, match_partition_rules, wrap_function_with_rng, tree_apply, \
from .checkpointing import StreamingCheckpointer
from .easylm import with_sharding_constraint, match_partition_rules, wrap_function_with_rng, tree_apply, \
names_in_current_mesh, flatten_tree, get_jax_mesh, cross_entropy_loss_and_accuracy, tree_path_to_string, \
average_metrics, float_tensor_to_dtype, get_float_dtype_by_name, float_to_dtype, get_metrics, \
get_gradient_checkpoint_policy, get_names_from_partition_spec, global_norm, get_weight_decay_mask, mse_loss, \
named_tree_map, make_shard_and_gather_fns, blockwise_cross_entropy, blockwise_dot_product_attention
from fjutils.load import load_pretrained_model
from fjutils.utils import change_to_fp16, change_to_fp32, change_to_bf16, change, count_params, get_names, get_devices
from fjutils.flash_attention import dot_product_attention_queries_per_head, dot_product_attention_multihead, \
from .load import load_pretrained_model
from .utils import change_to_fp16, change_to_fp32, change_to_bf16, change, count_params, get_names, get_devices
from .flash_attention import dot_product_attention_queries_per_head, dot_product_attention_multihead, \
dot_product_attention_multiquery, _memory_efficient_attention

__version__ = '0.0.17'
2 changes: 1 addition & 1 deletion fjutils/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
import msgpack
from fjutils.easylm import float_tensor_to_dtype, tree_apply
from .easylm import float_tensor_to_dtype, tree_apply


class StreamingCheckpointer(object):
Expand Down
2 changes: 1 addition & 1 deletion fjutils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Union
import os
from huggingface_hub import snapshot_download
from fjutils.checkpointing import StreamingCheckpointer
from .checkpointing import StreamingCheckpointer

TRANSFORMERS_CLS: Union = [
FlaxAutoModel,
Expand Down
15 changes: 12 additions & 3 deletions fjutils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
import msgpack
from typing import List, Optional, Callable, Any

from fjutils.checkpointing import StreamingCheckpointer
from .checkpointing import StreamingCheckpointer
from jax import numpy as jnp
import numpy as np
import json
import re
from jax.sharding import PartitionSpec as PS
import flax
from jax.interpreters import pxla
from fjutils.easylm import with_sharding_constraint
from .easylm import with_sharding_constraint
from flax.serialization import from_bytes, to_bytes, to_state_dict
from flax.traverse_util import flatten_dict
from fjutils.easylm import float_tensor_to_dtype
from .easylm import float_tensor_to_dtype


def is_torch_available():
Expand Down Expand Up @@ -211,3 +211,12 @@ def collate_fn(batch):
)
max_steps = num_epochs * len(dataloader) if max_steps is None else max_steps
return dataloader, max_steps


def inverse_permute(tensor, head, dim_in, dim_out):
return tensor.reshape(head, 2, dim_in // head // 2, dim_out).transpose(0, 2, 1, 3).reshape(
dim_in, dim_out)


def permute(tensor, head, dim_in, dim_out):
return tensor.view(head, dim_in // head // 2, 2, dim_out).transpose(1, 2).reshape(dim_in, dim_out)
12 changes: 6 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
jax~=0.4.10
transformers~=4.31.0
jax>=0.4.17
transformers>=4.34.0
typing~=3.7.4.3
numpy
flax==0.7.1
optax~=0.1.7
einops~=0.6.1
msgpack~=1.0.5
flax>=0.7.3
optax>=0.1.7
einops>=0.6.1
msgpack>=1.0.5
ml_collections
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setuptools.setup(
name="FJUtils",
version='0.0.16',
version='0.0.17',
author="Erfan Zare Chavoshi",
author_email="[email protected]",
long_description=long_description,
Expand Down

0 comments on commit 09a7fcc

Please sign in to comment.