Skip to content

Commit

Permalink
Merge pull request #178 from stochasticai/dev
Browse files Browse the repository at this point in the history
Release 0.1.2
  • Loading branch information
sarthaklangde authored May 1, 2023
2 parents 92b2117 + ab787fa commit 142954b
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 56 deletions.
5 changes: 5 additions & 0 deletions examples/gptj/gptj_lora_int8.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import gc

from xturing.datasets.instruction_dataset import InstructionDataset
from xturing.models import BaseModel

Expand All @@ -10,6 +12,9 @@
# Save the model
model.save("./gptj_weights")

del model
gc.collect()
model = BaseModel.load("./gptj_weights")
# Once the model has been finetuned, you can start doing inferences
output = model.generate(texts=["Why LLM models are becoming so important?"])
print("Generated output by the model: {}".format(output))
5 changes: 5 additions & 0 deletions examples/llama/llama_lora_int8.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import gc

from xturing.datasets.instruction_dataset import InstructionDataset
from xturing.models import BaseModel

Expand All @@ -11,6 +13,9 @@
model.save("./llama_weights")

# Once the model has been finetuned, you can start doing inferences
del model
gc.collect()
model = BaseModel.load("./llama_weights")
output = model.generate(texts=["Why LLM models are becoming so important?"])
print("Generated output by the model: {}".format(output))

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "xturing"
version = "0.1.1"
version = "0.1.2"
description = "Fine-tuning, evaluation and data generation for LLMs"

authors = [
Expand Down Expand Up @@ -43,7 +43,7 @@ keywords = [
dependencies = [
"torch >= 1.9.0",
"pytorch-lightning",
"transformers==4.27.3",
"transformers==4.28.1",
"datasets",
"evaluate",
"bitsandbytes==0.37.2",
Expand Down
2 changes: 1 addition & 1 deletion src/xturing/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.1"
__version__ = "0.1.2"
24 changes: 6 additions & 18 deletions src/xturing/config/generation_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ llama_lora:
max_new_tokens: 256
do_sample: false

# Contrastive search
# Greedy search
llama_lora_int8:
penalty_alpha: 0.6
top_k: 4
max_new_tokens: 256
do_sample: false

Expand All @@ -48,10 +46,8 @@ gptj_lora:
max_new_tokens: 256
do_sample: false

# Contrastive search
# Greedy search
gptj_lora_int8:
penalty_alpha: 0.6
top_k: 4
max_new_tokens: 256
do_sample: false

Expand Down Expand Up @@ -104,10 +100,8 @@ galactica_lora:
max_new_tokens: 256
do_sample: false

# Contrastive search
# Greedy search
galactica_lora_int8:
penalty_alpha: 0.6
top_k: 4
max_new_tokens: 256
do_sample: false

Expand All @@ -125,10 +119,8 @@ opt_lora:
max_new_tokens: 256
do_sample: false

# Contrastive search
# Greedy search
opt_lora_int8:
penalty_alpha: 0.6
top_k: 4
max_new_tokens: 256
do_sample: false

Expand All @@ -146,10 +138,8 @@ cerebras_lora:
max_new_tokens: 256
do_sample: false

# Contrastive search
# Greedy search
cerebras_lora_int8:
penalty_alpha: 0.6
top_k: 4
max_new_tokens: 256
do_sample: false

Expand All @@ -167,9 +157,7 @@ bloom_lora:
max_new_tokens: 256
do_sample: false

# Contrastive search
# Greedy search
bloom_lora_int8:
penalty_alpha: 0.6
top_k: 4
max_new_tokens: 256
do_sample: false
2 changes: 1 addition & 1 deletion src/xturing/engines/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
model_weights_path = str(Path(weights_path).resolve() / "pytorch_model.bin")
self.model.load_state_dict(
torch.load(
model_weights_path # , map_location=torch.device(DEFAULT_DEVICE)
model_weights_path, map_location=torch.device(DEFAULT_DEVICE)
)
)
else:
Expand Down
21 changes: 16 additions & 5 deletions src/xturing/engines/gptj_utils/gptj.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from typing import Optional, Union, Tuple
from transformers.models.gptj.modeling_gptj import (
apply_rotary_pos_emb,
fixed_pos_embedding,
)
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb


def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
dim = x.shape[-1]
if seq_len is None:
seq_len = x.shape[seq_dim]
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = (
torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq)
.to(x.device)
.float()
)
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)


class GPTJAttention(nn.Module):
Expand Down
50 changes: 30 additions & 20 deletions src/xturing/engines/llama_engine.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import transformers

import torch
import transformers
from torch import nn

from xturing.engines.causal import CausalEngine, CausalLoraEngine
from xturing.engines.llama_utils import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from xturing.engines.lora_engine import prepare_model_for_int8_training
from xturing.engines.quant_utils import make_quant, autotune_warmup
from xturing.engines.quant_utils import autotune_warmup, make_quant
from xturing.utils.hub import ModelHub


class LLamaEngine(CausalEngine):
config_name: str = "llama_engine"

Expand Down Expand Up @@ -102,24 +103,28 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
target_modules=["q_proj", "v_proj"],
)

