From c0e2100bb45cbc8982d0acc63a4290102ce6d766 Mon Sep 17 00:00:00 2001 From: Erfan Zare Chavoshi <59269023+erfanzar@users.noreply.github.com> Date: Mon, 2 Oct 2023 13:07:16 +0330 Subject: [PATCH 1/2] TODO Adding Mistral --- README.md | 3 ++- docs/Python/Models.md | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index cf59c1648..884c7df60 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,8 @@ pip install --upgrade pip # Note: wheels only available on linux. pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` - +## TODO +#### Mistral Model will be supported as soon as possible ## Documentation Tadadad (Magic Sound) 💫 finally documents are ready at [EasyDel/Docs](https://erfanzar.github.io/EasyDeL/docs) diff --git a/docs/Python/Models.md b/docs/Python/Models.md index 0b015c48e..356b9a7b6 100644 --- a/docs/Python/Models.md +++ b/docs/Python/Models.md @@ -2,6 +2,19 @@ ## Available Models Are +0. **_Mistral_**: + + * Supports: + * Fully Sharded Data Parallel `(FSDP)` + * MultiProcessing `(MP)` + * Data Parallel `(DP)` + * Distributed Data Parallel (DDP) `(DP)` + * Gradient CheckPointing + * Flash Attention + * BlockWise Attention + * Usage and Import from EasyDel Library + + 1. **_[Llama](https://erfanzar.github.io/EasyDeL/docs/Python/Llama)_**: * Supports: From cc83767e4fd0f4a784e181307298a97f88cf214b Mon Sep 17 00:00:00 2001 From: Erfan Zare Chavoshi <59269023+erfanzar@users.noreply.github.com> Date: Mon, 2 Oct 2023 17:38:14 +0330 Subject: [PATCH 2/2] Preparing For Version 0.31.0 , `Mistral` Models Added --- README.md | 19 +- docs/Python/Models.md | 29 +- lib/python/EasyDel/__init__.py | 4 +- lib/python/EasyDel/modules/__init__.py | 4 +- .../modules/falcon/modelling_falcon_flax.py | 42 +- .../EasyDel/modules/flax_modelling_utils.py | 49 ++ .../modules/gpt_j/modelling_gpt_j_flax.py | 40 +- .../gpt_neo_x/modelling_gpt_neo_x_flax.py | 38 +- .../modules/llama/modelling_llama_flax.py | 48 +- .../EasyDel/modules/mistral/__init__.py | 1 + .../modules/mistral/modelling_mistral_flax.py | 781 ++++++++++++++++++ .../modules/mosaic_mpt/modelling_mpt_flax.py | 41 +- .../EasyDel/modules/opt/modelling_opt_flax.py | 40 +- .../modules/palm/modelling_palm_flax.py | 32 +- .../EasyDel/modules/t5/modelling_t5_flax.py | 27 +- requirements.txt | 2 +- 16 files changed, 890 insertions(+), 307 deletions(-) create mode 100644 lib/python/EasyDel/modules/flax_modelling_utils.py create mode 100644 lib/python/EasyDel/modules/mistral/__init__.py create mode 100644 lib/python/EasyDel/modules/mistral/modelling_mistral_flax.py diff --git a/README.md b/README.md index 884c7df60..f2cc79b97 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,22 @@ EasyDeL (Easy Deep Learning) is an open-source library designed to accelerate and optimize the training process of machine learning models. This library is primarily focused on Jax/Flax and plans to offer easy and fine solutions to -train Flax/Jax Models on the `TPU/GPU` both for Serving and Training (EasyDel will support mojo and be rewriten for mojo too) +train Flax/Jax Models on the `TPU/GPU` both for Serving and Training (EasyDel will support mojo and be rewriten for mojo +too) ### EasyDel Mojo -EasyDel Mojo differs from EasyDel in Python in significant ways. In Python, you can leverage a vast array of packages to create a mid or high-level API in no time. However, when working with Mojo, it's a different story. Here, you have to build some of the features that other Python libraries provide, such as Jax for arrays and computations. But why not import numpy, Jax, and other similar packages to Mojo and use them? +EasyDel Mojo differs from EasyDel in Python in significant ways. In Python, you can leverage a vast array of packages to +create a mid or high-level API in no time. However, when working with Mojo, it's a different story. Here, you have to +build some of the features that other Python libraries provide, such as Jax for arrays and computations. But why not +import numpy, Jax, and other similar packages to Mojo and use them? -There are several reasons why building packages in Mojo is more efficient than importing them from Python. Firstly, when you import packages from Python, you incur the overhead of translating and processing the Python code into Mojo code, which takes time. Secondly, the Python code may not be optimized for the Mojo runtime environment, leading to slower performance. Lastly, building packages directly in Mojo allows you to design and optimize them explicitly for the Mojo runtime environment, resulting in faster and more efficient code. With Mojo's built-in array capabilities that are 35000x faster than Python, it's time to take your coding to the next level. +There are several reasons why building packages in Mojo is more efficient than importing them from Python. Firstly, when +you import packages from Python, you incur the overhead of translating and processing the Python code into Mojo code, +which takes time. Secondly, the Python code may not be optimized for the Mojo runtime environment, leading to slower +performance. Lastly, building packages directly in Mojo allows you to design and optimize them explicitly for the Mojo +runtime environment, resulting in faster and more efficient code. With Mojo's built-in array capabilities that are +35000x faster than Python, it's time to take your coding to the next level. [Read More ...](https://github.com/erfanzar/EasyDeL/blob/main/lib/mojo/README.md) @@ -48,8 +57,7 @@ pip install --upgrade pip # Note: wheels only available on linux. pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` -## TODO -#### Mistral Model will be supported as soon as possible + ## Documentation Tadadad (Magic Sound) 💫 finally documents are ready at [EasyDel/Docs](https://erfanzar.github.io/EasyDeL/docs) @@ -76,6 +84,7 @@ _Tutorials on how to use and train or serve your models with EasyDel is availabl ## Available Models Are +- **_Mistral_** (Support `FSDP`, `MP`,` DP`)(_Supports gradient checkpointing_) - **_Llama_** (Support `FSDP`, `MP`,` DP`)(_Supports gradient checkpointing_) - **_Llama2_** (Support `FSDP`, `MP`,` DP`)(_Supports gradient checkpointing_) - **_GPT-J_** (Support `FSDP`, `MP`,` DP`)(_Supports gradient checkpointing_) diff --git a/docs/Python/Models.md b/docs/Python/Models.md index 356b9a7b6..d265b8fd2 100644 --- a/docs/Python/Models.md +++ b/docs/Python/Models.md @@ -2,7 +2,7 @@ ## Available Models Are -0. **_Mistral_**: +1. **_Mistral_**: * Supports: * Fully Sharded Data Parallel `(FSDP)` @@ -10,12 +10,13 @@ * Data Parallel `(DP)` * Distributed Data Parallel (DDP) `(DP)` * Gradient CheckPointing - * Flash Attention - * BlockWise Attention * Usage and Import from EasyDel Library - -1. **_[Llama](https://erfanzar.github.io/EasyDeL/docs/Python/Llama)_**: +[//]: # ( * Flash Attention) + +[//]: # ( * BlockWise Attention) + +2. **_[Llama](https://erfanzar.github.io/EasyDeL/docs/Python/Llama)_**: * Supports: * Fully Sharded Data Parallel `(FSDP)` @@ -29,7 +30,7 @@ * [Usage](https://erfanzar.github.io/EasyDeL/docs/Python/Llama) -2. **_[Llama2](https://erfanzar.github.io/EasyDeL/docs/Python/Llama2)_**: +3. **_[Llama2](https://erfanzar.github.io/EasyDeL/docs/Python/Llama2)_**: * Supports: * Fully Sharded Data Parallel `(FSDP)` @@ -42,7 +43,7 @@ * [Usage](https://erfanzar.github.io/EasyDeL/docs/Python/Llama2) -3. **_[Falcon](https://erfanzar.github.io/EasyDeL/docs/Python/Falcon)_**: +4. **_[Falcon](https://erfanzar.github.io/EasyDeL/docs/Python/Falcon)_**: * Supports: * Fully Sharded Data Parallel `(FSDP)` @@ -55,7 +56,7 @@ * [Usage](https://erfanzar.github.io/EasyDeL/docs/Python/Falcon) -4. **_[MosaicMPT](https://erfanzar.github.io/EasyDeL/docs/Python/MosaicMPT)_**: +5. **_[MosaicMPT](https://erfanzar.github.io/EasyDeL/docs/Python/MosaicMPT)_**: * Supports: * Fully Sharded Data Parallel `(FSDP)` @@ -68,7 +69,7 @@ * [Usage](https://erfanzar.github.io/EasyDeL/docs/Python/MosaicMPT) -5. **_GPTNeoX_** : +6. **_GPTNeoX_** : * Supports: * Fully Sharded Data Parallel `(FSDP)` @@ -78,7 +79,7 @@ * Gradient CheckPointing -6. **_LT_** : +7. **_LT_** : * Supports: * Fully Sharded Data Parallel `(FSDP)` @@ -87,7 +88,7 @@ * Distributed Data Parallel (DDP) `(DP)` * Gradient CheckPointing -7. **_Palm_**: +8. **_Palm_**: * Supports: * Fully Sharded Data Parallel `(FSDP)` @@ -97,7 +98,7 @@ * Gradient CheckPointing -8. **_T5_**: +9. **_T5_**: * Supports: * Fully Sharded Data Parallel `(FSDP)` @@ -106,7 +107,7 @@ * Distributed Data Parallel (DDP) `(DP)` * Gradient CheckPointing -9. **_GPT-J_** : +10. **_GPT-J_** : * Supports: * Fully Sharded Data Parallel `(FSDP)` @@ -117,7 +118,7 @@ * Flash Attention * BlockWise Attention -10. **_OPT_**: +11. **_OPT_**: * Supports: * Fully Sharded Data Parallel `(FSDP)` diff --git a/lib/python/EasyDel/__init__.py b/lib/python/EasyDel/__init__.py index 10bbf1562..e31ac4a05 100644 --- a/lib/python/EasyDel/__init__.py +++ b/lib/python/EasyDel/__init__.py @@ -4,7 +4,7 @@ FlaxGPTJForCausalLMModule, FlaxGPTJModel, FlaxGPTJForCausalLM, FlaxMptForCausalLM, MptConfig, FlaxMptModel, \ FlaxFalconForCausalLM, FlaxFalconModel, FalconConfig, FlaxGPTNeoXForCausalLM, GPTNeoXConfig, FlaxGPTNeoXModel, \ FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxPalmForCausalLM, PalmModel, PalmConfig, T5Config, \ - FlaxOPTForCausalLM, FlaxOPTModel, OPTConfig + FlaxOPTForCausalLM, FlaxOPTModel, OPTConfig, FlaxMistralModule, FlaxMistralForCausalLM, MistralConfig from .trainer import TrainArguments, fsdp_train_step, get_training_modules, CausalLMTrainer try: @@ -29,4 +29,4 @@ "FlaxT5ForConditionalGeneration", "FlaxT5Model", \ "FlaxPalmForCausalLM", "PalmModel", "PalmConfig", "T5Config", \ "FlaxOPTForCausalLM", "FlaxOPTModel", "OPTConfig", "CausalLMTrainer", "LlamaConfig", "__version__", "JAXServer", \ - "get_mesh", "PyTorchServer", "JaxServerConfig" + "get_mesh", "PyTorchServer", "JaxServerConfig", "FlaxMistralModule", "FlaxMistralForCausalLM", "MistralConfig" diff --git a/lib/python/EasyDel/modules/__init__.py b/lib/python/EasyDel/modules/__init__.py index 3691caf62..fdbc8c4bf 100644 --- a/lib/python/EasyDel/modules/__init__.py +++ b/lib/python/EasyDel/modules/__init__.py @@ -7,6 +7,7 @@ from .palm import PalmConfig, PalmModel, FlaxPalmForCausalLM from .t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, T5Config from .opt import FlaxOPTForCausalLM, FlaxOPTModel, OPTConfig +from .mistral import FlaxMistralModule, FlaxMistralForCausalLM, MistralConfig __all__ = ['FlaxLlamaForCausalLM', 'FlaxLlamaModel', 'FlaxGPTJModule', 'FlaxGPTJForCausalLMModule', 'FlaxGPTJModel', 'FlaxGPTJForCausalLM', 'GPTJConfig', @@ -16,5 +17,6 @@ "FlaxGPTNeoXForCausalLM", "GPTNeoXConfig", "FlaxGPTNeoXModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "PalmConfig", "PalmModel", "FlaxPalmForCausalLM", 'T5Config', - "FlaxOPTForCausalLM", "FlaxOPTModel", "OPTConfig", "LlamaConfig" + "FlaxOPTForCausalLM", "FlaxOPTModel", "OPTConfig", "LlamaConfig", + "FlaxMistralModule", "FlaxMistralForCausalLM", "MistralConfig" ] diff --git a/lib/python/EasyDel/modules/falcon/modelling_falcon_flax.py b/lib/python/EasyDel/modules/falcon/modelling_falcon_flax.py index b515469ca..f2076cd17 100644 --- a/lib/python/EasyDel/modules/falcon/modelling_falcon_flax.py +++ b/lib/python/EasyDel/modules/falcon/modelling_falcon_flax.py @@ -1,53 +1,15 @@ import math - from flax import linen as nn from flax.core import FrozenDict from typing import Optional, Dict, Union, Tuple from transformers import FlaxPreTrainedModel, PretrainedConfig from jax import numpy as jnp import jax -from jax.interpreters import pxla -from jax.experimental.pjit import pjit, with_sharding_constraint as wsc from jax.sharding import PartitionSpec from transformers.modeling_flax_outputs import FlaxCausalLMOutput, FlaxBaseModelOutput -from jax.random import split, PRNGKey -from functools import partial from einops import rearrange - -ACT2FN = { - "gelu": partial(nn.gelu, approximate=False), - "relu": nn.relu, - "silu": nn.swish, - "swish": nn.swish, - "gelu_new": partial(nn.gelu, approximate=True), - -} - - -def get_names_from_parition_spec(partition_specs): - names = set() - if isinstance(partition_specs, dict): - partition_specs = partition_specs.values() - for item in partition_specs: - if item is None: - continue - elif isinstance(item, str): - names.add(item) - else: - names.update(get_names_from_parition_spec(item)) - - return list(names) - - -def names_in_mesh(*names): - return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) - - -def with_sharding_constraint(x, partition_specs): - axis_names = get_names_from_parition_spec(partition_specs) - if names_in_mesh(*axis_names): - x = wsc(x, partition_specs) - return x +from ..flax_modelling_utils import get_gradient_checkpoint_policy, \ + with_sharding_constraint class FalconConfig(PretrainedConfig): diff --git a/lib/python/EasyDel/modules/flax_modelling_utils.py b/lib/python/EasyDel/modules/flax_modelling_utils.py new file mode 100644 index 000000000..412139da1 --- /dev/null +++ b/lib/python/EasyDel/modules/flax_modelling_utils.py @@ -0,0 +1,49 @@ +from jax.interpreters import pxla +from jax.experimental.pjit import with_sharding_constraint as wsc +import jax +from flax import linen as nn +from functools import partial + +ACT2FN = { + "gelu": partial(nn.gelu, approximate=False), + "relu": nn.relu, + "silu": nn.swish, + "swish": nn.swish, + "gelu_new": partial(nn.gelu, approximate=True), + +} + + +def get_names_from_partition_spec(partition_specs): + names = set() + if isinstance(partition_specs, dict): + partition_specs = partition_specs.values() + for item in partition_specs: + if item is None: + continue + elif isinstance(item, str): + names.add(item) + else: + names.update(get_names_from_partition_spec(item)) + + return list(names) + + +def names_in_mesh(*names): + return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) + + +def with_sharding_constraint(x, partition_specs): + axis_names = get_names_from_partition_spec(partition_specs) + if names_in_mesh(*axis_names): + x = wsc(x, partition_specs) + return x + + +def get_gradient_checkpoint_policy(name): + return { + 'everything_saveable': jax.checkpoint_policies.everything_saveable, + 'nothing_saveable': jax.checkpoint_policies.nothing_saveable, + 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots, + 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, + }[name] diff --git a/lib/python/EasyDel/modules/gpt_j/modelling_gpt_j_flax.py b/lib/python/EasyDel/modules/gpt_j/modelling_gpt_j_flax.py index c05313308..63d920363 100644 --- a/lib/python/EasyDel/modules/gpt_j/modelling_gpt_j_flax.py +++ b/lib/python/EasyDel/modules/gpt_j/modelling_gpt_j_flax.py @@ -34,51 +34,25 @@ import jax import jax.numpy as jnp import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.core.frozen_dict import FrozenDict, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput -from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel +from transformers.utils import logging from transformers import PreTrainedTokenizer, TensorType, is_torch_available from transformers.configuration_utils import PretrainedConfig from transformers.onnx import OnnxConfigWithPast, PatchingSpec from jax.interpreters import pxla from fjutils.flash_attention import dot_product_attention_multihead +from ..flax_modelling_utils import with_sharding_constraint logger = logging.get_logger(__name__) -def get_names_from_parition_spec(partition_specs): - names = set() - if isinstance(partition_specs, dict): - partition_specs = partition_specs.values() - for item in partition_specs: - if item is None: - continue - elif isinstance(item, str): - names.add(item) - else: - names.update(get_names_from_parition_spec(item)) - - return list(names) - - -def names_in_mesh(*names): - return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) - - -def with_sharding_constraint_(x, partition_specs): - axis_names = get_names_from_parition_spec(partition_specs) - if names_in_mesh(*axis_names): - x = wsc(x, partition_specs) - return x - - class GPTJConfig(PretrainedConfig): model_type = "gptj" attribute_map = { @@ -450,9 +424,9 @@ def __call__( # Force A local Sharding if self.config.use_pjit_attention_force: - query = with_sharding_constraint_(query, PartitionSpec(('dp', 'fsdp'), None, 'mp')) - key = with_sharding_constraint_(key, PartitionSpec(('dp', 'fsdp'), None, 'mp')) - value = with_sharding_constraint_(value, PartitionSpec(('dp', 'fsdp'), None, 'mp')) + query = with_sharding_constraint(query, PartitionSpec(('dp', 'fsdp'), None, 'mp')) + key = with_sharding_constraint(key, PartitionSpec(('dp', 'fsdp'), None, 'mp')) + value = with_sharding_constraint(value, PartitionSpec(('dp', 'fsdp'), None, 'mp')) query = self._split_heads(query) key = self._split_heads(key) diff --git a/lib/python/EasyDel/modules/gpt_neo_x/modelling_gpt_neo_x_flax.py b/lib/python/EasyDel/modules/gpt_neo_x/modelling_gpt_neo_x_flax.py index ffc8c27b4..8b1a045f6 100644 --- a/lib/python/EasyDel/modules/gpt_neo_x/modelling_gpt_neo_x_flax.py +++ b/lib/python/EasyDel/modules/gpt_neo_x/modelling_gpt_neo_x_flax.py @@ -13,40 +13,8 @@ from jax.random import split, PRNGKey from functools import partial from einops import rearrange - -ACT2FN = { - "gelu": partial(nn.gelu, approximate=False), - "relu": nn.relu, - "silu": nn.swish, - "swish": nn.swish, - "gelu_new": partial(nn.gelu, approximate=True), - -} - - -def get_names_from_parition_spec(partition_specs): - names = set() - if isinstance(partition_specs, dict): - partition_specs = partition_specs.values() - for item in partition_specs: - if item is None: - continue - elif isinstance(item, str): - names.add(item) - else: - names.update(get_names_from_parition_spec(item)) - - return list(names) - - -def with_sharding_constraint(x, partition_specs): - def names_in_mesh(*names): - return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) - - axis_names = get_names_from_parition_spec(partition_specs) - if names_in_mesh(*axis_names): - x = wsc(x, partition_specs) - return x +from ..flax_modelling_utils import get_gradient_checkpoint_policy, \ + with_sharding_constraint, ACT2FN class GPTNeoXConfig(PretrainedConfig): @@ -128,10 +96,12 @@ def get_partition_rules(fully_fsdp: bool = False): @staticmethod def get_mesh_names(): return 'dp', 'fsdp', 'mp' + def add_jax_args(self): self.from_pt = False ... + def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.bfloat16) -> jnp.ndarray: freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) diff --git a/lib/python/EasyDel/modules/llama/modelling_llama_flax.py b/lib/python/EasyDel/modules/llama/modelling_llama_flax.py index d418aaec7..c180923df 100644 --- a/lib/python/EasyDel/modules/llama/modelling_llama_flax.py +++ b/lib/python/EasyDel/modules/llama/modelling_llama_flax.py @@ -1,10 +1,6 @@ -import math from typing import Dict, Optional, Tuple, Union - -import fjutils.easylm from einops import einops from flax.linen import remat -from einops import einsum import jax import jax.numpy as jnp from jax import lax @@ -13,54 +9,14 @@ from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from flax.linen import partitioning as nn_partitioning - from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask - from transformers.configuration_utils import PretrainedConfig from transformers.modeling_flax_utils import FlaxPreTrainedModel - -from jax.interpreters import pxla from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput, FlaxSequenceClassifierOutput - -from jax.experimental.pjit import with_sharding_constraint as wsc -from fjutils import dot_product_attention_multihead from fjutils.easylm import blockwise_dot_product_attention - - -def get_names_from_partition_spec(partition_specs): - names = set() - if isinstance(partition_specs, dict): - partition_specs = partition_specs.values() - for item in partition_specs: - if item is None: - continue - elif isinstance(item, str): - names.add(item) - else: - names.update(get_names_from_partition_spec(item)) - - return list(names) - - -def names_in_mesh(*names): - return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) - - -def with_sharding_constraint(x, partition_specs): - axis_names = get_names_from_partition_spec(partition_specs) - if names_in_mesh(*axis_names): - x = wsc(x, partition_specs) - return x - - -def get_gradient_checkpoint_policy(name): - return { - 'everything_saveable': jax.checkpoint_policies.everything_saveable, - 'nothing_saveable': jax.checkpoint_policies.nothing_saveable, - 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots, - 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, - }[name] +from ..flax_modelling_utils import with_sharding_constraint, \ + get_gradient_checkpoint_policy class LlamaConfig(PretrainedConfig): diff --git a/lib/python/EasyDel/modules/mistral/__init__.py b/lib/python/EasyDel/modules/mistral/__init__.py new file mode 100644 index 000000000..9332f25df --- /dev/null +++ b/lib/python/EasyDel/modules/mistral/__init__.py @@ -0,0 +1 @@ +from .modelling_mistral_flax import FlaxMistralModule, FlaxMistralForCausalLM, MistralConfig diff --git a/lib/python/EasyDel/modules/mistral/modelling_mistral_flax.py b/lib/python/EasyDel/modules/mistral/modelling_mistral_flax.py new file mode 100644 index 000000000..7a111e6dc --- /dev/null +++ b/lib/python/EasyDel/modules/mistral/modelling_mistral_flax.py @@ -0,0 +1,781 @@ +import functools + +import flax.core +from jax import jit, random, grad, numpy as jnp, Array +from jax.sharding import PartitionSpec as PS +import jax +from flax import linen as nn +from flax.traverse_util import unflatten_dict, flatten_dict +from flax.core import freeze, unfreeze +from typing import Union, Optional, Tuple, Dict, List +from transformers import PretrainedConfig, FlaxPreTrainedModel +from flax.linen import partitioning as nn_partitioning +from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput + +from ..flax_modelling_utils import ACT2FN, with_sharding_constraint + + +class MistralConfig(PretrainedConfig): + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + gradient_checkpointing: str = 'nothing_saveable', + use_pjit_attention_force: bool = True, + use_flash_attention: bool = False, + use_sacn_mlp: bool = False, + flash_attn_query_chunk_size: int = 1024, + flash_attn_key_chunk_size: int = 1024, + scan_mlp_chunk_size: int = 1024, + number_rep_kv: int = 1, + attn_pdrop: float = 0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.use_flash_attention = use_flash_attention + self.number_rep_kv = number_rep_kv + self.gradient_checkpointing = gradient_checkpointing + self.use_pjit_attention_force = use_pjit_attention_force + self.use_sacn_mlp = use_sacn_mlp + self.flash_attn_query_chunk_size = flash_attn_query_chunk_size + self.flash_attn_key_chunk_size = flash_attn_key_chunk_size + self.scan_mlp_chunk_size = scan_mlp_chunk_size + self.attn_pdrop = attn_pdrop + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def get_partition_rules(self, fully_fsdp: bool = True): + ... + + def add_jax_args(self, + gradient_checkpointing: str = 'nothing_saveable', + use_pjit_attention_force: bool = True, + use_flash_attention: bool = False, + use_sacn_mlp: bool = False, + flash_attn_query_chunk_size: int = 1024, + flash_attn_key_chunk_size: int = 1024, + scan_mlp_chunk_size: int = 1024, + number_rep_kv: int = 1, + attn_pdrop: float = 0.0, + ): + self.use_flash_attention = use_flash_attention + self.number_rep_kv = number_rep_kv + self.gradient_checkpointing = gradient_checkpointing + self.use_pjit_attention_force = use_pjit_attention_force + self.use_sacn_mlp = use_sacn_mlp + self.flash_attn_query_chunk_size = flash_attn_query_chunk_size + self.flash_attn_key_chunk_size = flash_attn_key_chunk_size + self.scan_mlp_chunk_size = scan_mlp_chunk_size + self.attn_pdrop = attn_pdrop + + @staticmethod + def get_weight_decay_exclusions(): + return tuple() + + @staticmethod + def rng_keys(): + return ('params', 'dropout', 'fcm') + + +remat = nn_partitioning.remat + + +def repeat_kv(x: jax.Array, n_rep: int) -> jax.Array: + bs, s, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + + return jnp.expand_dims(x[:, :, :, None, :], (bs, s, n_kv_heads, n_rep, head_dim)).reshape(bs, s, n_kv_heads * n_rep, + head_dim) + + +class MistralRMSNorm(nn.Module): + dim: int + eps: float = 1e-6 + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.weight = self.param( + 'kernel', + nn.initializers.ones, + (self.dim,), + self.param_dtype, + ) + + def _norm(self, x: jnp.ndarray) -> jnp.ndarray: + return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) + output = self._norm(x).astype(self.dtype) + weight = jnp.asarray(self.weight, self.dtype) + return output * weight + + +def precompute_freq_cis( + method: Union[str, None], + dim: int, end: int, theta: float = 10000.0, + scaling_factor: float = 8., **kwargs) -> jnp.ndarray: + freq = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(jnp.float32) / dim)) + t = jnp.arange(end) + + if method is not None: + if method == 'linear': + t = t / scaling_factor + elif method == 'dynamic': + base = theta * ( + (scaling_factor * end / end) - (scaling_factor - 1) + ) ** (dim / (dim - 2)) + freq = 1.0 / (base ** (jnp.arange(0, dim, 2) / dim)) + else: + raise ValueError(f'unknown {method} method for precompute_freq_cis') + + freq = jnp.outer(t, freq).astype(jnp.float32) + sin, cos = jnp.sin(freq).astype(jnp.float32), jnp.cos(freq).astype(jnp.float32) + freq_cis = jnp.complex64(cos + 1j * sin) + return jnp.asarray(freq_cis) + + +def apply_rotary_emb( + xq: jnp.ndarray, + xk: jnp.ndarray, + freq_cis: jnp.ndarray, + dtype: jnp.dtype = jnp.float32, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) + reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) + + xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) + xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) + + # add head dim + freq_cis = jnp.reshape(freq_cis, (*freq_cis.shape[:2], 1, *freq_cis.shape[2:])) + + xq_out = xq_ * freq_cis + xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1) + + xk_out = xk_ * freq_cis + xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1) + + return xq_out.astype(dtype), xk_out.astype(dtype) + + +class FlaxMistralMLP(nn.Module): + config: MistralConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[None, jax.lax.Precision]] = jax.lax.Precision('fastest') + + def setup(self) -> None: + dense = functools.partial( + nn.Dense, + use_bias=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=nn.initializers.normal() + ) + self.gate_proj = dense(self.config.intermediate_size) + self.up_proj = dense(self.config.intermediate_size) + self.down_proj = dense(self.config.hidden_size) + self.act_fn = ACT2FN[self.config.hidden_act] + + def __call__(self, x: jax.Array): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class FlaxMistralAttention(nn.Module): + config: MistralConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[None, jax.lax.Precision]] = jax.lax.Precision('fastest') + + def setup(self) -> None: + config = self.config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + + dense = functools.partial( + nn.Dense, + use_bias=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=nn.initializers.normal() + ) + + self.q_proj = dense(self.num_heads * self.head_dim) + self.k_proj = dense(self.num_key_value_heads * self.head_dim) + self.v_proj = dense(self.num_key_value_heads * self.head_dim) + self.o_proj = dense(self.hidden_size) + + @nn.compact + def concatenate_to_cache_(self, q: jax.Array, k: jax.Array, v: jax.Array, attention_mask: jax.Array): + is_cache_available = self.has_variable('cache', 'key') + key_cache = self.variable('cache', 'key', jnp.zeros, k.shape, k.dtype) + value_cache = self.variable('cache', 'value', jnp.zeros, k.shape, v.dtype) + index_cache = self.variable('cache', 'index', lambda: jnp.array(0, dtype=jnp.int32)) + if is_cache_available: + *bd, ml, nh, dph = key_cache.value.shape + indices = (0,) * len(bd) + (index_cache.value, 0, 0) + k = jax.lax.dynamic_update_slice(key_cache.value, k, indices) + v = jax.lax.dynamic_update_slice(value_cache.value, v, indices) + key_cache.value = k + value_cache.value = v + num_updated_cache_vector = q.shape[1] + index_cache.value = index_cache.value + num_updated_cache_vector + pad_mask = jnp.broadcast_to( + jnp.arange(ml) < index_cache.value, + tuple(bd) + (1, num_updated_cache_vector, ml) + ) + attention_mask = nn.combine_masks(pad_mask, attention_mask) + return q, k, v, attention_mask + + def __call__( + self, + hidden_state: jax.Array, + freq_cis: jax.Array, + attention_mask: jax.Array, + causal_mask: jax.Array, + position_ids: jax.Array, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = True + ): + batch_size, max_sequence_length = hidden_state.shape[:2] + q, k, v = self.q_proj(hidden_state), self.k_proj(hidden_state), self.v_proj(hidden_state) + + if self.config.use_pjit_attention_force: + q = with_sharding_constraint(q, PS('fsdp', 'mp', None)) + k = with_sharding_constraint(k, PS('fsdp', 'mp', None)) + v = with_sharding_constraint(v, PS('fsdp', 'mp', None)) + + q = q.reshape(batch_size, max_sequence_length, -1, self.head_dim) + k = k.reshape(batch_size, max_sequence_length, -1, self.head_dim) + v = v.reshape(batch_size, max_sequence_length, -1, self.head_dim) + + freq_cis = jnp.take(freq_cis, position_ids, axis=0) + q, k = apply_rotary_emb(q, k, freq_cis=freq_cis, dtype=self.dtype) + k = repeat_kv(k, self.num_key_value_groups) + v = repeat_kv(v, self.num_key_value_groups) + if self.has_variable('cache', 'key') or init_cache: + q, k, v, attention_mask = self.concatenate_to_cache_(q, k, v, attention_mask) + + q_l, k_l = q.shape[1], k.shape[1] + + if self.has_variable('cache', 'key'): + mask_shift: int = self.variables['cache']['index'] + dl = self.variables['cache']['key'].shape[1] + causal_mask = jax.lax.dynamic_slice( + causal_mask, (0, 0, mask_shift, 0), (1, 1, q_l, dl) + ) + else: + causal_mask = causal_mask[:, :, :q_l, :k_l] + + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = nn.combine_masks(attention_mask, causal_mask) + attention_bias = jax.lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + + attn_weight = nn.dot_product_attention_weights( + query=q, + key=k, + bias=attention_bias, + dtype=jnp.promote_types(self.dtype, jnp.float32), + deterministic=deterministic, + dropout_rate=self.config.attn_pdrop, + precision=self.precision + ) + if self.config.use_pjit_attention_force: + attn_weight = with_sharding_constraint(attn_weight, PS(("dp", "fsdp"), "mp", None, None)) + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weight, v) + attn_output = attn_output.reshape(attn_output.shape[:2] + (self.hidden_size,)) + outputs = (attn_output, attn_weight) if output_attentions else (attn_output,) + return outputs + + +class FlaxMistralDecoderLayer(nn.Module): + config: MistralConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[None, jax.lax.Precision]] = jax.lax.Precision('fastest') + + def setup(self) -> None: + self.self_attn = FlaxMistralAttention( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.mlp = FlaxMistralMLP( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.input_layernorm = MistralRMSNorm( + dim=self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + self.post_attention_layernorm = MistralRMSNorm( + dim=self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + + def __call__( + self, + hidden_state: jax.Array, + freq_cis: jax.Array, + attention_mask: jax.Array, + causal_mask: jax.Array, + position_ids: jax.Array, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = True + ): + residual = hidden_state + attention_output = self.self_attn( + hidden_state=self.input_layernorm(hidden_state), + freq_cis=freq_cis, + attention_mask=attention_mask, + causal_mask=causal_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions + ) + hidden_state = attention_output[0] + residual + + hidden_state = self.mlp(self.post_attention_layernorm(hidden_state)) + hidden_state + outputs = (hidden_state,) + if output_attentions: + outputs += attention_output[1] + return outputs + + +class FlaxMistralPretrainedModel(FlaxPreTrainedModel): + config_class = MistralConfig + base_model_prefix = 'mistral' + module_class: nn.Module = None + + def __init__(self, + config: MistralConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights( + self, + rng: jax.random.PRNGKey, + input_shape: Tuple, + params: flax.core.FrozenDict = None + ) -> flax.core.FrozenDict: + + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + params_rng, dropout_rng = jax.random.split(rng) + rng_s = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rng_s, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init(rng_s, input_ids, attention_mask, position_ids, return_dict=False) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return init_variables["cache"] + + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + params: dict = None, + past_key_values: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + add_params_field: bool = False + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, sequence_length = input_ids.shape + + if position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") + + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + rng_s = {} + if dropout_rng is not None: + rng_s["dropout"] = dropout_rng + + inputs = {"params": params or self.params} if add_params_field else params or self.params + + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + # input_ids: jax.Array + # attention_mask: jax.Array + # position_ids: jax.Array + # deterministic: bool = True + # input_embeds: jax.Array = None + # init_cache: bool = False + # output_attentions: bool = False + # output_hidden_states: bool = False + # return_dict: bool = True + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + None, + False, + output_attentions, + output_hidden_states, + return_dict, + rngs=rng_s, + mutable=mutable, + ) + + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxMistralDecoratorCollection(nn.Module): + config: MistralConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[None, jax.lax.Precision]] = jax.lax.Precision('fastest') + + def setup(self) -> None: + self.layers = [ + FlaxMistralDecoderLayer( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=str(i) + ) for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_state: jax.Array, + freq_cis: jax.Array, + attention_mask: jax.Array, + causal_mask: jax.Array, + position_ids: jax.Array, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_state,) + output = layer( + hidden_state=hidden_state, + freq_cis=freq_cis, + attention_mask=attention_mask, + causal_mask=causal_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions + ) + hidden_state = output[0] + + if output_attentions: + output_attentions += (output[1],) + + return hidden_state, all_hidden_states, all_attentions + + +class FlaxMistralModule(nn.Module): + config: MistralConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + + self.embed_tokens = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + + self.layers = FlaxMistralDecoratorCollection( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.norm = MistralRMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + self.freq_cis = precompute_freq_cis( + method=None, + scaling_factor=1.0, + dim=self.config.hidden_size // self.config.num_attention_heads, + end=self.config.max_position_embeddings * 2 + ) + self.causal_mask = nn.make_causal_mask(jnp.ones(1, self.config.max_position_embeddings)) + + def __call__( + self, + input_ids: jax.Array, + attention_mask: jax.Array, + position_ids: jax.Array, + deterministic: bool = True, + input_embeds: jax.Array = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + + ) -> Tuple[Array, ...] | FlaxBaseModelOutput: + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids.astype("i4")) + + outputs = self.layers( + hidden_state=input_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + freq_cis=self.freq_cis, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + causal_mask=self.causal_mask + ) + + hidden_states = outputs[0] + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = outputs[1] + (hidden_states,) + outputs = (hidden_states, all_hidden_states) + outputs[2:] + else: + outputs = (hidden_states,) + outputs[1:] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs[1], + attentions=outputs[-1], + ) + + +class FlaxMistralModel(FlaxMistralPretrainedModel): + module_class = FlaxMistralModule + + +class FlaxMistralForCausalLMModule(nn.Module): + config: MistralConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.model: FlaxMistralModule = FlaxMistralModule( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.lm_head = nn.Dense( + self.config.vocab_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + precision=self.precision, + ) + + def __call__( + self, + input_ids: jax.Array, + attention_mask: jax.Array = None, + position_ids: jax.Array = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + batch_size, seq_length = input_ids.shape + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + position_ids = jnp.broadcast_to( + jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), + (batch_size, seq_length) + ) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T + lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + lm_logits = lm_logits.astype(jnp.float32) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + +class FlaxMistralForCausalLM(FlaxMistralPretrainedModel): + module_class = FlaxMistralForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = jax.lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + @staticmethod + def update_inputs_for_generation(model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs diff --git a/lib/python/EasyDel/modules/mosaic_mpt/modelling_mpt_flax.py b/lib/python/EasyDel/modules/mosaic_mpt/modelling_mpt_flax.py index c9e40d21a..b7702aac6 100644 --- a/lib/python/EasyDel/modules/mosaic_mpt/modelling_mpt_flax.py +++ b/lib/python/EasyDel/modules/mosaic_mpt/modelling_mpt_flax.py @@ -9,50 +9,13 @@ from transformers import FlaxPreTrainedModel, PretrainedConfig from jax import numpy as jnp import jax -from jax.interpreters import pxla -from jax.experimental.pjit import pjit, with_sharding_constraint as wsc from jax.sharding import PartitionSpec from transformers.modeling_flax_outputs import FlaxCausalLMOutput, FlaxBaseModelOutput -from jax.random import split, PRNGKey -from functools import partial import flax from einops import rearrange from fjutils.flash_attention import dot_product_attention_multihead - -ACT2FN = { - "gelu": partial(nn.gelu, approximate=False), - "relu": nn.relu, - "silu": nn.swish, - "swish": nn.swish, - "gelu_new": partial(nn.gelu, approximate=True), - -} - - -def get_names_from_parition_spec(partition_specs): - names = set() - if isinstance(partition_specs, dict): - partition_specs = partition_specs.values() - for item in partition_specs: - if item is None: - continue - elif isinstance(item, str): - names.add(item) - else: - names.update(get_names_from_parition_spec(item)) - - return list(names) - - -def names_in_mesh(*names): - return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) - - -def with_sharding_constraint(x, partition_specs): - axis_names = get_names_from_parition_spec(partition_specs) - if names_in_mesh(*axis_names): - x = wsc(x, partition_specs) - return x +from ..flax_modelling_utils import get_gradient_checkpoint_policy, \ + with_sharding_constraint class MptConfig(PretrainedConfig): diff --git a/lib/python/EasyDel/modules/opt/modelling_opt_flax.py b/lib/python/EasyDel/modules/opt/modelling_opt_flax.py index 8fa2faa94..721f6ae35 100644 --- a/lib/python/EasyDel/modules/opt/modelling_opt_flax.py +++ b/lib/python/EasyDel/modules/opt/modelling_opt_flax.py @@ -29,46 +29,12 @@ from jax.random import PRNGKey from transformers import PretrainedConfig from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxMaskedLMOutput -from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring -from jax.experimental.pjit import with_sharding_constraint as wsc +from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel from jax.sharding import PartitionSpec -from jax.interpreters import pxla from transformers import logging - -def get_gradient_checkpoint_policy(name): - return { - 'everything_saveable': jax.checkpoint_policies.everything_saveable, - 'nothing_saveable': jax.checkpoint_policies.nothing_saveable, - 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots, - 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, - }[name] - - -def get_names_from_parition_spec(partition_specs): - names = set() - if isinstance(partition_specs, dict): - partition_specs = partition_specs.values() - for item in partition_specs: - if item is None: - continue - elif isinstance(item, str): - names.add(item) - else: - names.update(get_names_from_parition_spec(item)) - - return list(names) - - -def names_in_mesh(*names): - return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) - - -def with_sharding_constraint(x, partition_specs): - axis_names = get_names_from_parition_spec(partition_specs) - if names_in_mesh(*axis_names): - x = wsc(x, partition_specs) - return x +from ..flax_modelling_utils import get_gradient_checkpoint_policy, \ + with_sharding_constraint class OPTConfig(PretrainedConfig): diff --git a/lib/python/EasyDel/modules/palm/modelling_palm_flax.py b/lib/python/EasyDel/modules/palm/modelling_palm_flax.py index fec9170f7..54dd1af79 100644 --- a/lib/python/EasyDel/modules/palm/modelling_palm_flax.py +++ b/lib/python/EasyDel/modules/palm/modelling_palm_flax.py @@ -1,5 +1,4 @@ from typing import Union, Optional, Tuple, Any, Mapping - import jax import jax.numpy as jnp import numpy as onp @@ -7,39 +6,12 @@ from einops import rearrange import flax.linen as nn from flax.core import FrozenDict - from jax import numpy as np from transformers.modeling_flax_outputs import FlaxCausalLMOutput from transformers import PretrainedConfig -from jax.experimental.pjit import with_sharding_constraint as wsc from jax.sharding import PartitionSpec -from jax.interpreters import pxla - - -def get_names_from_parition_spec(partition_specs): - names = set() - if isinstance(partition_specs, dict): - partition_specs = partition_specs.values() - for item in partition_specs: - if item is None: - continue - elif isinstance(item, str): - names.add(item) - else: - names.update(get_names_from_parition_spec(item)) - - return list(names) - - -def names_in_mesh(*names): - return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) - - -def with_sharding_constraint(x, partition_specs): - axis_names = get_names_from_parition_spec(partition_specs) - if names_in_mesh(*axis_names): - x = wsc(x, partition_specs) - return x +from ..flax_modelling_utils import get_gradient_checkpoint_policy, \ + with_sharding_constraint class PalmConfig(PretrainedConfig): diff --git a/lib/python/EasyDel/modules/t5/modelling_t5_flax.py b/lib/python/EasyDel/modules/t5/modelling_t5_flax.py index 701c227d3..5bf302a32 100644 --- a/lib/python/EasyDel/modules/t5/modelling_t5_flax.py +++ b/lib/python/EasyDel/modules/t5/modelling_t5_flax.py @@ -45,31 +45,8 @@ from jax.experimental.pjit import with_sharding_constraint as wsc from jax.sharding import PartitionSpec - -def get_names_from_parition_spec(partition_specs): - names = set() - if isinstance(partition_specs, dict): - partition_specs = partition_specs.values() - for item in partition_specs: - if item is None: - continue - elif isinstance(item, str): - names.add(item) - else: - names.update(get_names_from_parition_spec(item)) - - return list(names) - - -def names_in_mesh(*names): - return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) - - -def with_sharding_constraint(x, partition_specs): - axis_names = get_names_from_parition_spec(partition_specs) - if names_in_mesh(*axis_names): - x = wsc(x, partition_specs) - return x +from ..flax_modelling_utils import get_gradient_checkpoint_policy, \ + with_sharding_constraint class T5Config(PretrainedConfig): diff --git a/requirements.txt b/requirements.txt index 1ef42f866..3abd2c46d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ flax~=0.7.1 fjutils~=0.0.16 numpy~=1.25.2 typing~=3.7.4.3 -transformers>=4.31.0 +transformers>=4.33.0 einops~=0.6.1 optax~=0.1.7 msgpack~=1.0.5