Skip to content

Commit

Permalink
Merge branch 'main' into README-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu authored Oct 12, 2023
2 parents d4fe0d1 + 8cb0e4d commit 4241f08
Show file tree
Hide file tree
Showing 34 changed files with 1,506 additions and 33 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

*Visual instruction tuning towards large language and vision models with GPT-4 level capabilities.*

[[Project Page](https://llava-vl.github.io/)] [[Demo](https://llava.hliu.cc/)] [[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)] [[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)]
[[Project Page](https://llava-vl.github.io/)] [[Demo](https://llava.hliu.cc/)] [[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)] [[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] [[Community Build Demo on 🤗Spaces](https://huggingface.co/spaces/badayvedat/LLaVA)]

**Improved Baselines with Visual Instruction Tuning** [[Paper](https://arxiv.org/abs/2310.03744)] <br>
[Haotian Liu](https://hliu.cc), [Chunyuan Li](https://chunyuan.li/), [Yuheng Li](https://yuheng-li.github.io/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/)
Expand All @@ -17,11 +17,11 @@


## Release
- [10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), with evaluation scripts coming this week!
- [10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
- [10/5] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md).
- [9/26] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [[LLavA-RLHF]](https://llava-rlhf.github.io/)
- [9/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accpeted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accpeted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**.
- [9/20] We summarize our emprical study of training 33B and 65B LLaVA models in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [``Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020)
- [9/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accepted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accepted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**.
- [9/20] We summarize our empirical study of training 33B and 65B LLaVA models in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [``Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020)
<p align="center">
<img src="https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings/blob/main/images/mfm_evolution.jpeg?raw=true" width=50%/>
</p>
Expand Down Expand Up @@ -243,7 +243,7 @@ New options to note:

In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.

Detailed evaluation scripts coming soon.
See [Evaluation.md](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md).

### GPT-assisted Evaluation

Expand Down
142 changes: 142 additions & 0 deletions docs/Evaluation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Evaluation

In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.

Currently, we mostly utilize the official toolkit or server for the evaluation.

## Evaluate on Custom Datasets

You can evaluate LLaVA on your custom datasets by converting your dataset to LLaVA's jsonl format, and evaluate using [`model_vqa.py`](https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/model_vqa.py).

Below we provide a general guideline for evaluating datasets with some common formats.

1. Short-answer (e.g. VQAv2, MME).

```
<question>
Answer the question using a single word or phrase.
```

2. Option-only for multiple-choice (e.g. MMBench, SEED-Bench).

```
<question>
A. <option_1>
B. <option_2>
C. <option_3>
D. <option_4>
Answer with the option's letter from the given choices directly.
```

3. Natural QA (e.g. LLaVA-Bench, MM-Vet).

No postprocessing is needed.

## Scripts

Before preparing task-specific data, download [eval.zip](https://drive.google.com/file/d/1atZSBBrAX54yYpxtVVW33zFvcnaHeFPy/view?usp=sharing). It contains custom annotations, scripts, and the prediction files with LLaVA v1.5. Extract to `./playground/data/eval`. This also provides a general structure for all datasets.

### VQAv2

1. Download [`test2015`](http://images.cocodataset.org/zips/test2015.zip) and put it under `./playground/data/eval/vqav2`.
2. Multi-GPU inference.
```Shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/vqav2.sh
```
3. Submit the results to the evaluation server: `./playground/data/eval/vqav2/answers_upload`.

### GQA

1. Download the data following the official instructions [here](https://cs.stanford.edu/people/dorarad/gqa/download.html) and put under `./playground/data/eval/gqa/data`.
2. Multi-GPU inference.
```Shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/gqa.sh
```

### VisWiz

1. Download [`test.json`](https://vizwiz.cs.colorado.edu/VizWiz_final/vqa_data/Annotations.zip) and extract [`test.zip`](https://vizwiz.cs.colorado.edu/VizWiz_final/images/test.zip) to `test`. Put them under `./playground/data/eval/vizwiz`.
2. Single-GPU inference.
```Shell
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/vizwiz.sh
```
3. Submit the results to the evaluation server: `./playground/data/eval/vizwiz/answers_upload`.

### ScienceQA

1. Under `./playground/data/eval/scienceqa`, download `images`, `pid_splits.json`, `problems.json` from the `data/scienceqa` folder of the ScienceQA [repo](https://github.com/lupantech/ScienceQA).
2. Single-GPU inference and evaluate.
```Shell
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/sqa.sh
```

### TextVQA

1. Download [`TextVQA_0.5.1_val.json`](https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json) and [images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) and extract to `./playground/data/eval/textvqa`.
2. Single-GPU inference and evaluate.
```Shell
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/textvqa.sh
```

### POPE

1. Download `coco` from [POPE](https://github.com/AoiDragon/POPE/tree/e3e39262c85a6a83f26cf5094022a782cb0df58d/output/coco) and put under `./playground/data/eval/pope`.
2. Single-GPU inference and evaluate.
```Shell
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/pope.sh
```

### MME

1. Download the data following the official instructions [here](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation).
2. Downloaded images to `MME_Benchmark_release_version`.
3. put the official `eval_tool` and `MME_Benchmark_release_version` under `./playground/data/eval/MME`.
4. Single-GPU inference and evaluate.
```Shell
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mme.sh
```

### MMBench

1. Download [`mmbench_dev_20230712.tsv`](https://download.openmmlab.com/mmclassification/datasets/mmbench/mmbench_dev_20230712.tsv) and put under `./playground/data/eval/mmbench`.
2. Single-GPU inference.
```Shell
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench.sh
```
3. Submit the results to the evaluation server: `./playground/data/eval/mmbench/answers_upload/mmbench_dev_20230712`.

### MMBench-CN

1. Download [`mmbench_dev_cn_20231003.tsv`](https://download.openmmlab.com/mmclassification/datasets/mmbench/mmbench_dev_en_20231003.tsv) and put under `./playground/data/eval/mmbench`.
2. Single-GPU inference.
```Shell
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench_cn.sh
```
3. Submit the results to the evaluation server: `./playground/data/eval/mmbench/answers_upload/mmbench_dev_cn_20231003`.

### SEED-Bench

1. Following the official [instructions](https://github.com/AILab-CVC/SEED-Bench/blob/main/DATASET.md) to download the images and the videos. Put images under `./playground/data/eval/seed_bench/SEED-Bench-image`.
2. Extract the video frame in the middle from the downloaded videos, and put them under `./playground/data/eval/seed_bench/SEED-Bench-video-image`. We provide our script `extract_video_frames.py` modified from the official one.
3. Multiple-GPU inference and evaluate.
```Shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/seed.sh
```
4. Optionally, submit the results to the leaderboard: `./playground/data/eval/seed_bench/answers_upload` using the official jupyter notebook.

### LLaVA-Bench-in-the-Wild

1. Extract contents of [`llava-bench-in-the-wild`](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to `./playground/data/eval/llava-bench-in-the-wild`.
2. Single-GPU inference and evaluate.
```Shell
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/llavabench.sh
```

### MM-Vet

1. Extract [`mm-vet.zip`](https://github.com/yuweihao/MM-Vet/releases/download/v1/mm-vet.zip) to `./playground/data/eval/mmvet`.
2. Single-GPU inference.
```Shell
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmvet.sh
```
3. Evaluate the predictions in `./playground/data/eval/mmvet/results` using the official jupyter notebook.
2 changes: 1 addition & 1 deletion docs/LLaVA_Bench.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
- [Multimodal Bing-Chat by Microsoft](https://blogs.bing.com/search/july-2023/Bing-Chat-Enterprise-announced,-multimodal-Visual-Search-rolling-out-to-Bing-Chat) (July 18, 2023)
- [Multimodal Bard by Google](https://bard.google.com/).

These chatbots are presumably supported by proprietary large multimodal models (LMM). Compared with the open-source LMM such as LLaVA, proprietary LMM represent the scaling success upperbound of the current SoTA techniques. They share the goal of developing multimodal chatbots that follow human intents to complete various daily-life visual tasks in the wild. While it remains less unexplored how to evaluate multimodal chat ability, it provides useful feedback to study open-source LMMs against the commercial multimodal chatbots. In addition to the *LLaVA-Bench (COCO)* dataset we used to develop the early versions of LLaVA, we are releasing [*LLaVA-Bench (In-the-Wild)*](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to the community for the public use.
These chatbots are presumably supported by proprietary large multimodal models (LMM). Compared with the open-source LMM such as LLaVA, proprietary LMM represent the scaling success upperbound of the current SoTA techniques. They share the goal of developing multimodal chatbots that follow human intents to complete various daily-life visual tasks in the wild. While it remains less explored how to evaluate multimodal chat ability, it provides useful feedback to study open-source LMMs against the commercial multimodal chatbots. In addition to the *LLaVA-Bench (COCO)* dataset we used to develop the early versions of LLaVA, we are releasing [*LLaVA-Bench (In-the-Wild)*](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to the community for the public use.

## LLaVA-Bench (In-the-Wild *[Ongoing work]*)

Expand Down
2 changes: 1 addition & 1 deletion docs/LoRA.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ You need latest code base for LoRA support (instructions [here](https://github.c

## Demo (Web UI)

Please execute each of the command below one by one (after the previous one has finished). The commands are the same as launching other demos except for an additional `--model-base` flag to specify the base model to use. Please make sure the base model corresponds to the LoRA checkpoint that you are using. For this technical preview, you need Vicuna v1.1 (7B) checkpoint (if you do not have that already, follow the instructions [here](https://github.com/lm-sys/FastChat#vicuna-weights)).
Please execute each of the commands below one by one (after the previous one has finished). The commands are the same as launching other demos except for an additional `--model-base` flag to specify the base model to use. Please make sure the base model corresponds to the LoRA checkpoint that you are using. For this technical preview, you need Vicuna v1.1 (7B) checkpoint (if you do not have that already, follow the instructions [here](https://github.com/lm-sys/FastChat#vicuna-weights)).

#### Launch a controller
```Shell
Expand Down
2 changes: 1 addition & 1 deletion docs/MODEL_ZOO.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ The model weights below are *merged* weights. You do not need to apply delta. Th

## Projector weights

The model weights below are projector weights we have pretrained. You can use these projector weights for visual instruction tuning. We'll add more projector weights into model zoo very soon.
These are projector weights we have pretrained. You can use these projector weights for visual instruction tuning. They are just pretrained on image-text pairs, and are **NOT** instruction tuned, which means they do **NOT** follow instructions as good as our official models, and can output repetitive, lengthy, and garbled outputs. If you want to have nice conversations with LLaVA, use the checkpoints above (LLaVA v1.5).

**NOTE**: These projector weights are only compatible with the `llava>=1.0.0`, please check out the latest code base if your local code version is below `v1.0.0`.

Expand Down
81 changes: 81 additions & 0 deletions llava/eval/eval_pope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import json
import argparse

def eval_pope(answers, label_file):
label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]

for answer in answers:
text = answer['text']

# Only keep the first sentence
if text.find('.') != -1:
text = text.split('.')[0]

text = text.replace(',', '')
words = text.split(' ')
if 'No' in words or 'not' in words or 'no' in words:
answer['text'] = 'no'
else:
answer['text'] = 'yes'

for i in range(len(label_list)):
if label_list[i] == 'no':
label_list[i] = 0
else:
label_list[i] = 1

pred_list = []
for answer in answers:
if answer['text'] == 'no':
pred_list.append(0)
else:
pred_list.append(1)

pos = 1
neg = 0
yes_ratio = pred_list.count(1) / len(pred_list)

TP, TN, FP, FN = 0, 0, 0, 0
for pred, label in zip(pred_list, label_list):
if pred == pos and label == pos:
TP += 1
elif pred == pos and label == neg:
FP += 1
elif pred == neg and label == neg:
TN += 1
elif pred == neg and label == pos:
FN += 1

print('TP\tFP\tTN\tFN\t')
print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))

precision = float(TP) / float(TP + FP)
recall = float(TP) / float(TP + FN)
f1 = 2*precision*recall / (precision + recall)
acc = (TP + TN) / (TP + TN + FP + FN)
print('Accuracy: {}'.format(acc))
print('Precision: {}'.format(precision))
print('Recall: {}'.format(recall))
print('F1 score: {}'.format(f1))
print('Yes ratio: {}'.format(yes_ratio))
print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--annotation-dir", type=str)
parser.add_argument("--question-file", type=str)
parser.add_argument("--result-file", type=str)
args = parser.parse_args()

questions = [json.loads(line) for line in open(args.question_file)]
questions = {question['question_id']: question for question in questions}
answers = [json.loads(q) for q in open(args.result_file)]
for file in os.listdir(args.annotation_dir):
assert file.startswith('coco_pope_')
assert file.endswith('.json')
category = file[10:-5]
cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
print("====================================")
35 changes: 25 additions & 10 deletions llava/eval/eval_science_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def get_pred_idx(prediction, choices, options):
if prediction in options[:len(choices)]:
return options.index(prediction)
else:
return -1
return random.choice(range(len(choices)))


Expand All @@ -55,16 +56,23 @@ def get_pred_idx(prediction, choices, options):

for prob_id, prob in split_problems.items():
if prob_id not in predictions:
continue
pred = predictions[prob_id]
pred_text = pred['text']

pattern = re.compile(r'The answer is ([A-Z]).')
res = pattern.findall(pred_text)
if len(res) == 1:
answer = res[0] # 'A', 'B', ...
pred = {'text': 'FAILED', 'prompt': 'Unknown'}
pred_text = 'FAILED'
else:
answer = "FAILED"
pred = predictions[prob_id]
pred_text = pred['text']

if pred_text in args.options:
answer = pred_text
elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
answer = pred_text[0]
else:
pattern = re.compile(r'The answer is ([A-Z]).')
res = pattern.findall(pred_text)
if len(res) == 1:
answer = res[0] # 'A', 'B', ...
else:
answer = "FAILED"

pred_idx = get_pred_idx(answer, prob['choices'], args.options)

Expand All @@ -87,7 +95,14 @@ def get_pred_idx(prediction, choices, options):

correct = len(results['correct'])
total = len(results['correct']) + len(results['incorrect'])
print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')

###### IMG ######
multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
multimodal_total = multimodal_correct + multimodal_incorrect
###### IMG ######

print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')

sqa_results['acc'] = correct / total * 100
sqa_results['correct'] = correct
Expand Down
Loading

0 comments on commit 4241f08

Please sign in to comment.