def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):

def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
res.update(
find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res


class LlamaLoraInt4Engine(CausalLoraEngine):
config_name: str = "llama_lora_int4_engine"

def __init__(self, weights_path: Optional[Union[str, Path]] = None):
model_name = "decapoda-research/llama-7b-hf"
model_name = "decapoda-research/llama-7b-hf"

if weights_path is None:
weights_path = ModelHub().load("x/llama_lora_int4")
weights_path = ModelHub().load("x/llama_lora_int4")

config = LlamaConfig.from_pretrained(model_name)

Expand All @@ -129,10 +134,10 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):

def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop

torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop

torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
Expand All @@ -143,18 +148,23 @@ def noop(*args, **kwargs):

layers = find_layers(model)

for name in ['lm_head']:
for name in ["lm_head"]:
if name in layers:
del layers[name]

wbits = 4
groupsize = 128
warmup_autotune=True
warmup_autotune = True

make_quant(model, layers, wbits, groupsize)


model.load_state_dict(torch.load(weights_path / Path("pytorch_model.bin")), strict=False)
state_dict = torch.load(
weights_path / Path("pytorch_model.bin"), map_location="cpu"
)
new_state_dict = {}
for key, value in state_dict.items():
new_state_dict[key[6:]] = value
model.load_state_dict(new_state_dict, strict=False)

if warmup_autotune:
autotune_warmup(model)
Expand All @@ -171,12 +181,12 @@ def noop(*args, **kwargs):
tokenizer.pad_token_id = tokenizer.eos_token_id

super().__init__(
model=model,
model=model,
tokenizer=tokenizer,
target_modules=[
"q_proj",
"v_proj",
]
],
)

torch.nn.init.kaiming_uniform_ = saved_kaiming_uniform_
Expand Down
4 changes: 3 additions & 1 deletion src/xturing/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,16 @@ def bar_progress(current, total, width=80):

entries = list(model_dir.glob("*"))

if len(entries) == 1 and entries[0].is_dir():
while len(entries) == 1 and entries[0].is_dir():
single_folder = entries[0]

for item in single_folder.iterdir():
shutil.move(str(item), str(model_dir / item.name))

shutil.rmtree(single_folder)

entries = list(model_dir.glob("*"))

except Exception as e:
print(f"Error downloading model {model_name} from {url}: {e}")
raise e
Expand Down
21 changes: 13 additions & 8 deletions tests/xturing/models/test_gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ def test_text_gpt2():
generation_config.top_k = 50
generation_config.top_p = 1.0

assert (
model.generate(texts="I want to")[: len(EXAMPLE_BASE_MODEL)]
== EXAMPLE_BASE_MODEL
)
assert model.generate(texts="I want to") != ""


def test_text_dataset_gpt2():
Expand All @@ -44,10 +41,18 @@ def test_text_dataset_gpt2_lora():
generation_config.max_new_tokens = None
generation_config.top_k = 50
generation_config.top_p = 1.0
assert (
other_model.generate(texts="I want to")[: len(EXAMPLE_LORA_MODEL)]
== EXAMPLE_LORA_MODEL
)
assert other_model.generate(texts="I want to") != ""


def test_text_dataset_gpt2_lora():
# Greedy search. Parameters are set to default config of HF
other_model = BaseModel.create("gpt2_lora_int8")
generation_config = other_model.generation_config()
generation_config.do_sample = False
generation_config.max_new_tokens = None
generation_config.top_k = 50
generation_config.top_p = 1.0
assert other_model.generate(texts="I want to") != ""


def test_train_gpt2():
Expand Down

0 comments on commit 142954b

Please sign in to comment.