From ded75cd7545bc0c1e35aeb4cfceaeaa82b9f5446 Mon Sep 17 00:00:00 2001 From: Haotian Liu Date: Mon, 23 Oct 2023 15:30:25 -0500 Subject: [PATCH 1/4] Add new lora schedule. --- llava/train/llava_trainer.py | 94 ++++++++++++++++++++++++++++++ llava/train/train.py | 2 + scripts/v1_5/finetune_lora.sh | 38 ++++++++++++ scripts/v1_5/finetune_task_lora.sh | 37 ++++++++++++ 4 files changed, 171 insertions(+) create mode 100644 scripts/v1_5/finetune_lora.sh create mode 100644 scripts/v1_5/finetune_task_lora.sh diff --git a/llava/train/llava_trainer.py b/llava/train/llava_trainer.py index d78c00f02..8adca6b26 100644 --- a/llava/train/llava_trainer.py +++ b/llava/train/llava_trainer.py @@ -5,7 +5,12 @@ from transformers import Trainer from transformers.trainer import ( + is_sagemaker_mp_enabled, + get_parameter_names, has_length, + ALL_LAYERNORM_LAYERS, + ShardedDDPOption, + logger, ) from typing import List, Optional @@ -146,6 +151,95 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: else: return super()._get_train_sampler() + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + if is_sagemaker_mp_enabled(): + return super().create_optimizer() + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + return super().create_optimizer() + + opt_model = self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + if self.args.mm_projector_lr is not None: + projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + "lr": self.args.mm_projector_lr, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + "lr": self.args.mm_projector_lr, + }, + ] + else: + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer = OSS( + params=optimizer_grouped_parameters, + optim=optimizer_cls, + **optimizer_kwargs, + ) + else: + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") + + return self.optimizer + def _save_checkpoint(self, model, trial, metrics=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR diff --git a/llava/train/train.py b/llava/train/train.py index bfffca705..cbfcc1bb4 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -103,6 +103,7 @@ class TrainingArguments(transformers.TrainingArguments): lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" + mm_projector_lr: Optional[float] = None group_by_modality_length: bool = field(default=False) @@ -900,6 +901,7 @@ def make_inputs_require_grad(module, input, output): model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_projector_lr = training_args.mm_projector_lr training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) diff --git a/scripts/v1_5/finetune_lora.sh b/scripts/v1_5/finetune_lora.sh new file mode 100644 index 000000000..90f00707c --- /dev/null +++ b/scripts/v1_5/finetune_lora.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path lmsys/vicuna-13b-v1.5 \ + --version v1 \ + --data_path ./playground/data/llava_v1_5_mix665k.json \ + --image_folder ./playground/data \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio pad \ + --group_by_modality_length True \ + --bf16 True \ + --output_dir ./checkpoints/llava-v1.5-13b-lora \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-4 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/v1_5/finetune_task_lora.sh b/scripts/v1_5/finetune_task_lora.sh new file mode 100644 index 000000000..f11303f29 --- /dev/null +++ b/scripts/v1_5/finetune_task_lora.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path liuhaotian/llava-v1.5-13b \ + --version v1 \ + --data_path ./playground/data/llava_v1_5_mix665k.json \ + --image_folder ./playground/data \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio pad \ + --group_by_modality_length True \ + --bf16 True \ + --output_dir ./checkpoints/llava-v1.5-13b-task-lora \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-4 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb From f8046a8079fbb3a5bdf24e95749f3f073f5ef2be Mon Sep 17 00:00:00 2001 From: Haotian Liu Date: Mon, 23 Oct 2023 15:32:50 -0500 Subject: [PATCH 2/4] Improve behavior match for grad_accu --- llava/train/llava_trainer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/llava/train/llava_trainer.py b/llava/train/llava_trainer.py index 8adca6b26..46a5d806e 100644 --- a/llava/train/llava_trainer.py +++ b/llava/train/llava_trainer.py @@ -79,12 +79,8 @@ def get_modality_length_grouped_indices(lengths, batch_size, world_size, generat megabatch_indices = torch.randperm(len(megabatches), generator=generator) megabatches = [megabatches[i] for i in megabatch_indices] - if len(additional_batch) >= megabatch_size: - megabatches = [additional_batch[:megabatch_size]] + megabatches - additional_batch = additional_batch[megabatch_size:] - if len(additional_batch) > 0: - megabatches.append(additional_batch) + megabatches.append(sorted(additional_batch)) return [i for megabatch in megabatches for i in megabatch] From 9ad726577ceabb54fc641d75e139aa5a266147d1 Mon Sep 17 00:00:00 2001 From: Haotian Liu Date: Thu, 26 Oct 2023 15:03:27 -0500 Subject: [PATCH 3/4] Update LoRA and task-finetuning docs. --- README.md | 10 +++++++++- docs/Finetune_Custom_Data.md | 35 ++++++++++++++++++++++++++++++++++ docs/MODEL_ZOO.md | 10 ++++++---- scripts/v1_5/finetune_task.sh | 36 +++++++++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 5 deletions(-) create mode 100644 docs/Finetune_Custom_Data.md create mode 100644 scripts/v1_5/finetune_task.sh diff --git a/README.md b/README.md index c60e1bddc..fb301d5dc 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,8 @@ ## Release -- [10/12] 🔥 Check out the Korean LLaVA (Ko-LLaVA), created by ETRI, who has generously supported our research! [[🤗 Demo](https://huggingface.co/spaces/etri-vilab/Ko-LLaVA)] +- [10/26] 🔥 LLaVA-1.5 with LoRA achieves comparable performance as full-model finetuning, with a reduced GPU RAM requirement ([ckpts](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md#llava-v15), [script](https://github.com/haotian-liu/LLaVA#train)). We also provide a [doc](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md) on how to finetune LLaVA-1.5 on your own dataset with LoRA. +- [10/12] Check out the Korean LLaVA (Ko-LLaVA), created by ETRI, who has generously supported our research! [[🤗 Demo](https://huggingface.co/spaces/etri-vilab/Ko-LLaVA)] - [10/12] LLaVA is now supported in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436) with 4-bit / 5-bit quantization support! - [10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)! - [10/5] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). @@ -235,6 +236,13 @@ Visual instruction tuning takes around 20 hours for LLaVA-v1.5-13B on 8x A100 (8 Training script with DeepSpeed ZeRO-3: [`finetune.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune.sh). +If you are do not have enough GPU memory: + +- Use LoRA: [`finetune_lora.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune_lora.sh). We are able to fit 13B training in 8-A100-40G/8-A6000, and 7B training in 8-RTX3090. Make sure `per_device_train_batch_size*gradient_accumulation_steps` is the same as the provided script for best reproducibility. +- Replace `zero3.json` with `zero3_offload.json` which offloads some parameters to CPU RAM. This slows down the training speed. + +If you are interested in finetuning LLaVA model to your own task/data, please check out [`Finetune_Custom_Data.md`](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md)。 + New options to note: - `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector. diff --git a/docs/Finetune_Custom_Data.md b/docs/Finetune_Custom_Data.md new file mode 100644 index 000000000..0d965c5d2 --- /dev/null +++ b/docs/Finetune_Custom_Data.md @@ -0,0 +1,35 @@ +# Finetune LLaVA on Custom Datasets + +## Dataset Format + +Convert your data to a JSON file of a List of all samples. Sample metadata should contain `id` (a unique identifier), `image` (the path to the image), and `conversations` (the conversation data between human and AI). + +A sample JSON for finetuning LLaVA for generating tag-style captions for Stable Diffusion: + +```json +[ + { + "id": "997bb945-628d-4724-b370-b84de974a19f", + "image": "part-000001/997bb945-628d-4724-b370-b84de974a19f.jpg", + "conversations": [ + { + "from": "human", + "value": "\nWrite a prompt for Stable Diffusion to generate this image." + }, + { + "from": "gpt", + "value": "a beautiful painting of chernobyl by nekro, pascal blanche, john harris, greg rutkowski, sin jong hun, moebius, simon stalenhag. in style of cg art. ray tracing. cel shading. hyper detailed. realistic. ue 5. maya. octane render. " + }, + ] + }, + ... +] +``` + +## Command + +If you have a limited task-specific data, we recommend finetuning from LLaVA checkpoints with LoRA following this [script](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune_task_lora.sh). + +You may need to adjust the hyperparameters to fit each specific dataset and your hardware constraint. + + diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md index f58531465..67ae9a399 100644 --- a/docs/MODEL_ZOO.md +++ b/docs/MODEL_ZOO.md @@ -10,10 +10,12 @@ The model weights below are *merged* weights. You do not need to apply delta. Th | Version | Size | Schedule | Checkpoint | VQAv2 | GQA | VizWiz | SQA | T-VQA | POPE | MME | MM-Bench | MM-Bench-CN | SEED | LLaVA-Bench-Wild | MM-Vet | |----------|----------|-----------|-----------|---|---|---|---|---|---|---|---|---|---|---|---| -| LLaVA-1.5 | 7B | full_ft-1e | [liuhaotian/llava-v1.5-7b](https://huggingface.co/liuhaotian/llava-v1.5-7b), [logs](https://api.wandb.ai/links/lht/6orh56wc) | 78.5 | 62.0 | 50.0 | 66.8 | 58.2 | 85.9 | 1510.7 | 64.3 | 58.3 | 58.6 | 65.4 | 31.1 | -| LLaVA-1.5 | 13B | full_ft-1e | [liuhaotian/llava-v1.5-13b](https://huggingface.co/liuhaotian/llava-v1.5-13b), [logs](https://api.wandb.ai/links/lht/6orh56wc) | 80.0 | 63.3 | 53.6 | 71.6 | 61.3 | 85.9 | 1531.3 | 67.7 | 63.6 | 61.6 | 72.5 | 36.1 | -| LLaVA-1.5 | 7B | lora-1e | coming soon | -| LLaVA-1.5 | 13B | lora-1e | coming soon | +| LLaVA-1.5 | 7B | full_ft-1e | [liuhaotian/llava-v1.5-7b](https://huggingface.co/liuhaotian/llava-v1.5-7b) | 78.5 | 62.0 | 50.0 | 66.8 | 58.2 | 85.9 | 1510.7 | 64.3 | 58.3 | 58.6 | 65.4 | 31.1 | +| LLaVA-1.5 | 13B | full_ft-1e | [liuhaotian/llava-v1.5-13b](https://huggingface.co/liuhaotian/llava-v1.5-13b) | 80.0 | 63.3 | 53.6 | 71.6 | 61.3 | 85.9 | 1531.3 | 67.7 | 63.6 | 61.6 | 72.5 | 36.1 | +| LLaVA-1.5 | 7B | lora-1e | [liuhaotian/llava-v1.5-7b-lora](https://huggingface.co/liuhaotian/llava-v1.5-7b-lora) | 79.1 | 63.0 | 47.8 | 68.4 | 58.2 | 86.4 | 1476.9 | 66.1 | 58.9 | 60.1 | 67.9 | 30.2 | +| LLaVA-1.5 | 13B | lora-1e | [liuhaotian/llava-v1.5-13b-lora](https://huggingface.co/liuhaotian/llava-v1.5-13b-lora) | 80.0 | 63.3 | 58.9 | 71.2 | 60.2 | 86.7 | 1541.7 | 68.5 | 61.5 | 61.3 | 69.5 | 38.3 | + +Training logs: [wandb](https://api.wandb.ai/links/lht/6orh56wc).


diff --git a/scripts/v1_5/finetune_task.sh b/scripts/v1_5/finetune_task.sh new file mode 100644 index 000000000..063f3f13e --- /dev/null +++ b/scripts/v1_5/finetune_task.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path liuhaotian/llava-v1.5-13b \ + --version v1 \ + --data_path ./playground/data/llava_v1_5_mix665k.json \ + --image_folder ./playground/data \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio pad \ + --group_by_modality_length True \ + --bf16 True \ + --output_dir ./checkpoints/llava-v1.5-13b-task \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb From 9e3f3d08457a3668c42c77eb295edf7083ce84fb Mon Sep 17 00:00:00 2001 From: Haotian Liu Date: Thu, 26 Oct 2023 15:28:20 -0500 Subject: [PATCH 4/4] Unfreeze projector when finetuning with LoRA. --- llava/model/llava_arch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py index 58b2b37df..7fadac793 100644 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -69,6 +69,10 @@ def initialize_vision_modules(self, model_args, fsdp=None): if getattr(self, 'mm_projector', None) is None: self.mm_projector = build_vision_projector(self.config) + else: + # In case it is frozen by LoRA + for p in self.mm_projector.parameters(): + p.requires_grad = True if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')