-
Notifications
You must be signed in to change notification settings - Fork 42
/
urbangpt_train.sh
40 lines (38 loc) · 1.21 KB
/
urbangpt_train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# to fill in the following path to run our UrbanGPT!
model_path=./checkpoints/vicuna-7b-v1.5-16k
instruct_ds=./data/train_data/multi_NYC.json
st_data_path=./data/train_data/multi_NYC_pkl.pkl
pretra_ste=ST_Encoder
output_model=./checkpoints/UrbanGPT
wandb offline
python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 --master_port=20001 \
urbangpt/train/train_mem.py \
--model_name_or_path ${model_path} \
--version v1 \
--data_path ${instruct_ds} \
--st_content ./TAXI.json \
--st_data_path ${st_data_path} \
--st_tower ${pretra_ste} \
--tune_st_mlp_adapter True \
--st_select_layer -2 \
--use_st_start_end \
--bf16 True \
--output_dir ${output_model} \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 4800 \
--save_total_limit 1 \
--learning_rate 2e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--lazy_preprocess True \
--report_to wandb