-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_classifier.py
executable file
·79 lines (60 loc) · 2.32 KB
/
test_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import outlines
from outlines import samplers
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from transformers import TextStreamer, BitsAndBytesConfig
from peft import PeftModel
set_seed(1234)
if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
else:
compute_dtype = torch.float16
use_4bit = True
bnb_4bit_quant_type = "nf4"
use_double_quant = True
compute_dtype = torch.bfloat16
attn_implementation = 'flash_attention_2'
target_modules = ["all_linear"]
bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=use_double_quant,
)
MODEL_ID = "/home/stefanwebb/models/llm/meta_llama3-8b-instruct"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map='auto',
torch_dtype="auto",
quantization_config=bnb_config,
attn_implementation=attn_implementation
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
peft_model_id = "/home/stefanwebb/code/python/train-sentence-classifier/stefans-debug-llama3-sentence-classifier/checkpoint-864" # 864"
peft_model = PeftModel.from_pretrained(model, peft_model_id)
model = outlines.models.Transformers(peft_model, tokenizer)
# prompt = """You are a sentiment-labelling assistant.
# Is the following review positive or negative?
# Review: This restaurant is so-so!
# """
# generator = outlines.generate.choice(model, ["Positive", "Negative"])
# answer = generator(prompt)
sample = "?"
prompt = f"Classify the following sentence as imperative, declarative, interrogative, or exclamative:\n\n{sample}"
chat = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"Classify the following sentence as imperative, declarative, interrogative, or exclamative:\n\n{prompt}"},
]
formatted_prompt = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
sampler = samplers.greedy()
generator = outlines.generate.choice(
model, ["imperative", "declarative", "interrogative", "exclamative"], sampler
)
for _ in range(7):
answer = generator(formatted_prompt)
print(f"Answer: {answer}")