-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
31 lines (27 loc) · 2.41 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Config of Llama2-lora")
parser.add_argument("--MICRO_BATCH_SIZE", type=int, default=4, help="Per device train batch size")
parser.add_argument("--BATCH_SIZE", type=int, default=8, help="batch size")
parser.add_argument('--EPOCHS', type=int, default=100, help='Training epochs')
parser.add_argument('--WARMUP_STEPS', type=int, default=100, help='Warmup steps')
parser.add_argument('--LEARNING_RATE', type=float, default= 2e-5 , help='Training learning rate')
parser.add_argument('--CONTEXT_LEN', type=int, default=256, help='Truncation length of context (in json)')
parser.add_argument('--TARGET_LEN', type=int, default=256, help='Truncation length of target (in json)')
parser.add_argument('--TEXT_LEN', type=int, default=256, help='Truncation length of text (in txt)')
parser.add_argument('--LORA_R', type=int, default=16, help='Lora low rank')
parser.add_argument('--LORA_ALPHA', type=int, default=16, help='Lora Alpha')
parser.add_argument('--LORA_DROPOUT', type=float, default=0.05, help='Lora dropout')
parser.add_argument('--MODEL_NAME', type=str, default="THUDM/chatglm2-6b", help='Model name')
parser.add_argument('--LOGGING_STEPS', type=int, default=100, help='Logging steps in training')
parser.add_argument('--OUTPUT_DIR', type=str, default="./output_model", help='Output dir')
parser.add_argument('--DATA_PATH', type=str, default="./train_data.json", help='Input dir')
parser.add_argument('--DATA_TYPE', type=str, choices= ["json" , "txt"], default="json", help='Input file type')
parser.add_argument('--SAVE_STEPS', type=int, default=1000, help='Save the model according to steps')
parser.add_argument('--SAVE_TOTAL_LIMIT', type=int, default=3, help='The number of the checkpoint you will save (Excluding the final one)')
parser.add_argument('--BIT_8', default=False, action="store_true", help='Use 8-bit')
parser.add_argument('--BIT_4', default=False, action="store_true", help='Use 4-bit')
parser.add_argument('--PROMPT', type=str, default="Input your prompt", help='Your prompt when inference')
parser.add_argument('--TEMPERATURE', type=int, default=0, help='Temperature when inference')
parser.add_argument('--LORA_CHECKPOINT_DIR', type=str, default="./output_model/checkpoint-1000/", help='Your Lora checkpoint')
return parser.parse_args()