-
Notifications
You must be signed in to change notification settings - Fork 132
/
config.py
73 lines (63 loc) · 3 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import torch
HAS_CUDA = torch.cuda.is_available()
DEVICE = torch.device('cuda' if HAS_CUDA else 'cpu')
parser = argparse.ArgumentParser(description='Simple LLM Finetuner')
parser.add_argument('--models',
nargs='+',
default=[
'decapoda-research/llama-7b-hf',
'cerebras/Cerebras-GPT-2.7B',
'cerebras/Cerebras-GPT-1.3B',
'EleutherAI/gpt-neo-2.7B'
],
help='List of models to use'
)
parser.add_argument('--device-map', type=str, default='', help='Device map to use')
parser.add_argument('--model', type=str, default='cerebras/Cerebras-GPT-2.7B', help='Model to use')
parser.add_argument('--max-seq-length', type=int, default=256, help='Max sequence length')
parser.add_argument('--micro-batch-size', type=int, default=12, help='Micro batch size')
parser.add_argument('--gradient-accumulation-steps', type=int, default=8, help='Gradient accumulation steps')
parser.add_argument('--epochs', type=int, default=3, help='Number of epochs')
parser.add_argument('--learning-rate', type=float, default=3e-4, help='Learning rate')
parser.add_argument('--lora-r', type=int, default=8, help='LORA r')
parser.add_argument('--lora-alpha', type=int, default=32, help='LORA alpha')
parser.add_argument('--lora-dropout', type=float, default=0.01, help='LORA dropout')
parser.add_argument('--max-new-tokens', type=int, default=80, help='Max new tokens')
parser.add_argument('--temperature', type=float, default=0.1, help='Temperature')
parser.add_argument('--top-k', type=int, default=40, help='Top k')
parser.add_argument('--top-p', type=float, default=0.3, help='Top p')
parser.add_argument('--repetition-penalty', type=float, default=1.5, help='Repetition penalty')
parser.add_argument('--do-sample', action='store_true', help='Enable sampling')
parser.add_argument('--num-beams', type=int, default=1, help='Number of beams')
parser.add_argument('--share', action='store_true', default=False, help='Whether to deploy the interface with Gradio')
parser.add_argument('--host', type=str, default='127.0.0.1', help='Host name or IP to launch Gradio webserver on')
parser.add_argument('--port', type=int, default=7860, help='Host port to launch Gradio webserver on')
args = parser.parse_args()
MODELS = args.models
DEVICE_MAP = {'': 0} if not args.device_map else args.device_map
MODEL = args.model
TRAINING_PARAMS = {
'max_seq_length': args.max_seq_length,
'micro_batch_size': args.micro_batch_size,
'gradient_accumulation_steps': args.gradient_accumulation_steps,
'epochs': args.epochs,
'learning_rate': args.learning_rate,
}
LORA_TRAINING_PARAMS = {
'lora_r': args.lora_r,
'lora_alpha': args.lora_alpha,
'lora_dropout': args.lora_dropout,
}
GENERATION_PARAMS = {
'max_new_tokens': args.max_new_tokens,
'temperature': args.temperature,
'top_k': args.top_k,
'top_p': args.top_p,
'repetition_penalty': args.repetition_penalty,
'do_sample': args.do_sample,
'num_beams': args.num_beams,
}
SHARE = args.share
SERVER_HOST = args.host
SERVER_PORT = args.port