Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating Beta Branch in order to fix MPT model generating bugs #19

Merged
merged 14 commits into from
Sep 12, 2023
51 changes: 51 additions & 0 deletions .github/workflows/jekyll-gh-pages.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Sample workflow for building and deploying a Jekyll site to GitHub Pages
name: Deploy Jekyll with GitHub Pages dependencies preinstalled

on:
# Runs on pushes targeting the default branch
push:
branches: ["main"]

# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
permissions:
contents: read
pages: write
id-token: write

# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete.
concurrency:
group: "pages"
cancel-in-progress: false

jobs:
# Build job
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Setup Pages
uses: actions/configure-pages@v3
- name: Build with Jekyll
uses: actions/jekyll-build-pages@v1
with:
source: ./
destination: ./_site
- name: Upload artifact
uses: actions/upload-pages-artifact@v2

# Deployment job
deploy:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
needs: build
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v2
2 changes: 1 addition & 1 deletion .idea/EasyDeL.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 10 additions & 5 deletions EasyDel/modules/falcon/modelling_falcon_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(
hidden_size: int = 64,
num_hidden_layers: int = 32,
num_attention_heads: int = 71,
n_layers:int=32,
n_heads:int=71,
n_layers: int = 32,
n_heads: int = 71,
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
use_cache: bool = True,
Expand Down Expand Up @@ -520,9 +520,14 @@ class FlaxFalconPretrainedModel(FlaxPreTrainedModel):
module_class: nn.Module = None
config_class = FalconConfig

def __init__(self, config, _do_init=False, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32,
input_shape: Tuple = (1, 12)):
module = self.module_class(config=config, dtype=dtype, param_dtype=param_dtype)
def __init__(self, config,
_do_init=False,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
input_shape: Tuple = (1, 1024),
precision: Optional[Union[None, jax.lax.Precision]] = jax.lax.Precision('fastest')
):
module = self.module_class(config=config, dtype=dtype, param_dtype=param_dtype, precision=precision)
super().__init__(_do_init=_do_init, module=module, config=config, dtype=dtype, input_shape=input_shape)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
Expand Down
2 changes: 1 addition & 1 deletion EasyDel/modules/llama/modelling_llama_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_gradient_checkpoint_policy(name):


class LlamaConfig(PretrainedConfig):
model_type = "Llama"
model_type = "Llama.md"

def __init__(
self,
Expand Down
54 changes: 36 additions & 18 deletions EasyDel/modules/mosaic_mpt/modelling_mpt_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def __init__(self,
resid_prob_drop: float = 0.0,
emb_prob_drop: float = 0.0,
alibi: bool = True,
use_bias: bool = True,
use_bias: bool = False,
learned_pos_emb: bool = True,
act_fn: str = 'gelu',
logit_scale: Optional[Union[float, str]] = None,
no_bias: bool = False,
verbose: int = 0,
embedding_fraction: float = 1.0,
use_cache: bool = False,
qk_ln: bool = True,
qk_ln: bool = False,
use_lm_head: bool = False,
use_norm_bias: bool = False,
gradient_checkpointing: str = 'nothing_saveable',
Expand Down Expand Up @@ -207,6 +207,9 @@ def add_jax_args(self,
flash_attn_key_chunk_size: int = 2048,
**kwargs
):
if hasattr(self, 'attn_config'):
for k, v in self.attn_config.items():
setattr(self, k, v)
basics = dict(
d_model=d_model,
n_heads=n_heads,
Expand Down Expand Up @@ -237,6 +240,7 @@ def add_jax_args(self,
)
for k, v in basics.items():
if not hasattr(self, k):
print(f' Key {k} not found in loaded config setting that to default of {v}')
setattr(self, k, v)

self.from_pt = False
Expand Down Expand Up @@ -282,8 +286,8 @@ def setup(self) -> None:
dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
self.act = ACT2FN[self.config.act_fn]

def __call__(self, x: jax.Array):
return self.down(self.act(self.up(x)))
def __call__(self, hidden_state: jax.Array):
return self.down(self.act(self.up(hidden_state)))


class FlaxMptAttention(nn.Module):
Expand All @@ -303,10 +307,16 @@ def setup(self) -> None:
self.k_ln = nn.LayerNorm(use_bias=self.config.use_norm_bias)
self.causal_mask = nn.make_causal_mask(jnp.ones((1, self.config.max_seq_len)))

def __call__(self, x, attn_bias=None, attention_mask=None):
inp_shape = x.shape
def __call__(self, hidden_state, attn_bias=None, attention_mask=None):
inp_shape = hidden_state.shape
b, s, ds = inp_shape
qkv = self.w_qkv(x)
if attention_mask is not None:
_, s = attention_mask.shape
assert inp_shape[
1] == s, (f'hidden_state_size on hidden_state shape'
f' ({inp_shape[1]}) and attention_mask ({s}) miss match'
f' attention Shape : {attention_mask.shape} | hidden Shape : {hidden_state.shape}')
qkv = self.w_qkv(hidden_state)
q, k, v = jnp.split(qkv, 3, -1)
if self.config.qk_ln:
q = self.q_ln(q)
Expand Down Expand Up @@ -361,7 +371,11 @@ def __call__(self, x, attn_bias=None, attention_mask=None):
atw += attn_bias
mask = jnp.where(self.causal_mask == 1, 0, jnp.finfo(atw).min)
if attention_mask is not None:
attention_mask = jnp.where(attention_mask.reshape(b, 1, 1, s) == 1, 0, jnp.finfo(atw).min)
attention_mask = jnp.where(
attention_mask.reshape(b, 1, 1, -1) == 1,
0,
jnp.finfo(atw).min
)
atw += attention_mask
atw += mask[:, :, :s, :s]
atw = nn.softmax(atw, -1)
Expand All @@ -383,10 +397,11 @@ def setup(self) -> None:
self.ffn = FlaxMptMLP(config=self.config, dtype=self.dtype, param_dtype=self.param_dtype,
precision=self.precision)

def __call__(self, x, attn_bias=None, attention_mask=None):
x = self.attn(self.norm_1(x), attn_bias=attn_bias, attention_mask=attention_mask) + x
x = self.ffn(self.norm_2(x)) + x
return x
def __call__(self, hidden_state, attn_bias=None, attention_mask=None):
hidden_state = (self.attn(self.norm_1(hidden_state), attn_bias=attn_bias, attention_mask=attention_mask) +
hidden_state)
hidden_state = self.ffn(self.norm_2(hidden_state)) + hidden_state
return hidden_state


def get_gradient_checkpoint_policy(name):
Expand Down Expand Up @@ -426,10 +441,10 @@ def setup(self) -> None:
)
]

def __call__(self, x, attn_bias=None, attention_mask=None):
def __call__(self, hidden_state, attn_bias=None, attention_mask=None):
for block in self.blocks:
x = block(x=x, attn_bias=attn_bias, attention_mask=attention_mask)
return x
hidden_state = block(hidden_state=hidden_state, attn_bias=attn_bias, attention_mask=attention_mask)
return hidden_state


def build_alibi(max_length, num_attention_heads, alibi_max: int = 8):
Expand Down Expand Up @@ -516,15 +531,18 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
def __call__(self,
input_ids,
attention_mask=None,
params=None,
params: dict = None,
add_params_field: bool = False,
return_dict: bool = True,
extra_embedding: Optional[Union[jnp.ndarray, None]] = None):
params = {'params': params or self.params} if add_params_field else params or self.params
input_ids = jnp.asarray(input_ids, dtype='i4')
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids, dtype='i4')
predict = self.module.apply(
params,
input_ids=jnp.asarray(input_ids, dtype='i4'),
attention_mask=jnp.asarray(attention_mask, dtype='i4') if attention_mask is not None else attention_mask,
input_ids=input_ids,
attention_mask=jnp.asarray(attention_mask, dtype='i4'),
return_dict=return_dict,
extra_embedding=extra_embedding
)
Expand Down
31 changes: 17 additions & 14 deletions EasyDel/serve/serve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,20 +479,22 @@ def forward_chat(self, data: ChatRequest):
'status': "down"
}

