diff --git a/drivellava/scripts/train.py b/drivellava/scripts/train.py index ac4f451..1499af8 100644 --- a/drivellava/scripts/train.py +++ b/drivellava/scripts/train.py @@ -145,7 +145,7 @@ def main(): --group_by_modality_length True \ --bf16 True \ --output_dir {OUTPUT_DIR} \ - --num_train_epochs 4 \ + --num_train_epochs 1 \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 1 \ @@ -154,7 +154,7 @@ def main(): --save_strategy "steps" \ --save_steps 50 \ --save_total_limit 1 \ - --learning_rate 2e-3 \ + --learning_rate 2e-5 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \