Skip to content

Commit

Permalink
Merge pull request #25 from erfanzar/mojo-beta
Browse files Browse the repository at this point in the history
`Mistral` Models Added
  • Loading branch information
erfanzar authored Oct 2, 2023
2 parents 3a82cb1 + cc83767 commit ff262ef
Show file tree
Hide file tree
Showing 16 changed files with 898 additions and 301 deletions.
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@

EasyDeL (Easy Deep Learning) is an open-source library designed to accelerate and optimize the training process of
machine learning models. This library is primarily focused on Jax/Flax and plans to offer easy and fine solutions to
train Flax/Jax Models on the `TPU/GPU` both for Serving and Training (EasyDel will support mojo and be rewriten for mojo too)
train Flax/Jax Models on the `TPU/GPU` both for Serving and Training (EasyDel will support mojo and be rewriten for mojo
too)

### EasyDel Mojo

EasyDel Mojo differs from EasyDel in Python in significant ways. In Python, you can leverage a vast array of packages to create a mid or high-level API in no time. However, when working with Mojo, it's a different story. Here, you have to build some of the features that other Python libraries provide, such as Jax for arrays and computations. But why not import numpy, Jax, and other similar packages to Mojo and use them?
EasyDel Mojo differs from EasyDel in Python in significant ways. In Python, you can leverage a vast array of packages to
create a mid or high-level API in no time. However, when working with Mojo, it's a different story. Here, you have to
build some of the features that other Python libraries provide, such as Jax for arrays and computations. But why not
import numpy, Jax, and other similar packages to Mojo and use them?

There are several reasons why building packages in Mojo is more efficient than importing them from Python. Firstly, when you import packages from Python, you incur the overhead of translating and processing the Python code into Mojo code, which takes time. Secondly, the Python code may not be optimized for the Mojo runtime environment, leading to slower performance. Lastly, building packages directly in Mojo allows you to design and optimize them explicitly for the Mojo runtime environment, resulting in faster and more efficient code. With Mojo's built-in array capabilities that are 35000x faster than Python, it's time to take your coding to the next level.
There are several reasons why building packages in Mojo is more efficient than importing them from Python. Firstly, when
you import packages from Python, you incur the overhead of translating and processing the Python code into Mojo code,
which takes time. Secondly, the Python code may not be optimized for the Mojo runtime environment, leading to slower
performance. Lastly, building packages directly in Mojo allows you to design and optimize them explicitly for the Mojo
runtime environment, resulting in faster and more efficient code. With Mojo's built-in array capabilities that are
35000x faster than Python, it's time to take your coding to the next level.