history = self.chat_format(
string = self.chat_format(
prompt=data.prompt,
system=None,
history=data.history
)

response, used_tokens = self.process(
string=history,
greedy=data.greedy,
max_new_tokens=None
)
response, used_tokens = [None] * 2
for response, used_tokens in self.process(
string=string,
greedy=data.greedy,
max_new_tokens=None
):
...
self.number_of_served_request_until_last_up_time += 1
return {
'input': f'{history}',
'input': f'{string}',
'response': response,
'tokens_used': used_tokens,
}
Expand All @@ -504,12 +506,13 @@ def forward_instruct(self, data: InstructRequest):
}

string = self.config.instruct_format.format(instruct=data.prompt, system=data.system)

response, used_tokens = self.process(
string=string,
greedy=data.greedy,
max_new_tokens=None
)
response, used_tokens = [None] * 2
for response, used_tokens in self.process(
string=string,
greedy=data.greedy,
max_new_tokens=None
):
...
self.number_of_served_request_until_last_up_time += 1
return {
'input': f'{string}',
Expand All @@ -534,7 +537,7 @@ def forward_chat_non_api(self, prompt, history, greedy):
return self.forward_chat(data)

def process(self,
string,
string: str,
*,
greedy: bool = False,
max_new_tokens: int = None,
Expand Down
46 changes: 24 additions & 22 deletions EasyDel/transform/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax
import torch
import numpy as np
from transformers import AutoModelForCausalLM


def mpt_convert_flax_to_pt_7b(state_dict_flax, n_layers: int, device=torch.device('cpu'), use_lm_head=False):
Expand Down Expand Up @@ -44,52 +45,52 @@ def mpt_convert_pt_to_flax_7b(state_dict, n_layers: int, device=jax.devices('cpu
state_dict_flax = {('transformer', 'wte', 'embedding'): state_dict[
'transformer.wte.weight'].cpu().detach().numpy()}
for i in range(n_layers):
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('norm_1'), ('scale')] = state_dict[
state_dict_flax[('transformer', 'h', f'{i}', 'norm_1', 'scale')] = state_dict[
f'transformer.blocks.{i}.norm_1.weight'].cpu().detach().numpy()
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('norm_2'), ('scale')] = state_dict[
state_dict_flax[('transformer', 'h', f'{i}', 'norm_2', 'scale')] = state_dict[
f'transformer.blocks.{i}.norm_2.weight'].cpu().detach().numpy()
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('ffn'), ('down'), ('kernel')] = jnp.transpose(
state_dict_flax[('transformer', 'h', f'{i}', 'ffn', 'down', 'kernel')] = jnp.transpose(
state_dict[f'transformer.blocks.{i}.ffn.down_proj.weight'].cpu().detach().numpy(), (1, 0))
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('ffn'), ('up'), ('kernel')] = jnp.transpose(
state_dict_flax[('transformer', 'h', f'{i}', 'ffn', 'up', 'kernel')] = jnp.transpose(
state_dict[f'transformer.blocks.{i}.ffn.up_proj.weight'].cpu().detach().numpy(), (1, 0))
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('w_qkv'), ('kernel')] = jnp.transpose(
state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'w_qkv', 'kernel')] = jnp.transpose(
state_dict[f'transformer.blocks.{i}.attn.Wqkv.weight'].cpu().detach().numpy(), (1, 0))
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('wo'), ('kernel')] = jnp.transpose(
state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'wo', 'kernel')] = jnp.transpose(
state_dict[f'transformer.blocks.{i}.attn.out_proj.weight'].cpu().detach().numpy(), (1, 0))
state_dict_flax[('transformer'), ('norm_f'), ('scale')] = state_dict[
state_dict_flax[('transformer', 'norm_f', 'scale')] = state_dict[
f'transformer.norm_f.weight'].cpu().detach().numpy()
if use_lm_head:
state_dict_flax[('lm_head'), ('kernel')] = jnp.transpose(
state_dict_flax[('lm_head', 'kernel')] = jnp.transpose(
state_dict[f'lm_head.weight'].cpu().detach().numpy(), (1, 0))
return state_dict_flax


def mpt_convert_pt_to_flax_1b(state_dict, n_layers: int, device=jax.devices('cpu')[0], use_lm_head=False, ):
# CONVERTER MPT-1B
with jax.default_device(device):
state_dict_flax = {(('transformer'), ('wte'), ('embedding')): state_dict[
state_dict_flax = {(('transformer', 'wte', 'embedding')): state_dict[
'transformer.wte.weight'].cpu().detach().numpy()}
for i in range(n_layers):
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('norm_1'), ('scale')] = state_dict[
state_dict_flax[('transformer', 'h', f'{i}', 'norm_1', 'scale')] = state_dict[
f'transformer.blocks.{i}.ln_1.weight'].cpu().detach().numpy()
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('norm_2'), ('scale')] = state_dict[
state_dict_flax[('transformer', 'h', f'{i}', 'norm_2', 'scale')] = state_dict[
f'transformer.blocks.{i}.ln_2.weight'].cpu().detach().numpy()
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('ffn'), ('down'), ('kernel')] = jnp.transpose(
state_dict_flax[('transformer', 'h', f'{i}', 'ffn', 'down', 'kernel')] = jnp.transpose(
state_dict[f'transformer.blocks.{i}.mlp.mlp_down.weight'].cpu().detach().numpy(), (1, 0))
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('ffn'), ('up'), ('kernel')] = jnp.transpose(
state_dict_flax[('transformer', 'h', f'{i}', 'ffn', 'up', 'kernel')] = jnp.transpose(
state_dict[f'transformer.blocks.{i}.mlp.mlp_up.weight'].cpu().detach().numpy(), (1, 0))
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('w_qkv'), ('kernel')] = jnp.transpose(
state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'w_qkv', 'kernel')] = jnp.transpose(
state_dict[f'transformer.blocks.{i}.attn.Wqkv.weight'].cpu().detach().numpy(), (1, 0))
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('wo'), ('kernel')] = jnp.transpose(
state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'wo', 'kernel')] = jnp.transpose(
state_dict[f'transformer.blocks.{i}.attn.out_proj.weight'].cpu().detach().numpy(), (1, 0))
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('q_ln'), ('scale')] = state_dict[
state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'q_ln', 'scale')] = state_dict[
f'transformer.blocks.{i}.attn.q_ln.weight'].cpu().detach().numpy()
state_dict_flax[('transformer'), ('h'), (f'{i}'), ('attn'), ('k_ln'), ('scale')] = state_dict[
state_dict_flax[('transformer', 'h', f'{i}', 'attn', 'k_ln', 'scale')] = state_dict[
f'transformer.blocks.{i}.attn.k_ln.weight'].cpu().detach().numpy()
state_dict_flax[('transformer'), ('norm_f'), ('scale')] = state_dict[
state_dict_flax[('transformer', 'norm_f', 'scale')] = state_dict[
f'transformer.ln_f.weight'].cpu().detach().numpy()
if use_lm_head:
state_dict_flax[('lm_head'), ('kernel')] = jnp.transpose(
state_dict_flax[('lm_head', 'kernel')] = jnp.transpose(
state_dict[f'lm_head.weight'].cpu().detach().numpy(),
(1, 0))
return state_dict_flax
Expand Down Expand Up @@ -140,15 +141,16 @@ def mpt_convert_flax_to_pt_1b(state_dict_flax, n_layers: int, device=torch.devic
return state_dict


def mpt_from_pretrained(model_id, device=jax.devices('cpu')[0]):
def mpt_from_pretrained(model_id, device=jax.devices('cpu')[0], **kwargs):
"""
return: Weight or Params for EasyDel Model , Config
"""
config = MptConfig.from_pretrained(model_id)
model = FlaxMptForCausalLM.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, **kwargs)

easydel_wights = mpt_convert_pt_to_flax_7b(
state_dict=model.state_dict(),
n_layers=config.num_hidden_layers,
n_layers=config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else config.n_layers,
device=device
)
config.add_jax_args()
Expand Down
Loading