diff --git a/.github/workflows/jekyll-gh-pages.yml b/.github/workflows/jekyll-gh-pages.yml new file mode 100644 index 000000000..559bddf57 --- /dev/null +++ b/.github/workflows/jekyll-gh-pages.yml @@ -0,0 +1,51 @@ +# Sample workflow for building and deploying a Jekyll site to GitHub Pages +name: Deploy Jekyll with GitHub Pages dependencies preinstalled + +on: + # Runs on pushes targeting the default branch + push: + branches: ["main"] + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. +# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. +concurrency: + group: "pages" + cancel-in-progress: false + +jobs: + # Build job + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup Pages + uses: actions/configure-pages@v3 + - name: Build with Jekyll + uses: actions/jekyll-build-pages@v1 + with: + source: ./ + destination: ./_site + - name: Upload artifact + uses: actions/upload-pages-artifact@v2 + + # Deployment job + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v2 diff --git a/.idea/EasyDeL.iml b/.idea/EasyDeL.iml index b5ad51ab9..81c471349 100644 --- a/.idea/EasyDeL.iml +++ b/.idea/EasyDeL.iml @@ -2,7 +2,7 @@ - + diff --git a/.idea/misc.xml b/.idea/misc.xml index a971a2c93..d16004ac2 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/EasyDel/modules/falcon/modelling_falcon_flax.py b/EasyDel/modules/falcon/modelling_falcon_flax.py index aa99a28b8..b515469ca 100644 --- a/EasyDel/modules/falcon/modelling_falcon_flax.py +++ b/EasyDel/modules/falcon/modelling_falcon_flax.py @@ -63,8 +63,8 @@ def __init__( hidden_size: int = 64, num_hidden_layers: int = 32, num_attention_heads: int = 71, - n_layers:int=32, - n_heads:int=71, + n_layers: int = 32, + n_heads: int = 71, layer_norm_epsilon: float = 1e-5, initializer_range: float = 0.02, use_cache: bool = True, @@ -520,9 +520,14 @@ class FlaxFalconPretrainedModel(FlaxPreTrainedModel): module_class: nn.Module = None config_class = FalconConfig - def __init__(self, config, _do_init=False, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, - input_shape: Tuple = (1, 12)): - module = self.module_class(config=config, dtype=dtype, param_dtype=param_dtype) + def __init__(self, config, + _do_init=False, + dtype: jnp.dtype = jnp.float32, + param_dtype: jnp.dtype = jnp.float32, + input_shape: Tuple = (1, 1024), + precision: Optional[Union[None, jax.lax.Precision]] = jax.lax.Precision('fastest') + ): + module = self.module_class(config=config, dtype=dtype, param_dtype=param_dtype, precision=precision) super().__init__(_do_init=_do_init, module=module, config=config, dtype=dtype, input_shape=input_shape) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict: diff --git a/EasyDel/modules/llama/modelling_llama_flax.py b/EasyDel/modules/llama/modelling_llama_flax.py index 4a60840b9..d418aaec7 100644 --- a/EasyDel/modules/llama/modelling_llama_flax.py +++ b/EasyDel/modules/llama/modelling_llama_flax.py @@ -64,7 +64,7 @@ def get_gradient_checkpoint_policy(name): class LlamaConfig(PretrainedConfig): - model_type = "Llama" + model_type = "Llama.md" def __init__( self, diff --git a/EasyDel/modules/mosaic_mpt/modelling_mpt_flax.py b/EasyDel/modules/mosaic_mpt/modelling_mpt_flax.py index 77e016d78..cf8efb271 100644 --- a/EasyDel/modules/mosaic_mpt/modelling_mpt_flax.py +++ b/EasyDel/modules/mosaic_mpt/modelling_mpt_flax.py @@ -68,7 +68,7 @@ def __init__(self, resid_prob_drop: float = 0.0, emb_prob_drop: float = 0.0, alibi: bool = True, - use_bias: bool = True, + use_bias: bool = False, learned_pos_emb: bool = True, act_fn: str = 'gelu', logit_scale: Optional[Union[float, str]] = None, @@ -76,7 +76,7 @@ def __init__(self, verbose: int = 0, embedding_fraction: float = 1.0, use_cache: bool = False, - qk_ln: bool = True, + qk_ln: bool = False, use_lm_head: bool = False, use_norm_bias: bool = False, gradient_checkpointing: str = 'nothing_saveable', @@ -207,6 +207,9 @@ def add_jax_args(self, flash_attn_key_chunk_size: int = 2048, **kwargs ): + if hasattr(self, 'attn_config'): + for k, v in self.attn_config.items(): + setattr(self, k, v) basics = dict( d_model=d_model, n_heads=n_heads, @@ -237,6 +240,7 @@ def add_jax_args(self, ) for k, v in basics.items(): if not hasattr(self, k): + print(f' Key {k} not found in loaded config setting that to default of {v}') setattr(self, k, v) self.from_pt = False @@ -282,8 +286,8 @@ def setup(self) -> None: dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision) self.act = ACT2FN[self.config.act_fn] - def __call__(self, x: jax.Array): - return self.down(self.act(self.up(x))) + def __call__(self, hidden_state: jax.Array): + return self.down(self.act(self.up(hidden_state))) class FlaxMptAttention(nn.Module): @@ -303,10 +307,16 @@ def setup(self) -> None: self.k_ln = nn.LayerNorm(use_bias=self.config.use_norm_bias) self.causal_mask = nn.make_causal_mask(jnp.ones((1, self.config.max_seq_len))) - def __call__(self, x, attn_bias=None, attention_mask=None): - inp_shape = x.shape + def __call__(self, hidden_state, attn_bias=None, attention_mask=None): + inp_shape = hidden_state.shape b, s, ds = inp_shape - qkv = self.w_qkv(x) + if attention_mask is not None: + _, s = attention_mask.shape + assert inp_shape[ + 1] == s, (f'hidden_state_size on hidden_state shape' + f' ({inp_shape[1]}) and attention_mask ({s}) miss match' + f' attention Shape : {attention_mask.shape} | hidden Shape : {hidden_state.shape}') + qkv = self.w_qkv(hidden_state) q, k, v = jnp.split(qkv, 3, -1) if self.config.qk_ln: q = self.q_ln(q) @@ -361,7 +371,11 @@ def __call__(self, x, attn_bias=None, attention_mask=None): atw += attn_bias mask = jnp.where(self.causal_mask == 1, 0, jnp.finfo(atw).min) if attention_mask is not None: - attention_mask = jnp.where(attention_mask.reshape(b, 1, 1, s) == 1, 0, jnp.finfo(atw).min) + attention_mask = jnp.where( + attention_mask.reshape(b, 1, 1, -1) == 1, + 0, + jnp.finfo(atw).min + ) atw += attention_mask atw += mask[:, :, :s, :s] atw = nn.softmax(atw, -1) @@ -383,10 +397,11 @@ def setup(self) -> None: self.ffn = FlaxMptMLP(config=self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision) - def __call__(self, x, attn_bias=None, attention_mask=None): - x = self.attn(self.norm_1(x), attn_bias=attn_bias, attention_mask=attention_mask) + x - x = self.ffn(self.norm_2(x)) + x - return x + def __call__(self, hidden_state, attn_bias=None, attention_mask=None): + hidden_state = (self.attn(self.norm_1(hidden_state), attn_bias=attn_bias, attention_mask=attention_mask) + + hidden_state) + hidden_state = self.ffn(self.norm_2(hidden_state)) + hidden_state + return hidden_state def get_gradient_checkpoint_policy(name): @@ -426,10 +441,10 @@ def setup(self) -> None: ) ] - def __call__(self, x, attn_bias=None, attention_mask=None): + def __call__(self, hidden_state, attn_bias=None, attention_mask=None): for block in self.blocks: - x = block(x=x, attn_bias=attn_bias, attention_mask=attention_mask) - return x + hidden_state = block(hidden_state=hidden_state, attn_bias=attn_bias, attention_mask=attention_mask) + return hidden_state def build_alibi(max_length, num_attention_heads, alibi_max: int = 8): @@ -516,15 +531,18 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz def __call__(self, input_ids, attention_mask=None, - params=None, + params: dict = None, add_params_field: bool = False, return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None): params = {'params': params or self.params} if add_params_field else params or self.params + input_ids = jnp.asarray(input_ids, dtype='i4') + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids, dtype='i4') predict = self.module.apply( params, - input_ids=jnp.asarray(input_ids, dtype='i4'), - attention_mask=jnp.asarray(attention_mask, dtype='i4') if attention_mask is not None else attention_mask, + input_ids=input_ids, + attention_mask=jnp.asarray(attention_mask, dtype='i4'), return_dict=return_dict, extra_embedding=extra_embedding ) diff --git a/EasyDel/serve/serve_utils.py b/EasyDel/serve/serve_utils.py index b2e13a8dd..d413b212b 100644 --- a/EasyDel/serve/serve_utils.py +++ b/EasyDel/serve/serve_utils.py @@ -479,20 +479,22 @@ def forward_chat(self, data: ChatRequest): 'status': "down" } - history = self.chat_format( + string = self.chat_format( prompt=data.prompt, system=None, history=data.history ) - response, used_tokens = self.process( - string=history, - greedy=data.greedy, - max_new_tokens=None - ) + response, used_tokens = [None] * 2 + for response, used_tokens in self.process( + string=string, + greedy=data.greedy, + max_new_tokens=None + ): + ... self.number_of_served_request_until_last_up_time += 1 return { - 'input': f'{history}', + 'input': f'{string}', 'response': response, 'tokens_used': used_tokens, } @@ -504,12 +506,13 @@ def forward_instruct(self, data: InstructRequest): } string = self.config.instruct_format.format(instruct=data.prompt, system=data.system) - - response, used_tokens = self.process( - string=string, - greedy=data.greedy, - max_new_tokens=None - ) + response, used_tokens = [None] * 2 + for response, used_tokens in self.process( + string=string, + greedy=data.greedy, + max_new_tokens=None + ): + ... self.number_of_served_request_until_last_up_time += 1 return { 'input': f'{string}', @@ -534,7 +537,7 @@ def forward_chat_non_api(self, prompt, history, greedy): return self.forward_chat(data) def process(self, - string, + string: str, *, greedy: bool = False, max_new_tokens: int = None, diff --git a/EasyDel/transform/mpt.py b/EasyDel/transform/mpt.py index 18ebd6fe7..6768bfca6 100644 --- a/EasyDel/transform/mpt.py +++ b/EasyDel/transform/mpt.py @@ -3,6 +3,7 @@ import jax import torch import numpy as np +from transformers import AutoModelForCausalLM def mpt_convert_flax_to_pt_7b(state_dict_flax, n_layers: int, device=torch.device('cpu'), use_lm_head=False): @@ -44,22 +45,22 @@ def mpt_convert_pt_to_flax_7b(state_dict, n_layers: int, device=jax.devices('cpu state_dict_flax = {('transformer', 'wte', 'embedding'): state_dict[ 'transformer.wte.weight'].cpu().detach().numpy()} for i in range(n_layers): - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('norm_1'), ('scale')] = state_dict[ + state_dict_flax[('transformer', 'h', f'{i}', 'norm_1', 'scale')] = state_dict[ f'transformer.blocks.{i}.norm_1.weight'].cpu().detach().numpy() - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('norm_2'), ('scale')] = state_dict[ + state_dict_flax[('transformer', 'h', f'{i}', 'norm_2', 'scale')] = state_dict[ f'transformer.blocks.{i}.norm_2.weight'].cpu().detach().numpy() - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('ffn'), ('down'), ('kernel')] = jnp.transpose( + state_dict_flax[('transformer', 'h', f'{i}', 'ffn', 'down', 'kernel')] = jnp.transpose( state_dict[f'transformer.blocks.{i}.ffn.down_proj.weight'].cpu().detach().numpy(), (1, 0)) - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('ffn'), ('up'), ('kernel')] = jnp.transpose( + state_dict_flax[('transformer', 'h', f'{i}', 'ffn', 'up', 'kernel')] = jnp.transpose( state_dict[f'transformer.blocks.{i}.ffn.up_proj.weight'].cpu().detach().numpy(), (1, 0)) - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('w_qkv'), ('kernel')] = jnp.transpose( + state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'w_qkv', 'kernel')] = jnp.transpose( state_dict[f'transformer.blocks.{i}.attn.Wqkv.weight'].cpu().detach().numpy(), (1, 0)) - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('wo'), ('kernel')] = jnp.transpose( + state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'wo', 'kernel')] = jnp.transpose( state_dict[f'transformer.blocks.{i}.attn.out_proj.weight'].cpu().detach().numpy(), (1, 0)) - state_dict_flax[('transformer'), ('norm_f'), ('scale')] = state_dict[ + state_dict_flax[('transformer', 'norm_f', 'scale')] = state_dict[ f'transformer.norm_f.weight'].cpu().detach().numpy() if use_lm_head: - state_dict_flax[('lm_head'), ('kernel')] = jnp.transpose( + state_dict_flax[('lm_head', 'kernel')] = jnp.transpose( state_dict[f'lm_head.weight'].cpu().detach().numpy(), (1, 0)) return state_dict_flax @@ -67,29 +68,29 @@ def mpt_convert_pt_to_flax_7b(state_dict, n_layers: int, device=jax.devices('cpu def mpt_convert_pt_to_flax_1b(state_dict, n_layers: int, device=jax.devices('cpu')[0], use_lm_head=False, ): # CONVERTER MPT-1B with jax.default_device(device): - state_dict_flax = {(('transformer'), ('wte'), ('embedding')): state_dict[ + state_dict_flax = {(('transformer', 'wte', 'embedding')): state_dict[ 'transformer.wte.weight'].cpu().detach().numpy()} for i in range(n_layers): - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('norm_1'), ('scale')] = state_dict[ + state_dict_flax[('transformer', 'h', f'{i}', 'norm_1', 'scale')] = state_dict[ f'transformer.blocks.{i}.ln_1.weight'].cpu().detach().numpy() - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('norm_2'), ('scale')] = state_dict[ + state_dict_flax[('transformer', 'h', f'{i}', 'norm_2', 'scale')] = state_dict[ f'transformer.blocks.{i}.ln_2.weight'].cpu().detach().numpy() - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('ffn'), ('down'), ('kernel')] = jnp.transpose( + state_dict_flax[('transformer', 'h', f'{i}', 'ffn', 'down', 'kernel')] = jnp.transpose( state_dict[f'transformer.blocks.{i}.mlp.mlp_down.weight'].cpu().detach().numpy(), (1, 0)) - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('ffn'), ('up'), ('kernel')] = jnp.transpose( + state_dict_flax[('transformer', 'h', f'{i}', 'ffn', 'up', 'kernel')] = jnp.transpose( state_dict[f'transformer.blocks.{i}.mlp.mlp_up.weight'].cpu().detach().numpy(), (1, 0)) - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('w_qkv'), ('kernel')] = jnp.transpose( + state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'w_qkv', 'kernel')] = jnp.transpose( state_dict[f'transformer.blocks.{i}.attn.Wqkv.weight'].cpu().detach().numpy(), (1, 0)) - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('wo'), ('kernel')] = jnp.transpose( + state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'wo', 'kernel')] = jnp.transpose( state_dict[f'transformer.blocks.{i}.attn.out_proj.weight'].cpu().detach().numpy(), (1, 0)) - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('q_ln'), ('scale')] = state_dict[ + state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'q_ln', 'scale')] = state_dict[ f'transformer.blocks.{i}.attn.q_ln.weight'].cpu().detach().numpy() - state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('k_ln'), ('scale')] = state_dict[ + state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'k_ln', 'scale')] = state_dict[ f'transformer.blocks.{i}.attn.k_ln.weight'].cpu().detach().numpy() - state_dict_flax[('transformer'), ('norm_f'), ('scale')] = state_dict[ + state_dict_flax[('transformer', 'norm_f', 'scale')] = state_dict[ f'transformer.ln_f.weight'].cpu().detach().numpy() if use_lm_head: - state_dict_flax[('lm_head'), ('kernel')] = jnp.transpose( + state_dict_flax[('lm_head', 'kernel')] = jnp.transpose( state_dict[f'lm_head.weight'].cpu().detach().numpy(), (1, 0)) return state_dict_flax @@ -140,15 +141,16 @@ def mpt_convert_flax_to_pt_1b(state_dict_flax, n_layers: int, device=torch.devic return state_dict -def mpt_from_pretrained(model_id, device=jax.devices('cpu')[0]): +def mpt_from_pretrained(model_id, device=jax.devices('cpu')[0], **kwargs): """ return: Weight or Params for EasyDel Model , Config """ config = MptConfig.from_pretrained(model_id) - model = FlaxMptForCausalLM.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, **kwargs) + easydel_wights = mpt_convert_pt_to_flax_7b( state_dict=model.state_dict(), - n_layers=config.num_hidden_layers, + n_layers=config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else config.n_layers, device=device ) config.add_jax_args() diff --git a/LLAMA.md b/LLAMA.md index e9e5b1c8a..04b93b728 100644 --- a/LLAMA.md +++ b/LLAMA.md @@ -8,7 +8,7 @@ ```shell python -m examples.serving.causal-lm.llama-2-chat \ - --repo_id='meta-llama/Llama-2-7b-chat-hf' --max_length=4096 \ + --repo_id='meta-llama/Llama.md-2-7b-chat-hf' --max_length=4096 \ --max_new_tokens=2048 --max_stream_tokens=32 --temperature=0.6 \ --top_p=0.95 --top_k=50 \ --dtype='fp16' --use_prefix_tokenizer @@ -62,7 +62,7 @@ from transformers import AutoTokenizer, GenerationConfig from functools import partial # Let Use this model since runs fast and light, but you can use all the available llama models like llama, llama2, xgen... -model_id = 'meta-llama/Llama-2-7b-chat-hf' +model_id = 'meta-llama/Llama.md-2-7b-chat-hf' tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True # Optional ) diff --git a/README.md b/README.md index ac7be036c..d2a1fee38 100644 --- a/README.md +++ b/README.md @@ -6,39 +6,45 @@ train Flax/Jax Models on the `TPU/GPU` both for Serving and Training #### Note this Library needs golang to run (for some tracking stuff on TPU/GPU/CPU) -install go on ubuntu be like +#### Ubuntu GO installation ```shell sudo apt-get update && apt-get upgrade -y sudo apt-get install golang -y ``` -and you need Jax>=0.4.10 and FJutils>=0.0.15 +#### Manjaro/Arch GO installation -on TPUs be like +```shell +sudo pacman -Syyuu go +``` _you can install other version too but easydel required at least version of 0.4.10_ ```shell !pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -q - ``` on GPUs be like ```shell pip install --upgrade pip - # CUDA 12 installation # Note: wheels only available on linux. pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``` +```shell +pip install --upgrade pip # CUDA 11 installation # Note: wheels only available on linux. pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - ``` +## Documentation + +Tadadad (Magic Sound) 💫 finally documents are ready at [EasyDel/Docs](https://erfanzar.github.io/EasyDeL/docs) + ## Installation ### Available on PyPi @@ -86,7 +92,7 @@ chat model (70B model is supported too) ```shell python -m examples.serving.causal-lm.llama-2-chat \ - --repo_id='meta-llama/Llama-2-7b-chat-hf' --max_length=4096 \ + --repo_id='meta-llama/Llama.md-2-7b-chat-hf' --max_length=4096 \ --max_new_tokens=2048 --max_stream_tokens=32 --temperature=0.6 \ --top_p=0.95 --top_k=50 \ --dtype='fp16' --use_prefix_tokenizer @@ -103,8 +109,9 @@ and you will get links or api to use model from gradio app chat/instruct or Fast ## RLHF(Reinforcement Learning From Human Feedback) -RLHF or Reinforcement Learning From Human Feedback is going to be available in the next -versions of EasyDel +`RLHF` or Reinforcement Learning From Human Feedback is Available At the moment, but it's still +under heavy development , because i don't have enough experience with Reinforcement Learning at the moment so its still +in beta version but it's works and ill soon release a Tutorial For that ## FineTuning @@ -197,7 +204,7 @@ from transformers import AutoTokenizer import jax -model_id = 'meta-llama/Llama-2-7b-chat-hf' +model_id = 'meta-llama/Llama.md-2-7b-chat-hf' tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) diff --git a/docs/Falcon.md b/docs/Falcon.md new file mode 100644 index 000000000..22bfce4e6 --- /dev/null +++ b/docs/Falcon.md @@ -0,0 +1,127 @@ +# About Falcon Models + +Sure, here is a document about Falcon Models: + +**Falcon Models** + +Falcon Models is a family of large language models (LLMs) developed by the Technology Innovation Institute (TII) in Abu +Dhabi. The models are trained on a massive dataset of text and code, and can be used for a variety of tasks, including + +* Natural language understanding (NLU) +* Natural language generation (NLG) +* Machine translation +* Text summarization +* Question answering +* Code generation + +The Falcon models are available under the Apache 2.0 license, which means that they can be freely used, modified, and +redistributed. + +**Falcon-40B** + +The Falcon-40B is the largest model in the Falcon family. It has 40 billion parameters, and is trained on a dataset of +500 billion words. The model is capable of state-of-the-art performance on a variety of NLP tasks. + +**Falcon-7B** + +The Falcon-7B is a smaller version of the Falcon-40B. It has 7 billion parameters, and is trained on a dataset of 100 +billion words. The model is still capable of achieving strong performance on NLP tasks, but it is more efficient to +train and deploy. + +**Falcon-180B** + +The Falcon-180B is the newest model in the Falcon family. It has 180 billion parameters, and is trained on a dataset of +2 trillion words. The model is the largest openly available LLM, and it is capable of achieving state-of-the-art +performance on a variety of NLP tasks. + +**Use Cases** + +The Falcon models can be used for a variety of tasks, including: + +* Natural language understanding (NLU): The Falcon models can be used to understand the meaning of text, such as + identifying the entities and relationships in a sentence. +* Natural language generation (NLG): The Falcon models can be used to generate text, such as writing different kinds of + creative content, like poems, code, scripts, musical pieces, email, letters, etc. +* Machine translation: The Falcon models can be used to translate text from one language to another. +* Text summarization: The Falcon models can be used to summarize a text document into a shorter, more concise version. +* Question answering: The Falcon models can be used to answer questions about a text document. +* Code generation: The Falcon models can be used to generate code, such as Python scripts or Java classes. + +**Availability** + +The Falcon models are available through the Hugging Face Hub. The models are also available in the TensorFlow Hub and +the PyTorch Hub ( and EasyDel). + +**Conclusion** + +The Falcon models are a powerful family of LLMs that can be used for a variety of tasks. The models are open source and +available for free, making them a valuable resource for researchers and developers. + +## How to Use/Load Them in EasyDel + +```python +import jax +from EasyDel.transform import falcon_from_pretrained + +params, config = falcon_from_pretrained( + 'tiiuae/falcon-7b', + device=jax.devices('cpu')[0] # Offload on CPU +) +``` + +also keep that in mind that returned `config` includes `.get_partition_rules(fsdp=True)` + +#### Use With JaxServer + +```python +from EasyDel import JAXServer, FlaxFalconForCausalLM +import jax +from EasyDel.transform import falcon_from_pretrained +from transformers import AutoTokenizer + +params, config = falcon_from_pretrained( + 'tiiuae/falcon-7b', + device=jax.devices('cpu')[0] # Offload on CPU +) + + +class FalconJaxServer(JAXServer): + ... + # You have to Custom this one yourself as you + # need read JaxServer Documents inorder to learn how + + +server = FalconJaxServer.load_from_params( + params=params, + model=FlaxFalconForCausalLM( + config=config, + dtype=jax.numpy.bfloat16, # Im on TPUs + param_dtype=jax.numpy.bfloat16, # Im on TPUs + precision=jax.lax.Precision('fastest'), + _do_init=False, + input_shape=(1, 1024) + ), + config_model=config, + add_param_field=True, + tokenizer=AutoTokenizer.from_pretrained('tiiuae/falcon-7b'), + verbose=False, + do_memory_log=True, + config={ + "max_length": 2048, + "max_new_tokens": 2048, + "max_stream_tokens": 64, + "dtype": 'bf16', + "use_prefix_tokenizer": True, + 'pre_compile': True + } +) + +server.fire() # Launch FastAPI functions + +shared_urls = server.launch( + share_chat=True, + share_inst=True +) +``` + +Done 😇 this method can be used for all the Falcon models \ No newline at end of file diff --git a/docs/JAXServer.md b/docs/JAXServer.md new file mode 100644 index 000000000..a560ff3c5 --- /dev/null +++ b/docs/JAXServer.md @@ -0,0 +1,320 @@ +## JAXServer 🧬 + +`JAXServer` is one of offered utilities by EasyDel, and it's help hosting using and doing process with LLMs +and its also hackable, so you can override your own method in it and use it support both mid-level and high-level apis +and also give you a Gradio Chat and Instruct Pre-build and ready to use page + +* Supported Models are: + * EveryModel that have `transformers.FlaxPretrainedModel` as their Parent :) + +### Input Configs + +The config input is a dictionary that contains the following keys: + +* `port`: The port number that the server will listen on. + * _Default Value has been set to `2059`_ +* `batch_size`: The batch size for training. + * _Default Value has been set to `1`_ +* `max_length`: The maximum length of a sequence. + * _Default Value has been set to `2048`_ +* `max_new_tokens`: The maximum number of new tokens generated by the model in a single step. + * _Default Value has been set to `2048`_ +* `max_stream_tokens`: The maximum number of tokens that can be streamed to the model in a single batch. + * _Default Value has been set to `32`_ +* `temperature`: The temperature parameter for sampling from the model's output distribution. + * _Default Value has been set to `0.1`_ +* `top_p`: The top-p parameter for sampling from the model's output distribution. + * _Default Value has been set to `0.95`_ +* `top_k`: The top-k parameter for sampling from the model's output distribution. + * _Default Value has been set to `50`_ +* `mesh_axes_shape`: The shape of the mesh axes for distributed training. + * _Default Value has been set to `(1, -1, 1)`_ +* `host`: The host address for the server. + * _Default Value has been set to `'0.0.0.0'`_ +* `dtype`: The data type for the model's parameters. + * _Default Value has been set to `'fp16'`_ +* `mesh_axes_names`: The names of the mesh axes for distributed training. + * _Default Value has been set to `('dp', 'fsdp', 'mp')`_ +* `system_prefix`: The prefix that will be prepended to system messages. + * _Default Value has been set to `''`_ +* `system`: The system message to be displayed. + * _Default Value has been set to `''`_ +* `prompt_prefix_instruct`: The prefix that will be prepended to instruction prompts. + * _Default Value has been set to `''`_ +* `prompt_postfix_instruct`: The postfix that will be appended to instruction prompts. + * _Default Value has been set to `''`_ +* `prompt_prefix_chat`: The prefix that will be prepended to chat prompts. + * _Default Value has been set to `<|prompter|>`_ +* `prompt_postfix_chat`: The postfix that will be appended to chat prompts. + * _Default Value has been set to `<|assistant|>`_ +* `instruct_format`: The format string for instruction prompts. + * _Default Value has been set to `### SYSTEM:\n{system}\n### INSTRUCT: + \n{instruct}\n### ASSISTANT:\n`_ +* `chat_format`: The format string for chat prompts. + * _Default Value has been set to `'<|prompter|>{prompt}<|assistant|>{assistant}'`_ +* `chat_prefix`: The prefix that will be prepended to chat responses. + * _Default Value has been set to `''`_ +* `contains_auto_format`: Whether the model should automatically format instruction prompts. + * _Default Value has been set to `True`_ +* `logging`: Whether the model should log its training progress.: + * _Default Value has been set to `True`_ +* `stream_tokens_for_gradio`: Whether the model should stream tokens to Gradio. + * _Default Value has been set to `True`_ +* `use_prefix_tokenizer`: Whether the model should use a prefix tokenizer. + * _Default Value has been set to `True`_ +* `pre_compile`: Whether the model should be pre-compiled. + * _Default Value has been set to `True`_ + +## JAXServer Functions + +`JAXServer` Contains a method named `.process` and with using `process` method you can generate text from text + +what does this do and how this works ? here's the inputs that `process` function takes in + +```python +def process(self, + string, + *, + greedy: bool = False, + max_new_tokens: int = None, + **kwargs + ) -> [str, int]: + ... +``` + +* _Arguments_: + * string : String to be tokenized `(String)` + * Greedy : Use Greedy Search Method or NO `(Bool)` + * Max New Tokens : Number Of new Tokens to be Generated `(Int)` +* _Yields_: + * String : Next Tokens Predicted to String `(String)` + * Number of Used Tokens : Number of Used Tokens to generate answer `(Int)` + +you can use this function outside the class like this + +```python +for string, num_used_tokens in server.process( + 'im an string', + greedy=False, + max_new_tokens=256 # or None to use Maximum numbers passed in Config +): + print(f'\r{num_used_tokens}: {string}', end='') +``` + +### Gradio Functions 🤖 + +if you want to change gradio response functions you can override them like this + +#### Chat Gradio Function + +this is the default gradio functions and this is how it looks : + +```python +def process_gradio_chat(self, prompt, history, max_new_tokens, system, greedy): + string = self.chat_format(history=history, prompt=prompt, system=system) + + if not self.config.stream_tokens_for_gradio: + response = '' + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens, + ): + ... + history.append([prompt, response]) + else: + history.append([prompt, '']) + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens, + ): + history[-1][-1] = response + yield '', history + return '', history +``` + +and here's a example of changing that in order to use Llama Models + +```python +def process_gradio_chat(self, prompt, history, max_new_tokens, system, greedy): + def prompt_llama2_model(message: str, chat_history, + system_prompt: str) -> str: + + do_strip = False + texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] + for user_input, response in chat_history: + user_input = user_input.strip() if do_strip else user_input + do_strip = True + texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') + message = message.strip() if do_strip else message + texts.append(f'{message} [/INST]') + return ''.join(texts) + + string = prompt_llama2_model( + message=prompt, + chat_history=history or [], + system_prompt=system + ) + if not self.config.stream_tokens_for_gradio: + response = '' + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens, + ): + ... + history.append([prompt, response]) + else: + history.append([prompt, '']) + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens + ): + history[-1][-1] = response + yield '', history + + return '', history + +``` + +as you see you can easily override the functions just like how you want and use them with some simple changes, +and you can Also Use Their `Gradio Client` or use `JAXServer` `FastAPI` builtin methods + +### FastAPI 🌪 + +#### Instruct API + +to Override this api you have to code `forward_instruct` just like what you want the default implementation of this +function is + +```python +def forward_instruct(self, data: InstructRequest): + if not self._funcs_generated: + return { + 'status': "down" + } + + string = self.config.instruct_format.format(instruct=data.prompt, system=data.system) + response, used_tokens = [None] * 2 + for response, used_tokens in self.process( + string=string, + greedy=data.greedy, + max_new_tokens=None + ): + ... + self.number_of_served_request_until_last_up_time += 1 + return { + 'input': f'{string}', + 'response': response, + 'tokens_used': used_tokens, + } +``` + +* BaseModel Class For PYData in FastAPI : + +```python +class InstructRequest(BaseModel): + prompt: str + system: Optional[str] = None + temperature: Optional[float] = None + greedy: Optional[bool] = False +``` + +* And here's an example of using this api via python and creating a simple client with using `requests` library in + python : + +```python +import requests + +content = { + 'prompt': 'can you code a simple neural network in c++ for me', + 'system': 'You are an AI assistant generate short and useful response', + 'temperature': 0.1, + 'greedy': False +} + +response = requests.post( + url='http://ip:port/instruct', + json=content +).json() + +print(response['response']) +# Response of model +print(response['input']) +# The input passed to the model + +``` + +#### Chat API + +to Override this api you have to code `forward_chat` just like what you want the default implementation of this function +is + +```python +def forward_chat(self, data: ChatRequest): + if not self._funcs_generated: + return { + 'status': "down" + } + + history = self.process_chat_history(data.history or []) + history += self.config.prompt_prefix_chat + data.prompt + self.config.prompt_postfix_chat + + response, used_tokens = [None] * 2 + for response, used_tokens in self.process( + string=history, + greedy=data.greedy, + max_new_tokens=None + ): + ... + self.number_of_served_request_until_last_up_time += 1 + return { + 'input': f'{history}', + 'response': response, + 'tokens_used': used_tokens, + } +``` + +* BaseModel Class For PYData in FastAPI : + +```python +class ChatRequest(BaseModel): + prompt: str + history: Union[List[List], None] = None + temperature: Optional[float] = None + greedy: Optional[bool] = False +``` + +* And here's an example of using this api via python and creating a simple client with using `requests` library in + python : + +```python +import requests + +content = { + 'prompt': 'can you code a simple neural network in c++ for me', + 'history': [ + ['hello how are you', 'Hello\nthanks, im here to assist you you have any question that i could help you with'] + ], + 'temperature': 0.1, + 'greedy': False +} + +response = requests.post( + url='http://ip:port/chat', + json=content +).json() + +print(response['response']) +# Response of model +print(response['input']) +# The input passed to the model + +``` + +#### Status 📣 + +Simply by sending a get API to `https://ip:port/status` you will receive base information about the server and +how it being run, num cores in use, number of generated prompt , number of request and ... diff --git a/docs/Llama.md b/docs/Llama.md new file mode 100644 index 000000000..8a9f47e9f --- /dev/null +++ b/docs/Llama.md @@ -0,0 +1,197 @@ +# About Llama Models + +* **Introduction** + +Llama models are a family of large language models (LLMs) developed by Meta AI. They are trained on a massive dataset of +text and code, and they can be used for a variety of tasks, such as text generation, translation, summarization, +question answering, code generation, and natural language inference. + +* **Model Architecture** + +Llama models are based on the Transformer architecture, which is a neural network architecture that has been shown to be +very effective for natural language processing tasks. The Transformer architecture uses self-attention to learn +long-range dependencies between words in a sentence. + +* **Training Data** + +Llama models are trained on a massive dataset of text and code. The text dataset includes text from a variety of +sources, such as books, articles, and websites. The code dataset includes code from a variety of programming languages, +such as Python, Java, and C++. + +* **Fine-tuning** + +After being pre-trained on a massive dataset, Llama models can be fine-tuned for specific tasks. Fine-tuning involves +training the model on a smaller dataset of data that is relevant to the specific task. + +* **Applications** + +Llama models can be used for a variety of tasks, such as: + + * Text generation: Llama models can be used to generate text, such as poems, code, scripts, and musical pieces. + * Translation: Llama models can be used to translate text from one language to another. + * Summarization: Llama models can be used to summarize text. + * Question answering: Llama models can be used to answer questions about text. + * Code generation: Llama models can be used to generate code. + * Natural language inference: Llama models can be used to determine the relationship between two sentences. + +* **Availability** + +Llama models are available for free for research and commercial use. They can be downloaded from the Hugging Face Hub. + +* **Limitations** + +Llama models are still under development, and they have some limitations. For example, they can sometimes generate +incorrect or misleading text. They can also be biased, reflecting the biases that are present in the training data. + +* **Future Work** + +Llama models are a promising new technology with the potential to be used for a variety of applications. Future work on +Llama models will focus on improving their accuracy, reducing their bias, and making them more robust to errors. + +* Text generation +* Translation +* Summarization +* Question answering +* Code generation +* Natural language inference + +Here is a table comparing the different sizes of Llama models: + +| Model | Parameters | +|-----------|------------| +| Llama 7B | 7 billion | +| Llama 13B | 13 billion | +| Llama 33B | 33 billion | +| Llama 65B | 65 billion | +| Llama 70B | 70 billion | + +## How to Use/Load Them in EasyDel + +```python +import jax +from EasyDel.transform import llama_from_pretrained + +params, config = llama_from_pretrained( + 'meta-llama/Llama-2-7b', + device=jax.devices('cpu')[0] # Offload on CPU +) +``` + +also keep that in mind that returned `config` includes `.get_partition_rules(fsdp=True)` + +#### Use With JaxServer + +```python +from EasyDel import JAXServer, FlaxLlamaForCausalLM +import jax +from EasyDel.transform import llama_from_pretrained +from transformers import AutoTokenizer + +params, config = llama_from_pretrained( + 'meta-llama/Llama-2-7b', + device=jax.devices('cpu')[0] # Offload on CPU +) + +DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant and act as wanted" + + +class Llama2JaxServer(JAXServer): + def process_gradio_chat(self, prompt, history, max_new_tokens, system, greedy): + + system = None if system == '' else system + string = self.prompt_llama2_model( + message=prompt, + chat_history=history or [], + system_prompt=system or DEFAULT_SYSTEM_PROMPT + ) + if not self.config.stream_tokens_for_gradio: + response = '' + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens, + ): + ... + history.append([prompt, response]) + else: + history.append([prompt, '']) + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens + ): + history[-1][-1] = response + yield '', history + + return '', history + + def process_gradio_instruct(self, prompt, system, max_new_tokens, greedy): + string = self.prompt_llama2_model(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[]) + if not self.config.stream_tokens_for_gradio: + response = '' + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens, + ): + pass + else: + response = '' + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens, + stream=True + ): + yield '', response + return '', response + + @staticmethod + def prompt_llama2_model(message: str, chat_history, + system_prompt: str) -> str: + + do_strip = False + texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] + for user_input, response in chat_history: + user_input = user_input.strip() if do_strip else user_input + do_strip = True + texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') + message = message.strip() if do_strip else message + texts.append(f'{message} [/INST]') + return ''.join(texts) + + +server = Llama2JaxServer.load_from_params( + params=params, + model=FlaxLlamaForCausalLM( + config=config, + dtype=jax.numpy.bfloat16, # Im on TPUs + param_dtype=jax.numpy.bfloat16, # Im on TPUs + precision=jax.lax.Precision('fastest'), + _do_init=False, + input_shape=(1, 1024) + ), + config_model=config, + add_param_field=True, + tokenizer=AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b'), + verbose=False, + do_memory_log=True, + config={ + "max_length": 4096, + "max_new_tokens": 4096, + "max_stream_tokens": 64, + "dtype": 'bf16', + "use_prefix_tokenizer": True, + 'pre_compile': True + } +) + +server.fire() # Launch FastAPI functions + +shared_urls = server.launch( + share_chat=True, + share_inst=True +) +``` + +Done 😇 this method can be used for all the llama models \ No newline at end of file diff --git a/docs/Llama2.md b/docs/Llama2.md new file mode 100644 index 000000000..9661fa02b --- /dev/null +++ b/docs/Llama2.md @@ -0,0 +1,188 @@ +## About Llama2 Models + +**Llama2 Models** + +Llama2 Models is a family of pretrained and fine-tuned large language models (LLMs) developed by Meta AI. The models are +trained on a massive dataset of text and code, and can be used for a variety of tasks, including + +* Natural language understanding (NLU) +* Natural language generation (NLG) +* Machine translation +* Text summarization +* Question answering +* Code generation + +The Llama2 models are available under the Apache 2.0 license, which means that they can be freely used, modified, and +redistributed. + +**Model Architecture** + +The Llama2 models are based on the Transformer architecture, which is a neural network architecture that has been shown +to be very effective for NLP tasks. The models are trained using a technique called masked language modeling, which +involves predicting the missing words in a sequence of text. + +**Model Sizes** + +The Llama2 models come in a variety of sizes, ranging from 7 billion to 70 billion parameters. The larger models have +more capacity to learn complex patterns in language, but they are also more computationally expensive to train and +deploy. + +**Fine-tuning** + +The Llama2 models are pretrained on a massive dataset of text and code, but they can be further fine-tuned on a specific +task to improve their performance. Fine-tuning involves training the model on a dataset of labeled data for the specific +task. + +**Use Cases** + +The Llama2 models can be used for a variety of tasks, including: + +* Natural language understanding (NLU): The Llama2 models can be used to understand the meaning of text, such as + identifying the entities and relationships in a sentence. +* Natural language generation (NLG): The Llama2 models can be used to generate text, such as writing different kinds of + creative content, like poems, code, scripts, musical pieces, email, letters, etc. +* Machine translation: The Llama2 models can be used to translate text from one language to another. +* Text summarization: The Llama2 models can be used to summarize a text document into a shorter, more concise version. +* Question answering: The Llama2 models can be used to answer questions about a text document. +* Code generation: The Llama2 models can be used to generate code, such as Python scripts or Java classes. + +**Availability** + +The Llama2 models are available through the Hugging Face Hub. The models are also available in the TensorFlow Hub , the +PyTorch Hub and EasyDel. + +**Conclusion** + +The Llama2 models are a powerful family of LLMs that can be used for a variety of tasks. The models are open source and +available for free, making them a valuable resource for researchers and developers. + +## How to Use/Load Them in EasyDel + +```python +import jax +from EasyDel.transform import llama_from_pretrained + +params, config = llama_from_pretrained( + 'meta-llama/Llama-2-7b', + device=jax.devices('cpu')[0] # Offload on CPU +) +``` + +also keep that in mind that returned `config` includes `.get_partition_rules(fsdp=True)` + +#### Use With JaxServer + +```python +from EasyDel import JAXServer, FlaxLlamaForCausalLM +import jax +from EasyDel.transform import llama_from_pretrained +from transformers import AutoTokenizer + +params, config = llama_from_pretrained( + 'meta-llama/Llama-2-7b', + device=jax.devices('cpu')[0] # Offload on CPU +) + +DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant and act as wanted" + + +class Llama2JaxServer(JAXServer): + def process_gradio_chat(self, prompt, history, max_new_tokens, system, greedy): + + system = None if system == '' else system + string = self.prompt_llama2_model( + message=prompt, + chat_history=history or [], + system_prompt=system or DEFAULT_SYSTEM_PROMPT + ) + if not self.config.stream_tokens_for_gradio: + response = '' + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens, + ): + ... + history.append([prompt, response]) + else: + history.append([prompt, '']) + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens + ): + history[-1][-1] = response + yield '', history + + return '', history + + def process_gradio_instruct(self, prompt, system, max_new_tokens, greedy): + string = self.prompt_llama2_model(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[]) + if not self.config.stream_tokens_for_gradio: + response = '' + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens, + ): + pass + else: + response = '' + for response, _ in self.process( + string=string, + greedy=greedy, + max_new_tokens=max_new_tokens, + stream=True + ): + yield '', response + return '', response + + @staticmethod + def prompt_llama2_model(message: str, chat_history, + system_prompt: str) -> str: + + do_strip = False + texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] + for user_input, response in chat_history: + user_input = user_input.strip() if do_strip else user_input + do_strip = True + texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') + message = message.strip() if do_strip else message + texts.append(f'{message} [/INST]') + return ''.join(texts) + + +server = Llama2JaxServer.load_from_params( + params=params, + model=FlaxLlamaForCausalLM( + config=config, + dtype=jax.numpy.bfloat16, # Im on TPUs + param_dtype=jax.numpy.bfloat16, # Im on TPUs + precision=jax.lax.Precision('fastest'), + _do_init=False, + input_shape=(1, 1024) + ), + config_model=config, + add_param_field=True, + tokenizer=AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b'), + verbose=False, + do_memory_log=True, + config={ + "max_length": 4096, + "max_new_tokens": 4096, + "max_stream_tokens": 64, + "dtype": 'bf16', + "use_prefix_tokenizer": True, + 'pre_compile': True + } +) + +server.fire() # Launch FastAPI functions + +shared_urls = server.launch( + share_chat=True, + share_inst=True +) +``` + +Done 😇 this method can be used for all the llama2 models \ No newline at end of file diff --git a/docs/Models.md b/docs/Models.md new file mode 100644 index 000000000..b742d0aad --- /dev/null +++ b/docs/Models.md @@ -0,0 +1,114 @@ +# Model 🤖💫 + +## Available Models Are + +1. **_[Llama](https://erfanzar.github.io/EasyDeL/docs/Llama)_**: + + * Supports: + * Fully Sharded Data Parallel `(FSDP)` + * MultiProcessing `(MP)` + * Data Parallel `(DP)` + * Distributed Data Parallel (DDP) `(DP)` + * Gradient CheckPointing + * Flash Attention + * BlockWise Attention + * Usage and Import from EasyDel Library + * [Usage](https://erfanzar.github.io/EasyDeL/docs/Llama) + + +2. **_[Llama2](https://erfanzar.github.io/EasyDeL/docs/Llama2)_**: + + * Supports: + * Fully Sharded Data Parallel `(FSDP)` + * MultiProcessing `(MP)` + * Data Parallel `(DP)` + * Distributed Data Parallel (DDP) `(DP)` + * Gradient CheckPointing + * Flash Attention + * BlockWise Attention + * [Usage](https://erfanzar.github.io/EasyDeL/docs/Llama2) + + +3. **_[Falcon](https://erfanzar.github.io/EasyDeL/docs/Falcon)_**: + + * Supports: + * Fully Sharded Data Parallel `(FSDP)` + * MultiProcessing `(MP)` + * Data Parallel `(DP)` + * Distributed Data Parallel (DDP) `(DP)` + * Gradient CheckPointing + * Flash Attention + * BlockWise Attention + * [Usage](https://erfanzar.github.io/EasyDeL/docs/Falcon) + + +4. **_[MosaicMPT](https://erfanzar.github.io/EasyDeL/docs/MosaicMPT)_**: + + * Supports: + * Fully Sharded Data Parallel `(FSDP)` + * MultiProcessing `(MP)` + * Data Parallel `(DP)` + * Distributed Data Parallel (DDP) `(DP)` + * Gradient CheckPointing + * Flash Attention + * BlockWise Attention + * [Usage](https://erfanzar.github.io/EasyDeL/docs/MosaicMPT) + + +5. **_GPTNeoX_** : + + * Supports: + * Fully Sharded Data Parallel `(FSDP)` + * MultiProcessing `(MP)` + * Data Parallel `(DP)` + * Distributed Data Parallel (DDP) `(DP)` + * Gradient CheckPointing + + +6. **_LT_** : + + * Supports: + * Fully Sharded Data Parallel `(FSDP)` + * MultiProcessing `(MP)` + * Data Parallel `(DP)` + * Distributed Data Parallel (DDP) `(DP)` + * Gradient CheckPointing + +7. **_Palm_**: + + * Supports: + * Fully Sharded Data Parallel `(FSDP)` + * MultiProcessing `(MP)` + * Data Parallel `(DP)` + * Distributed Data Parallel (DDP) `(DP)` + * Gradient CheckPointing + + +8. **_T5_**: + + * Supports: + * Fully Sharded Data Parallel `(FSDP)` + * MultiProcessing `(MP)` + * Data Parallel `(DP)` + * Distributed Data Parallel (DDP) `(DP)` + * Gradient CheckPointing + +9. **_GPT-J_** : + + * Supports: + * Fully Sharded Data Parallel `(FSDP)` + * MultiProcessing `(MP)` + * Data Parallel `(DP)` + * Distributed Data Parallel (DDP) `(DP)` + * Gradient CheckPointing + * Flash Attention + * BlockWise Attention + +10. **_OPT_**: + + * Supports: + * Fully Sharded Data Parallel `(FSDP)` + * MultiProcessing `(MP)` + * Data Parallel `(DP)` + * Distributed Data Parallel (DDP) `(DP)` + * Gradient CheckPointing diff --git a/docs/MosaicMPT.md b/docs/MosaicMPT.md new file mode 100644 index 000000000..52a5c2672 --- /dev/null +++ b/docs/MosaicMPT.md @@ -0,0 +1,133 @@ +# About MosaicMPT Models + +**MosaicMPT Models** + +MosaicMPT Models is a family of large language models (LLMs) developed by MosaicML. The models are trained on a massive +dataset of text and code, and can be used for a variety of tasks, including + +* Natural language understanding (NLU) +* Natural language generation (NLG) +* Machine translation +* Text summarization +* Question answering +* Code generation + +The MosaicMPT models are available under the Apache 2.0 license, which means that they can be freely used, modified, and +redistributed. + +**Model Architecture** + +The MosaicMPT models are based on the Transformer architecture, which is a neural network architecture that has been +shown to be very effective for NLP tasks. The models are trained using a technique called masked language modeling, +which involves predicting the missing words in a sequence of text. + +**Model Sizes** + +The MosaicMPT models come in a variety of sizes, ranging from 7 billion to 70 billion parameters. The larger models have +more capacity to learn complex patterns in language, but they are also more computationally expensive to train and +deploy. + +**MosaicPretrainedTransformer (MPT) Architecture** + +The MosaicPretrainedTransformer (MPT) architecture is a modified transformer architecture that is optimized for +efficient training and inference. The MPT architecture includes the following changes: + +* Performance-optimized layer implementations +* Architecture changes that provide greater training stability +* Elimination of context length limits by replacing positional embeddings with Attention with Linear Biases (ALiBi) + +Thanks to these modifications, MPT models can be trained with high throughput efficiency and stable convergence. MPT +models can also be served efficiently with both standard HuggingFace pipelines and NVIDIA's FasterTransformer. + +**Use Cases** + +The MosaicMPT models can be used for a variety of tasks, including: + +* Natural language understanding (NLU): The MosaicMPT models can be used to understand the meaning of text, such as + identifying the entities and relationships in a sentence. +* Natural language generation (NLG): The MosaicMPT models can be used to generate text, such as writing different kinds + of creative content, like poems, code, scripts, musical pieces, email, letters, etc. +* Machine translation: The MosaicMPT models can be used to translate text from one language to another. +* Text summarization: The MosaicMPT models can be used to summarize a text document into a shorter, more concise + version. +* Question answering: The MosaicMPT models can be used to answer questions about a text document. +* Code generation: The MosaicMPT models can be used to generate code, such as Python scripts or Java classes. + +**Availability** + +The MosaicMPT models are available through the Hugging Face Hub. The models are also available in the TensorFlow Hub, +the PyTorch Hub and EasyDel. + +**Conclusion** + +The MosaicMPT models are a powerful family of LLMs that can be used for a variety of tasks. The models are open source +and available for free, making them a valuable resource for researchers and developers. + +## How to Use/Load Them in EasyDel + +```python +import jax +from EasyDel.transform import mpt_from_pretrained + +params, config = mpt_from_pretrained( + 'mosaicml/mpt-7b', + device=jax.devices('cpu')[0] # Offload on CPU +) +``` + +also keep that in mind that returned `config` includes `.get_partition_rules(fsdp=True)` + +#### Use With JaxServer + +```python +from EasyDel import JAXServer, FlaxMptForCausalLM +import jax +from EasyDel.transform import mpt_from_pretrained +from transformers import AutoTokenizer + +params, config = mpt_from_pretrained( + 'mosaicml/mpt-7b', + device=jax.devices('cpu')[0] # Offload on CPU +) + + +class MPTJaxServer(JAXServer): + ... + # You have to Custom this one yourself as you + # need read JaxServer Documents inorder to learn how + + +server = MPTJaxServer.load_from_params( + params=params, + model=FlaxMptForCausalLM( + config=config, + dtype=jax.numpy.bfloat16, # Im on TPUs + param_dtype=jax.numpy.bfloat16, # Im on TPUs + precision=jax.lax.Precision('fastest'), + _do_init=False, + input_shape=(1, 1024) + ), + config_model=config, + add_param_field=True, + tokenizer=AutoTokenizer.from_pretrained('mosaicml/mpt-7b'), + verbose=False, + do_memory_log=True, + config={ + "max_length": 2048, + "max_new_tokens": 2048, + "max_stream_tokens": 64, + "dtype": 'bf16', + "use_prefix_tokenizer": True, + 'pre_compile': True + } +) + +server.fire() # Launch FastAPI functions + +shared_urls = server.launch( + share_chat=True, + share_inst=True +) +``` + +Done 😇 this method can be used for all the MosaicMPT models \ No newline at end of file diff --git a/docs/PyTorchServer.md b/docs/PyTorchServer.md new file mode 100644 index 000000000..62988d37f --- /dev/null +++ b/docs/PyTorchServer.md @@ -0,0 +1,10 @@ +## PyTorchServer 🧬 + +`PyTorchServer` is one of offered utilities by EasyDel, and it's help hosting using and doing process with LLMs +and its also hackable, so you can override your own method in it and use it support both mid-level and high-level apis +and also give you a Gradio Chat and Instruct Pre-build and ready to use page + +* Supported Models are: + * EveryModel that have `transformers.PretrainedModel` as their Parent :) + +Documents are On The Way Amigos... \ No newline at end of file diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..d6bf46a05 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,56 @@ +# EasyDeL + +EasyDeL (Easy Deep Learning) is an open-source library designed to accelerate and optimize the training process of +machine learning models. This library is primarily focused on Jax/Flax and plans to offer easy and fine solutions to +train Flax/Jax Models on the `TPU/GPU` both for Serving and Training + +#### Note this Library needs golang to run (for some tracking stuff on TPU/GPU/CPU) + +#### Ubuntu GO installation + +```shell +sudo apt-get update && apt-get upgrade -y +sudo apt-get install golang -y +``` + +#### Manjaro/Arch GO installation + +```shell +sudo pacman -Syyuu go +``` + +_you can install other version too but easydel required at least version of 0.4.10_ + +```shell +!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -q +``` + +on GPUs be like + +```shell +pip install --upgrade pip +# CUDA 12 installation +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda12_pip]" -f \ + https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``` + +```shell +pip install --upgrade pip +# CUDA 11 installation +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda11_pip]" -f \ + https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``` + +## Documentations 🧭 + +* _EasyDel_: + * Configs + * [Modules](https://erfanzar.github.io/EasyDeL/docs/Models) + * RLHF + * [Serve](https://erfanzar.github.io/EasyDeL/docs/Serve) + * SMI + * Trainer + * Transform + * Utils diff --git a/docs/Serve.md b/docs/Serve.md new file mode 100644 index 000000000..9c0ebbdde --- /dev/null +++ b/docs/Serve.md @@ -0,0 +1,10 @@ +## EasyDel Serve 💫 + +`Serve` is one of offered utilities by EasyDel, and it's help hosting using and doing process with LLMs +and its also hackable, so you can override your own method in it and use it support both mid-level and high-level apis +and also give you a Gradio Chat and Instruct Pre-build and ready to use page + +and yes EasyDel supports Pytorch and jax Both for Serving LLMs 🧬 + +1. [JAXServer](https://erfanzar.github.io/EasyDeL/docs/JAXServer) For Jax 🤖 +2. [PyTorchServer](https://erfanzar.github.io/EasyDeL/docs/PyTorchServer) For PyTorch 🌪 \ No newline at end of file diff --git a/docs/ServinExample.md b/docs/ServinExample.md new file mode 100644 index 000000000..e69de29bb diff --git a/docs/TrainingExample.md b/docs/TrainingExample.md new file mode 100644 index 000000000..e69de29bb diff --git a/examples/serving/README.md b/examples/serving/README.md index abbf7680f..82ec1140b 100644 --- a/examples/serving/README.md +++ b/examples/serving/README.md @@ -54,7 +54,7 @@ you can also load model itself from parameters like ```python import EasyDel.transform -model_id = 'meta-llama/Llama-2-7b-chat-hf' +model_id = 'meta-llama/Llama.md-2-7b-chat-hf' tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) @@ -80,12 +80,13 @@ server = JAXServer.load_from_params( ) ``` -### API +### FastAPI 🌪 #### Instruct API to Override this api you have to code `forward_instruct` just like what you want the default implementation of this function is + ```python def forward_instruct(self, data: InstructRequest): if not self._funcs_generated: @@ -94,12 +95,13 @@ def forward_instruct(self, data: InstructRequest): } string = self.config.instruct_format.format(instruct=data.prompt, system=data.system) - - response, used_tokens = self.process( - string=string, - greedy=data.greedy, - max_new_tokens=None - ) + response, used_tokens = [None] * 2 + for response, used_tokens in self.process( + string=string, + greedy=data.greedy, + max_new_tokens=None + ): + ... self.number_of_served_request_until_last_up_time += 1 return { 'input': f'{string}', @@ -108,7 +110,7 @@ def forward_instruct(self, data: InstructRequest): } ``` -Base Class : +* BaseModel Class For PYData in FastAPI : ```python class InstructRequest(BaseModel): @@ -118,7 +120,8 @@ class InstructRequest(BaseModel): greedy: Optional[bool] = False ``` -Using Example : +* And here's an example of using this api via python and creating a simple client with using `requests` library in + python : ```python import requests @@ -157,11 +160,13 @@ def forward_chat(self, data: ChatRequest): history = self.process_chat_history(data.history or []) history += self.config.prompt_prefix_chat + data.prompt + self.config.prompt_postfix_chat - response, used_tokens = self.process( - string=history, - greedy=data.greedy, - max_new_tokens=None - ) + response, used_tokens = [None] * 2 + for response, used_tokens in self.process( + string=history, + greedy=data.greedy, + max_new_tokens=None + ): + ... self.number_of_served_request_until_last_up_time += 1 return { 'input': f'{history}', @@ -170,7 +175,7 @@ def forward_chat(self, data: ChatRequest): } ``` -Base Class : +* BaseModel Class For PYData in FastAPI : ```python class ChatRequest(BaseModel): @@ -180,7 +185,8 @@ class ChatRequest(BaseModel): greedy: Optional[bool] = False ``` -Using Example : +* And here's an example of using this api via python and creating a simple client with using `requests` library in + python : ```python import requests