[Read More ...](https://github.com/erfanzar/EasyDeL/blob/main/lib/mojo/README.md)

Expand Down Expand Up @@ -75,6 +84,7 @@ _Tutorials on how to use and train or serve your models with EasyDel is availabl

## Available Models Are

- **_Mistral_** (Support `FSDP`, `MP`,` DP`)(_Supports gradient checkpointing_)
- **_Llama_** (Support `FSDP`, `MP`,` DP`)(_Supports gradient checkpointing_)
- **_Llama2_** (Support `FSDP`, `MP`,` DP`)(_Supports gradient checkpointing_)
- **_GPT-J_** (Support `FSDP`, `MP`,` DP`)(_Supports gradient checkpointing_)
Expand Down
34 changes: 24 additions & 10 deletions docs/Python/Models.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,21 @@

## Available Models Are

1. **_[Llama](https://erfanzar.github.io/EasyDeL/docs/Python/Llama)_**:
1. **_Mistral_**:

* Supports:
* Fully Sharded Data Parallel `(FSDP)`
* MultiProcessing `(MP)`
* Data Parallel `(DP)`
* Distributed Data Parallel (DDP) `(DP)`
* Gradient CheckPointing
* Usage and Import from EasyDel Library

[//]: # ( * Flash Attention)

[//]: # ( * BlockWise Attention)

2. **_[Llama](https://erfanzar.github.io/EasyDeL/docs/Python/Llama)_**:

* Supports:
* Fully Sharded Data Parallel `(FSDP)`
Expand All @@ -16,7 +30,7 @@
* [Usage](https://erfanzar.github.io/EasyDeL/docs/Python/Llama)


2. **_[Llama2](https://erfanzar.github.io/EasyDeL/docs/Python/Llama2)_**:
3. **_[Llama2](https://erfanzar.github.io/EasyDeL/docs/Python/Llama2)_**:

* Supports:
* Fully Sharded Data Parallel `(FSDP)`
Expand All @@ -29,7 +43,7 @@
* [Usage](https://erfanzar.github.io/EasyDeL/docs/Python/Llama2)


3. **_[Falcon](https://erfanzar.github.io/EasyDeL/docs/Python/Falcon)_**:
4. **_[Falcon](https://erfanzar.github.io/EasyDeL/docs/Python/Falcon)_**:

* Supports:
* Fully Sharded Data Parallel `(FSDP)`
Expand All @@ -42,7 +56,7 @@
* [Usage](https://erfanzar.github.io/EasyDeL/docs/Python/Falcon)


4. **_[MosaicMPT](https://erfanzar.github.io/EasyDeL/docs/Python/MosaicMPT)_**:
5. **_[MosaicMPT](https://erfanzar.github.io/EasyDeL/docs/Python/MosaicMPT)_**:

* Supports:
* Fully Sharded Data Parallel `(FSDP)`
Expand All @@ -55,7 +69,7 @@
* [Usage](https://erfanzar.github.io/EasyDeL/docs/Python/MosaicMPT)


5. **_GPTNeoX_** :
6. **_GPTNeoX_** :

* Supports:
* Fully Sharded Data Parallel `(FSDP)`
Expand All @@ -65,7 +79,7 @@
* Gradient CheckPointing


6. **_LT_** :
7. **_LT_** :

* Supports:
* Fully Sharded Data Parallel `(FSDP)`
Expand All @@ -74,7 +88,7 @@
* Distributed Data Parallel (DDP) `(DP)`
* Gradient CheckPointing

7. **_Palm_**:
8. **_Palm_**:

* Supports:
* Fully Sharded Data Parallel `(FSDP)`
Expand All @@ -84,7 +98,7 @@
* Gradient CheckPointing


8. **_T5_**:
9. **_T5_**:

* Supports:
* Fully Sharded Data Parallel `(FSDP)`
Expand All @@ -93,7 +107,7 @@
* Distributed Data Parallel (DDP) `(DP)`
* Gradient CheckPointing

9. **_GPT-J_** :
10. **_GPT-J_** :

* Supports:
* Fully Sharded Data Parallel `(FSDP)`
Expand All @@ -104,7 +118,7 @@
* Flash Attention
* BlockWise Attention

10. **_OPT_**:
11. **_OPT_**:

* Supports:
* Fully Sharded Data Parallel `(FSDP)`
Expand Down
4 changes: 2 additions & 2 deletions lib/python/EasyDel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
FlaxGPTJForCausalLMModule, FlaxGPTJModel, FlaxGPTJForCausalLM, FlaxMptForCausalLM, MptConfig, FlaxMptModel, \
FlaxFalconForCausalLM, FlaxFalconModel, FalconConfig, FlaxGPTNeoXForCausalLM, GPTNeoXConfig, FlaxGPTNeoXModel, \
FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxPalmForCausalLM, PalmModel, PalmConfig, T5Config, \
FlaxOPTForCausalLM, FlaxOPTModel, OPTConfig
FlaxOPTForCausalLM, FlaxOPTModel, OPTConfig, FlaxMistralModule, FlaxMistralForCausalLM, MistralConfig
from .trainer import TrainArguments, fsdp_train_step, get_training_modules, CausalLMTrainer

try:
Expand All @@ -29,4 +29,4 @@
"FlaxT5ForConditionalGeneration", "FlaxT5Model", \
"FlaxPalmForCausalLM", "PalmModel", "PalmConfig", "T5Config", \
"FlaxOPTForCausalLM", "FlaxOPTModel", "OPTConfig", "CausalLMTrainer", "LlamaConfig", "__version__", "JAXServer", \
"get_mesh", "PyTorchServer", "JaxServerConfig"
"get_mesh", "PyTorchServer", "JaxServerConfig", "FlaxMistralModule", "FlaxMistralForCausalLM", "MistralConfig"
4 changes: 3 additions & 1 deletion lib/python/EasyDel/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .palm import PalmConfig, PalmModel, FlaxPalmForCausalLM
from .t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, T5Config
from .opt import FlaxOPTForCausalLM, FlaxOPTModel, OPTConfig
from .mistral import FlaxMistralModule, FlaxMistralForCausalLM, MistralConfig

__all__ = ['FlaxLlamaForCausalLM', 'FlaxLlamaModel',
'FlaxGPTJModule', 'FlaxGPTJForCausalLMModule', 'FlaxGPTJModel', 'FlaxGPTJForCausalLM', 'GPTJConfig',
Expand All @@ -16,5 +17,6 @@
"FlaxGPTNeoXForCausalLM", "GPTNeoXConfig", "FlaxGPTNeoXModel",
"FlaxT5ForConditionalGeneration", "FlaxT5Model",
"PalmConfig", "PalmModel", "FlaxPalmForCausalLM", 'T5Config',
"FlaxOPTForCausalLM", "FlaxOPTModel", "OPTConfig", "LlamaConfig"
"FlaxOPTForCausalLM", "FlaxOPTModel", "OPTConfig", "LlamaConfig",
"FlaxMistralModule", "FlaxMistralForCausalLM", "MistralConfig"
]
42 changes: 2 additions & 40 deletions lib/python/EasyDel/modules/falcon/modelling_falcon_flax.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,15 @@
import math

from flax import linen as nn
from flax.core import FrozenDict
from typing import Optional, Dict, Union, Tuple
from transformers import FlaxPreTrainedModel, PretrainedConfig
from jax import numpy as jnp
import jax
from jax.interpreters import pxla
from jax.experimental.pjit import pjit, with_sharding_constraint as wsc
from jax.sharding import PartitionSpec
from transformers.modeling_flax_outputs import FlaxCausalLMOutput, FlaxBaseModelOutput
from jax.random import split, PRNGKey
from functools import partial
from einops import rearrange

ACT2FN = {
"gelu": partial(nn.gelu, approximate=False),
"relu": nn.relu,
"silu": nn.swish,
"swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True),

}


def get_names_from_parition_spec(partition_specs):
names = set()
if isinstance(partition_specs, dict):
partition_specs = partition_specs.values()
for item in partition_specs:
if item is None:
continue
elif isinstance(item, str):
names.add(item)
else:
names.update(get_names_from_parition_spec(item))

return list(names)


def names_in_mesh(*names):
return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names)


def with_sharding_constraint(x, partition_specs):
axis_names = get_names_from_parition_spec(partition_specs)
if names_in_mesh(*axis_names):
x = wsc(x, partition_specs)
return x
from ..flax_modelling_utils import get_gradient_checkpoint_policy, \
with_sharding_constraint


class FalconConfig(PretrainedConfig):
Expand Down
49 changes: 49 additions & 0 deletions lib/python/EasyDel/modules/flax_modelling_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from jax.interpreters import pxla
from jax.experimental.pjit import with_sharding_constraint as wsc
import jax
from flax import linen as nn
from functools import partial

ACT2FN = {
"gelu": partial(nn.gelu, approximate=False),
"relu": nn.relu,
"silu": nn.swish,
"swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True),

}


def get_names_from_partition_spec(partition_specs):
names = set()
if isinstance(partition_specs, dict):
partition_specs = partition_specs.values()
for item in partition_specs:
if item is None:
continue
elif isinstance(item, str):
names.add(item)
else:
names.update(get_names_from_partition_spec(item))

return list(names)


def names_in_mesh(*names):
return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names)


def with_sharding_constraint(x, partition_specs):
axis_names = get_names_from_partition_spec(partition_specs)
if names_in_mesh(*axis_names):
x = wsc(x, partition_specs)
return x


def get_gradient_checkpoint_policy(name):
return {
'everything_saveable': jax.checkpoint_policies.everything_saveable,
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
}[name]
40 changes: 7 additions & 33 deletions lib/python/EasyDel/modules/gpt_j/modelling_gpt_j_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,51 +34,25 @@
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax

from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from transformers.utils import logging

from transformers import PreTrainedTokenizer, TensorType, is_torch_available
from transformers.configuration_utils import PretrainedConfig
from transformers.onnx import OnnxConfigWithPast, PatchingSpec
from jax.interpreters import pxla
from fjutils.flash_attention import dot_product_attention_multihead
from ..flax_modelling_utils import with_sharding_constraint

logger = logging.get_logger(__name__)


def get_names_from_parition_spec(partition_specs):
names = set()
if isinstance(partition_specs, dict):
partition_specs = partition_specs.values()
for item in partition_specs:
if item is None:
continue
elif isinstance(item, str):
names.add(item)
else:
names.update(get_names_from_parition_spec(item))

return list(names)


def names_in_mesh(*names):
return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names)


def with_sharding_constraint_(x, partition_specs):
axis_names = get_names_from_parition_spec(partition_specs)
if names_in_mesh(*axis_names):
x = wsc(x, partition_specs)
return x


class GPTJConfig(PretrainedConfig):
model_type = "gptj"
attribute_map = {
Expand Down Expand Up @@ -450,9 +424,9 @@ def __call__(

# Force A local Sharding
if self.config.use_pjit_attention_force:
query = with_sharding_constraint_(query, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
key = with_sharding_constraint_(key, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
value = with_sharding_constraint_(value, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
query = with_sharding_constraint(query, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
key = with_sharding_constraint(key, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
value = with_sharding_constraint(value, PartitionSpec(('dp', 'fsdp'), None, 'mp'))

query = self._split_heads(query)
key = self._split_heads(key)
Expand Down
38 changes: 4 additions & 34 deletions lib/python/EasyDel/modules/gpt_neo_x/modelling_gpt_neo_x_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,8 @@
from jax.random import split, PRNGKey
from functools import partial
from einops import rearrange

ACT2FN = {
"gelu": partial(nn.gelu, approximate=False),
"relu": nn.relu,
"silu": nn.swish,
"swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True),

}


def get_names_from_parition_spec(partition_specs):
names = set()
if isinstance(partition_specs, dict):
partition_specs = partition_specs.values()
for item in partition_specs:
if item is None:
continue
elif isinstance(item, str):
names.add(item)
else:
names.update(get_names_from_parition_spec(item))

return list(names)


def with_sharding_constraint(x, partition_specs):
def names_in_mesh(*names):
return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names)

axis_names = get_names_from_parition_spec(partition_specs)
if names_in_mesh(*axis_names):
x = wsc(x, partition_specs)
return x
from ..flax_modelling_utils import get_gradient_checkpoint_policy, \
with_sharding_constraint, ACT2FN


class GPTNeoXConfig(PretrainedConfig):
Expand Down Expand Up @@ -128,10 +96,12 @@ def get_partition_rules(fully_fsdp: bool = False):
@staticmethod
def get_mesh_names():
return 'dp', 'fsdp', 'mp'

def add_jax_args(self):
self.from_pt = False
...


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
dtype: jnp.dtype = jnp.bfloat16) -> jnp.ndarray:
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
Expand Down
Loading

0 comments on commit ff262ef

Please sign in to comment.