From a4eedf16c63e680380f35c450deab7fb082aaf49 Mon Sep 17 00:00:00 2001 From: Yuvraj Sharma <48665385+yvrjsharma@users.noreply.github.com> Date: Tue, 10 Oct 2023 12:02:02 +0530 Subject: [PATCH 01/12] =?UTF-8?q?Community=20Build=20Demo=20on=20?= =?UTF-8?q?=F0=9F=A4=97Spaces?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit added a link for a community contributed demo hosted on Spaces --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 48e70fd25..1c12391e2 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ *Visual instruction tuning towards large language and vision models with GPT-4 level capabilities.* -[[Project Page](https://llava-vl.github.io/)] [[Demo](https://llava.hliu.cc/)] [[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)] [[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] +[[Project Page](https://llava-vl.github.io/)] [[Demo](https://llava.hliu.cc/)] [[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)] [[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] [[Community Build Demo on 🤗Spaces](https://huggingface.co/spaces/badayvedat/LLaVA)] **Improved Baselines with Visual Instruction Tuning** [[Paper](https://arxiv.org/abs/2310.03744)]
[Haotian Liu](https://hliu.cc), [Chunyuan Li](https://chunyuan.li/), [Yuheng Li](https://yuheng-li.github.io/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/) From 66cdaabff75db20f68c838f21521bb2cfc0dca3d Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Wed, 11 Oct 2023 01:36:57 +0900 Subject: [PATCH 02/12] Fix typo in attention.py implemetation -> implementation --- llava/model/language_model/mpt/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llava/model/language_model/mpt/attention.py b/llava/model/language_model/mpt/attention.py index e5c758afa..b5543ef21 100644 --- a/llava/model/language_model/mpt/attention.py +++ b/llava/model/language_model/mpt/attention.py @@ -151,7 +151,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softma class MultiheadAttention(nn.Module): """Multi-head self attention. - Using torch or triton attention implemetation enables user to also use + Using torch or triton attention implementation enables user to also use additive bias. """ @@ -204,7 +204,7 @@ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, i class MultiQueryAttention(nn.Module): """Multi-Query self attention. - Using torch or triton attention implemetation enables user to also use + Using torch or triton attention implementation enables user to also use additive bias. """ @@ -297,4 +297,4 @@ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None slopes = gen_slopes(n_heads, alibi_bias_max, device=device) alibi_bias = alibi_bias * slopes return alibi_bias.to(dtype=dtype) -ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention} \ No newline at end of file +ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention} From 148b72fc241e369c01b3b71f454fd0ec50f58b16 Mon Sep 17 00:00:00 2001 From: Didier Durand Date: Wed, 11 Oct 2023 20:18:30 +0200 Subject: [PATCH 03/12] Fixing typos in README.md Hi, as title says. Didier --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 75a8ebfc2..392d7eaf3 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,8 @@ - [10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), with evaluation scripts coming this week! - [10/5] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). - [9/26] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [[LLavA-RLHF]](https://llava-rlhf.github.io/) -- [9/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accpeted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accpeted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**. -- [9/20] We summarize our emprical study of training 33B and 65B LLaVA mdoels in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [``Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020) +- [9/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accepted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accepted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**. +- [9/20] We summarize our empirical study of training 33B and 65B LLaVA mdoels in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [``Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020)

From c24aba9038adc871a97a4009d5a875294f12469b Mon Sep 17 00:00:00 2001 From: Didier Durand Date: Wed, 11 Oct 2023 20:25:03 +0200 Subject: [PATCH 04/12] Update LLaVA_Bench.md Double negation "less unexplored" doesn't seem appropriate for the intended meaning --- docs/LLaVA_Bench.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/LLaVA_Bench.md b/docs/LLaVA_Bench.md index 5921964c4..643fee99c 100644 --- a/docs/LLaVA_Bench.md +++ b/docs/LLaVA_Bench.md @@ -4,7 +4,7 @@ - [Multimodal Bing-Chat by Microsoft](https://blogs.bing.com/search/july-2023/Bing-Chat-Enterprise-announced,-multimodal-Visual-Search-rolling-out-to-Bing-Chat) (July 18, 2023) - [Multimodal Bard by Google](https://bard.google.com/). -These chatbots are presumably supported by proprietary large multimodal models (LMM). Compared with the open-source LMM such as LLaVA, proprietary LMM represent the scaling success upperbound of the current SoTA techniques. They share the goal of developing multimodal chatbots that follow human intents to complete various daily-life visual tasks in the wild. While it remains less unexplored how to evaluate multimodal chat ability, it provides useful feedback to study open-source LMMs against the commercial multimodal chatbots. In addition to the *LLaVA-Bench (COCO)* dataset we used to develop the early versions of LLaVA, we are releasing [*LLaVA-Bench (In-the-Wild)*](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to the community for the public use. +These chatbots are presumably supported by proprietary large multimodal models (LMM). Compared with the open-source LMM such as LLaVA, proprietary LMM represent the scaling success upperbound of the current SoTA techniques. They share the goal of developing multimodal chatbots that follow human intents to complete various daily-life visual tasks in the wild. While it remains less explored how to evaluate multimodal chat ability, it provides useful feedback to study open-source LMMs against the commercial multimodal chatbots. In addition to the *LLaVA-Bench (COCO)* dataset we used to develop the early versions of LLaVA, we are releasing [*LLaVA-Bench (In-the-Wild)*](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to the community for the public use. ## LLaVA-Bench (In-the-Wild *[Ongoing work]*) From 01cb8e8ca43b204186cf2ea46ba6c80f5cec1d0e Mon Sep 17 00:00:00 2001 From: Didier Durand Date: Wed, 11 Oct 2023 20:28:26 +0200 Subject: [PATCH 05/12] fixing typo in LoRA.md --- docs/LoRA.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/LoRA.md b/docs/LoRA.md index 369fe9257..bed25f57d 100644 --- a/docs/LoRA.md +++ b/docs/LoRA.md @@ -6,7 +6,7 @@ You need latest code base for LoRA support (instructions [here](https://github.c ## Demo (Web UI) -Please execute each of the command below one by one (after the previous one has finished). The commands are the same as launching other demos except for an additional `--model-base` flag to specify the base model to use. Please make sure the base model corresponds to the LoRA checkpoint that you are using. For this technical preview, you need Vicuna v1.1 (7B) checkpoint (if you do not have that already, follow the instructions [here](https://github.com/lm-sys/FastChat#vicuna-weights)). +Please execute each of the commands below one by one (after the previous one has finished). The commands are the same as launching other demos except for an additional `--model-base` flag to specify the base model to use. Please make sure the base model corresponds to the LoRA checkpoint that you are using. For this technical preview, you need Vicuna v1.1 (7B) checkpoint (if you do not have that already, follow the instructions [here](https://github.com/lm-sys/FastChat#vicuna-weights)). #### Launch a controller ```Shell From ce1aa08d129bd7df931a7067bcb79cf3cb3a1af5 Mon Sep 17 00:00:00 2001 From: Haotian Liu Date: Wed, 11 Oct 2023 16:58:25 -0700 Subject: [PATCH 06/12] Release evaluation scripts. --- README.md | 2 +- docs/Evaluation.md | 142 +++++++++ llava/eval/eval_mmbench.py | 226 +++++++++++++++ llava/eval/eval_pope.py | 81 ++++++ llava/eval/eval_science_qa.py | 35 ++- llava/eval/eval_textvqa.py | 65 +++++ llava/eval/m4c_evaluator.py | 334 ++++++++++++++++++++++ llava/eval/model_vqa.py | 2 +- llava/eval/model_vqa_loader.py | 144 ++++++++++ llava/eval/model_vqa_mmbench.py | 170 +++++++++++ llava/eval/model_vqa_science.py | 14 +- llava/eval/summarize_gpt_review.py | 22 +- scripts/convert_gqa_for_eval.py | 18 ++ scripts/convert_mmbench_for_submission.py | 27 ++ scripts/convert_mmvet_for_eval.py | 18 ++ scripts/convert_seed_for_submission.py | 74 +++++ scripts/convert_vizwiz_for_submission.py | 47 +++ scripts/convert_vqav2_for_submission.py | 56 ++++ scripts/v1_5/eval/gqa.sh | 39 +++ scripts/v1_5/eval/llavabench.sh | 23 ++ scripts/v1_5/eval/mmbench.sh | 19 ++ scripts/v1_5/eval/mmbench_cn.sh | 20 ++ scripts/v1_5/eval/mme.sh | 17 ++ scripts/v1_5/eval/mmvet.sh | 16 ++ scripts/v1_5/eval/pope.sh | 14 + scripts/v1_5/eval/seed.sh | 39 +++ scripts/v1_5/eval/sqa.sh | 16 ++ scripts/v1_5/eval/textvqa.sh | 13 + scripts/v1_5/eval/vizwiz.sh | 14 + scripts/v1_5/eval/vqav2.sh | 36 +++ 30 files changed, 1721 insertions(+), 22 deletions(-) create mode 100644 docs/Evaluation.md create mode 100644 llava/eval/eval_mmbench.py create mode 100644 llava/eval/eval_pope.py create mode 100644 llava/eval/eval_textvqa.py create mode 100644 llava/eval/m4c_evaluator.py create mode 100644 llava/eval/model_vqa_loader.py create mode 100644 llava/eval/model_vqa_mmbench.py create mode 100644 scripts/convert_gqa_for_eval.py create mode 100644 scripts/convert_mmbench_for_submission.py create mode 100644 scripts/convert_mmvet_for_eval.py create mode 100644 scripts/convert_seed_for_submission.py create mode 100644 scripts/convert_vizwiz_for_submission.py create mode 100644 scripts/convert_vqav2_for_submission.py create mode 100644 scripts/v1_5/eval/gqa.sh create mode 100644 scripts/v1_5/eval/llavabench.sh create mode 100644 scripts/v1_5/eval/mmbench.sh create mode 100644 scripts/v1_5/eval/mmbench_cn.sh create mode 100644 scripts/v1_5/eval/mme.sh create mode 100644 scripts/v1_5/eval/mmvet.sh create mode 100644 scripts/v1_5/eval/pope.sh create mode 100644 scripts/v1_5/eval/seed.sh create mode 100644 scripts/v1_5/eval/sqa.sh create mode 100644 scripts/v1_5/eval/textvqa.sh create mode 100644 scripts/v1_5/eval/vizwiz.sh create mode 100644 scripts/v1_5/eval/vqav2.sh diff --git a/README.md b/README.md index 75a8ebfc2..805376773 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ ## Release -- [10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), with evaluation scripts coming this week! +- [10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)! - [10/5] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). - [9/26] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [[LLavA-RLHF]](https://llava-rlhf.github.io/) - [9/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accpeted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accpeted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**. diff --git a/docs/Evaluation.md b/docs/Evaluation.md new file mode 100644 index 000000000..899dfe8e9 --- /dev/null +++ b/docs/Evaluation.md @@ -0,0 +1,142 @@ +# Evaluation + +In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs. + +Currently, we mostly utilize the official toolkit or server for the evaluation. + +## Evaluate on Custom Datasets + +You can evaluate LLaVA on your custom datasets by converting your dataset to LLaVA's jsonl format, and evaluate using [`model_vqa.py`](https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/model_vqa.py). + +Below we provide a general guideline for evaluating datasets with some common formats. + +1. Short-answer (e.g. VQAv2, MME). + +``` + +Answer the question using a single word or phrase. +``` + +2. Option-only for multiple-choice (e.g. MMBench, SEED-Bench). + +``` + +A. +B. +C. +D. +Answer with the option's letter from the given choices directly. +``` + +3. Natural QA (e.g. LLaVA-Bench, MM-Vet). + +No postprocessing is needed. + +## Scripts + +Before preparing task-specific data, download [eval.zip](https://drive.google.com/file/d/1atZSBBrAX54yYpxtVVW33zFvcnaHeFPy/view?usp=sharing). It contains custom annotations, scripts, and the prediction files with LLaVA v1.5. Extract to `./playground/data/eval`. This also provides a general structure for all datasets. + +### VQAv2 + +1. Download [`test2015`](http://images.cocodataset.org/zips/test2015.zip) and put it under `./playground/data/eval/vqav2`. +2. Multi-GPU inference. +```Shell +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/vqav2.sh +``` +3. Submit the results to the evaluation server: `./playground/data/eval/vqav2/answers_upload`. + +### GQA + +1. Download the data following the official instructions [here](https://cs.stanford.edu/people/dorarad/gqa/download.html) and put under `./playground/data/eval/gqa/data`. +2. Multi-GPU inference. +```Shell +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/gqa.sh +``` + +### VisWiz + +1. Download [`test.json`](https://vizwiz.cs.colorado.edu/VizWiz_final/vqa_data/Annotations.zip) and extract [`test.zip`](https://vizwiz.cs.colorado.edu/VizWiz_final/images/test.zip) to `test`. Put them under `./playground/data/eval/vizwiz`. +2. Single-GPU inference. +```Shell +CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/vizwiz.sh +``` +3. Submit the results to the evaluation server: `./playground/data/eval/vizwiz/answers_upload`. + +### ScienceQA + +1. Under `./playground/data/eval/scienceqa`, download `images`, `pid_splits.json`, `problems.json` from the `data/scienceqa` folder of the ScienceQA [repo](https://github.com/lupantech/ScienceQA). +2. Single-GPU inference and evaluate. +```Shell +CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/sqa.sh +``` + +### TextVQA + +1. Download [`TextVQA_0.5.1_val.json](https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json) and [images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) and extract to `./playground/data/eval/textvqa`. +2. Single-GPU inference and evaluate. +```Shell +CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/textvqa.sh +``` + +### POPE + +1. Download `coco` from [POPE](https://github.com/AoiDragon/POPE/tree/e3e39262c85a6a83f26cf5094022a782cb0df58d/output/coco) and put under `./playground/data/eval/pope`. +2. Single-GPU inference and evaluate. +```Shell +CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/pope.sh +``` + +### MME + +1. Download the data following the official instructions [here](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation). +2. Downloaded images to `MME_Benchmark_release_version`. +3. put the official `eval_tool` and `MME_Benchmark_release_version` under `./playground/data/eval/MME`. +4. Single-GPU inference and evaluate. +```Shell +CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mme.sh +``` + +### MMBench + +1. Download `mmbench_dev_20230712.tsv` from the official [website](https://github.com/open-compass/MMBench) and put under `./playground/data/eval/mmbench`. +2. Single-GPU inference. +```Shell +CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench.sh +``` +3. Submit the results to the evaluation server: `./playground/data/eval/mmbench/answers_upload/mmbench_dev_20230712`. + +### MMBench-CN + +1. Download `mmbench_dev_cn_20231003.tsv` from the official [website](https://github.com/open-compass/MMBench) and put under `./playground/data/eval/mmbench`. +2. Single-GPU inference. +```Shell +CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench_cn.sh +``` +3. Submit the results to the evaluation server: `./playground/data/eval/mmbench/answers_upload/mmbench_dev_cn_20231003`. + +### SEED-Bench + +1. Following the official [instructions](https://github.com/AILab-CVC/SEED-Bench/blob/main/DATASET.md) to download the images and the videos. Put images under `./playground/data/eval/seed_bench/SEED-Bench-image`. +2. Extract the video frame in the middle from the downloaded videos, and put them under `./playground/data/eval/seed_bench/SEED-Bench-video-image`. We provide our script `extract_video_frames.py` modified from the official one. +3. Multiple-GPU inference and evaluate. +```Shell +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/seed.sh +``` +4. Optionally, submit the results to the leaderboard: `./playground/data/eval/seed_bench/answers_upload` using the official jupyter notebook. + +### LLaVA-Bench-in-the-Wild + +1. Extract contents of [`llava-bench-in-the-wild`](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to `./playground/data/eval/llava-bench-in-the-wild`. +2. Single-GPU inference and evaluate. +```Shell +CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/llavabench.sh +``` + +### MM-Vet + +1. Extract [`mm-vet.zip`](https://github.com/yuweihao/MM-Vet/releases/download/v1/mm-vet.zip) to `./playground/data/eval/mmvet`. +2. Single-GPU inference. +```Shell +CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmvet.sh +``` +3. Evaluate the predictions in `./playground/data/eval/mmvet/results` using the official jupyter notebook. diff --git a/llava/eval/eval_mmbench.py b/llava/eval/eval_mmbench.py new file mode 100644 index 000000000..c4205d61b --- /dev/null +++ b/llava/eval/eval_mmbench.py @@ -0,0 +1,226 @@ +import argparse +import os +import json +import pandas as pd +from tqdm import tqdm +import openai +from concurrent.futures import ThreadPoolExecutor, as_completed +import math +import time + + +all_options = ['A', 'B', 'C', 'D'] + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + + +def get_row(df, colname, value): + assert (df[colname] == value).sum() == 1 + return df[df[colname] == value].iloc[0] + + +def encode_query(question, options, answer): + query = "" + query += "Question: " + question + "\n" + query += "Options: " + "\n".join([f"{option_char}. {option}" for option_char, option in zip(all_options[:len(options)], options)]) + "\n" + query += "Answer: " + answer + "\n" + return query + + +def get_openai_api(): + api_type = os.environ.get('API_TYPE', 'azure') + + if api_type == 'azure': + api_key = os.environ.get('API_KEY', 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx') + engine = os.environ.get('ENGINE', 'chatgpt-turbo') + api_host = os.environ.get('API_BASE') + return { + 'api_type': 'azure', + 'api_version': '2023-06-01-preview', + 'engine': engine, + 'api_key': api_key, + 'api_base': f'https://{api_host}.openai.azure.com', + } + else: + api_key = os.environ.get('API_KEY', 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx') + model = os.environ.get('MODEL', 'gpt-3.5-turbo-0301') + + return { + 'model': model, + 'api_key': api_key, + } + + +def chatgpt_extract_answer( + question, options, answer, max_tokens=64, temperature=0.2, top_p=0.9, frequency_penalty=0, presence_penalty=0, + request_timeout=None, num_retry=1): + api_kwargs = get_openai_api() + + system_message = """You are an AI assistant to help me matching an answer with several options of a multiple choice question. +You are provided with a question, several options, and an answer, and you need to find which option is most similar to the answer. +If the meaning of all options are significantly different from the answer, output X. +You should output a single uppercase character in A, B, C, D, if they are valid options, and X otherwise.""" + exemplers = [ + { + "question": "What is the main object in image?", + "options": ["teddy bear", "rabbit", "cat", "dog"], + "answer": "a cute teddy bear", + "output": "A", + }, + { + "question": "What is the main object in image?", + "options": ["teddy bear", "rabbit", "cat", "dog"], + "answer": "Spider", + "output": "X", + }, + ] + + messages = [ + {"role": "system", "content": system_message}, + ] + for exempler in exemplers: + messages.append({"role": "user", "content": encode_query(exempler['question'], exempler['options'], exempler['answer'])}) + messages.append({"role": "assistant", "content": exempler['output']}) + messages.append({"role": "user", "content": encode_query(question, options, answer)}) + + response = None + attempts = [] + for i in range(num_retry): + try: + response = openai.ChatCompletion.create( + messages = messages, + max_tokens = max_tokens, + temperature = temperature, + top_p = top_p, + frequency_penalty = frequency_penalty, + presence_penalty = presence_penalty, + request_timeout = request_timeout, + **api_kwargs + ) + except Exception as e: + if type(e) in [openai.error.RateLimitError, openai.error.APIError, openai.error.APIConnectionError, openai.error.Timeout]: + pass + elif type(e) in [openai.error.AuthenticationError, openai.error.InvalidRequestError]: + print(e) + return None + else: + print(type(e), e) + attempts.append(e.__class__.__name__) + time.sleep(1) + else: + time.sleep(1) + break + + if response is None: + print(f'All {num_retry} attempts failed: {attempts}. Returning None.') + return None + + content = response['choices'][0]['message']['content'] + content = content.strip() + return content + +def is_none(value): + if value is None: + return True + if type(value) is float and math.isnan(value): + return True + if type(value) is str and value.lower() == 'nan': + return True + if type(value) is str and value.lower() == 'none': + return True + return False + +def get_options(row, options): + parsed_options = [] + for option in options: + option_value = row[option] + if is_none(option_value): + break + parsed_options.append(option_value) + return parsed_options + +def auto_parse_answer(question, options, answer): + if answer.strip('.').strip().upper() in all_options[:len(options)]: + return answer.strip('.').strip().upper() + expand_option_valid = [f'The answer is {option}.'.lower() in answer.lower() for option in all_options[:len(options)]] + if any(expand_option_valid): + return all_options[expand_option_valid.index(True)] + + matched_ops = [all_options[_i] for _i, option in enumerate(options) if answer.lower() in option.lower()] + if len(matched_ops) == 1: + return matched_ops[0] + return None + +def eval_results(args): + questions = pd.read_table(os.path.expanduser(args.question_file)) + answers = [json.loads(line) for line in open(os.path.expanduser(args.answers_file))] + answers = {(row['question_id'], row.get('round_id', 0)): row for row in answers} + results_file = os.path.expanduser(args.results_file) + if os.path.exists(results_file): + results = [json.loads(line) for line in open(results_file)] + results = {(row['question_id'], row.get('round_id', 0)): row for row in results} + else: + results = {} + results_writer = open(results_file, 'a') + + def process_answer(idx, answer): + if idx in results: + return None + question_id, round_id = idx + question_data = get_row(questions, 'index', question_id) + if 'options' in answer: + options = answer['options'] + option_char = answer['option_char'] + else: + assert round_id == 0, "round_id must be 0 when options are not provided" + options = get_options(question_data, all_options) + option_char = all_options[:len(options)] + option_map = {all_options[i]: option_char[i] for i in range(len(options))} + option_map['X'] = 'X' + parsed_answer = auto_parse_answer(question_data['question'], options, answer['text']) + if parsed_answer is None: + parsed_answer = chatgpt_extract_answer( + question_data['question'], options, answer['text'], + request_timeout=args.request_timeout, num_retry=args.num_retry) + if parsed_answer is None: + return None + if parsed_answer not in option_map: + print(f'Invalid parsed answer: {parsed_answer}') + return None + answer['parsed_answer'] = option_map[parsed_answer] + return answer + + with ThreadPoolExecutor(max_workers=args.max_workers) as executor: + # Submit all tasks to the executor + futures = {executor.submit(process_answer, key, value): key for key, value in answers.items()} + + # Process results as they become available + for future in tqdm(as_completed(futures), total=len(answers)): + answer = future.result() + if answer is not None: + results_writer.write(json.dumps(answer) + '\n') + results_writer.flush() + + results_writer.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--question-file", type=str, default="tables/question.jsonl") + parser.add_argument("--answers-file", type=str, default="answer.jsonl") + parser.add_argument("--results-file", type=str, default="results.jsonl") + parser.add_argument("--max-workers", type=int, default=1) + parser.add_argument("--num-retry", type=int, default=3) + parser.add_argument("--request-timeout", type=int, default=None) + args = parser.parse_args() + + eval_results(args) diff --git a/llava/eval/eval_pope.py b/llava/eval/eval_pope.py new file mode 100644 index 000000000..b115b8f23 --- /dev/null +++ b/llava/eval/eval_pope.py @@ -0,0 +1,81 @@ +import os +import json +import argparse + +def eval_pope(answers, label_file): + label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] + + for answer in answers: + text = answer['text'] + + # Only keep the first sentence + if text.find('.') != -1: + text = text.split('.')[0] + + text = text.replace(',', '') + words = text.split(' ') + if 'No' in words or 'not' in words or 'no' in words: + answer['text'] = 'no' + else: + answer['text'] = 'yes' + + for i in range(len(label_list)): + if label_list[i] == 'no': + label_list[i] = 0 + else: + label_list[i] = 1 + + pred_list = [] + for answer in answers: + if answer['text'] == 'no': + pred_list.append(0) + else: + pred_list.append(1) + + pos = 1 + neg = 0 + yes_ratio = pred_list.count(1) / len(pred_list) + + TP, TN, FP, FN = 0, 0, 0, 0 + for pred, label in zip(pred_list, label_list): + if pred == pos and label == pos: + TP += 1 + elif pred == pos and label == neg: + FP += 1 + elif pred == neg and label == neg: + TN += 1 + elif pred == neg and label == pos: + FN += 1 + + print('TP\tFP\tTN\tFN\t') + print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) + + precision = float(TP) / float(TP + FP) + recall = float(TP) / float(TP + FN) + f1 = 2*precision*recall / (precision + recall) + acc = (TP + TN) / (TP + TN + FP + FN) + print('Accuracy: {}'.format(acc)) + print('Precision: {}'.format(precision)) + print('Recall: {}'.format(recall)) + print('F1 score: {}'.format(f1)) + print('Yes ratio: {}'.format(yes_ratio)) + print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--annotation-dir", type=str) + parser.add_argument("--question-file", type=str) + parser.add_argument("--result-file", type=str) + args = parser.parse_args() + + questions = [json.loads(line) for line in open(args.question_file)] + questions = {question['question_id']: question for question in questions} + answers = [json.loads(q) for q in open(args.result_file)] + for file in os.listdir(args.annotation_dir): + assert file.startswith('coco_pope_') + assert file.endswith('.json') + category = file[10:-5] + cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] + print('Category: {}, # samples: {}'.format(category, len(cur_answers))) + eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) + print("====================================") diff --git a/llava/eval/eval_science_qa.py b/llava/eval/eval_science_qa.py index e1b3ce52f..ccf206bbd 100644 --- a/llava/eval/eval_science_qa.py +++ b/llava/eval/eval_science_qa.py @@ -32,6 +32,7 @@ def get_pred_idx(prediction, choices, options): if prediction in options[:len(choices)]: return options.index(prediction) else: + return -1 return random.choice(range(len(choices))) @@ -55,16 +56,23 @@ def get_pred_idx(prediction, choices, options): for prob_id, prob in split_problems.items(): if prob_id not in predictions: - continue - pred = predictions[prob_id] - pred_text = pred['text'] - - pattern = re.compile(r'The answer is ([A-Z]).') - res = pattern.findall(pred_text) - if len(res) == 1: - answer = res[0] # 'A', 'B', ... + pred = {'text': 'FAILED', 'prompt': 'Unknown'} + pred_text = 'FAILED' else: - answer = "FAILED" + pred = predictions[prob_id] + pred_text = pred['text'] + + if pred_text in args.options: + answer = pred_text + elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": + answer = pred_text[0] + else: + pattern = re.compile(r'The answer is ([A-Z]).') + res = pattern.findall(pred_text) + if len(res) == 1: + answer = res[0] # 'A', 'B', ... + else: + answer = "FAILED" pred_idx = get_pred_idx(answer, prob['choices'], args.options) @@ -87,7 +95,14 @@ def get_pred_idx(prediction, choices, options): correct = len(results['correct']) total = len(results['correct']) + len(results['incorrect']) - print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') + + ###### IMG ###### + multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) + multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) + multimodal_total = multimodal_correct + multimodal_incorrect + ###### IMG ###### + + print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') sqa_results['acc'] = correct / total * 100 sqa_results['correct'] = correct diff --git a/llava/eval/eval_textvqa.py b/llava/eval/eval_textvqa.py new file mode 100644 index 000000000..468f4bb12 --- /dev/null +++ b/llava/eval/eval_textvqa.py @@ -0,0 +1,65 @@ +import os +import argparse +import json +import re + +from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--annotation-file', type=str) + parser.add_argument('--result-file', type=str) + parser.add_argument('--result-dir', type=str) + return parser.parse_args() + + +def prompt_processor(prompt): + if prompt.startswith('OCR tokens: '): + pattern = r"Question: (.*?) Short answer:" + match = re.search(pattern, prompt, re.DOTALL) + question = match.group(1) + elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: + if prompt.startswith('Reference OCR token:'): + question = prompt.split('\n')[1] + else: + question = prompt.split('\n')[0] + elif len(prompt.split('\n')) == 2: + question = prompt.split('\n')[0] + else: + assert False + + return question.lower() + + +def eval_single(annotation_file, result_file): + experiment_name = os.path.splitext(os.path.basename(result_file))[0] + print(experiment_name) + annotations = json.load(open(annotation_file))['data'] + annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} + results = [json.loads(line) for line in open(result_file)] + + pred_list = [] + for result in results: + annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] + pred_list.append({ + "pred_answer": result['text'], + "gt_answers": annotation['answers'], + }) + + evaluator = TextVQAAccuracyEvaluator() + print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) + + +if __name__ == "__main__": + args = get_args() + + if args.result_file is not None: + eval_single(args.annotation_file, args.result_file) + + if args.result_dir is not None: + for result_file in sorted(os.listdir(args.result_dir)): + if not result_file.endswith('.jsonl'): + print(f'Skipping {result_file}') + continue + eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) diff --git a/llava/eval/m4c_evaluator.py b/llava/eval/m4c_evaluator.py new file mode 100644 index 000000000..e30e958da --- /dev/null +++ b/llava/eval/m4c_evaluator.py @@ -0,0 +1,334 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import re + +from tqdm import tqdm + + +class EvalAIAnswerProcessor: + """ + Processes an answer similar to Eval AI + copied from + https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897 + """ + + CONTRACTIONS = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + } + + NUMBER_MAP = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + ARTICLES = ["a", "an", "the"] + PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") + COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)") + PUNCTUATIONS = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def __init__(self, *args, **kwargs): + pass + + def word_tokenize(self, word): + word = word.lower() + word = word.replace(",", "").replace("?", "").replace("'s", " 's") + return word.strip() + + def process_punctuation(self, in_text): + out_text = in_text + for p in self.PUNCTUATIONS: + if (p + " " in in_text or " " + p in in_text) or ( + re.search(self.COMMA_STRIP, in_text) is not None + ): + out_text = out_text.replace(p, "") + else: + out_text = out_text.replace(p, " ") + out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE) + return out_text + + def process_digit_article(self, in_text): + out_text = [] + temp_text = in_text.lower().split() + for word in temp_text: + word = self.NUMBER_MAP.setdefault(word, word) + if word not in self.ARTICLES: + out_text.append(word) + else: + pass + for word_id, word in enumerate(out_text): + if word in self.CONTRACTIONS: + out_text[word_id] = self.CONTRACTIONS[word] + out_text = " ".join(out_text) + return out_text + + def __call__(self, item): + item = self.word_tokenize(item) + item = item.replace("\n", " ").replace("\t", " ").strip() + item = self.process_punctuation(item) + item = self.process_digit_article(item) + return item + + +class TextVQAAccuracyEvaluator: + def __init__(self): + self.answer_processor = EvalAIAnswerProcessor() + + def _compute_answer_scores(self, raw_answers): + """ + compute the accuracy (soft score) of human answers + """ + answers = [self.answer_processor(a) for a in raw_answers] + assert len(answers) == 10 + gt_answers = list(enumerate(answers)) + unique_answers = set(answers) + unique_answer_scores = {} + + for unique_answer in unique_answers: + accs = [] + for gt_answer in gt_answers: + other_answers = [item for item in gt_answers if item != gt_answer] + matching_answers = [ + item for item in other_answers if item[1] == unique_answer + ] + acc = min(1, float(len(matching_answers)) / 3) + accs.append(acc) + unique_answer_scores[unique_answer] = sum(accs) / len(accs) + + return unique_answer_scores + + def eval_pred_list(self, pred_list): + pred_scores = [] + for entry in tqdm(pred_list): + pred_answer = self.answer_processor(entry["pred_answer"]) + unique_answer_scores = self._compute_answer_scores(entry["gt_answers"]) + score = unique_answer_scores.get(pred_answer, 0.0) + pred_scores.append(score) + + accuracy = sum(pred_scores) / len(pred_scores) + return accuracy + + +class STVQAAccuracyEvaluator: + def __init__(self): + self.answer_processor = EvalAIAnswerProcessor() + + def eval_pred_list(self, pred_list): + pred_scores = [] + for entry in pred_list: + pred_answer = self.answer_processor(entry["pred_answer"]) + gts = [self.answer_processor(a) for a in entry["gt_answers"]] + score = 1.0 if pred_answer in gts else 0.0 + pred_scores.append(score) + + accuracy = sum(pred_scores) / len(pred_scores) + return accuracy + + +class STVQAANLSEvaluator: + def __init__(self): + import editdistance # install with `pip install editdistance` + + self.get_edit_distance = editdistance.eval + + def get_anls(self, s1, s2): + s1 = s1.lower().strip() + s2 = s2.lower().strip() + iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2)) + anls = iou if iou >= 0.5 else 0.0 + return anls + + def eval_pred_list(self, pred_list): + pred_scores = [] + for entry in pred_list: + anls = max( + self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"] + ) + pred_scores.append(anls) + + accuracy = sum(pred_scores) / len(pred_scores) + return accuracy + + +class TextCapsBleu4Evaluator: + def __init__(self): + # The following script requires Java 1.8.0 and pycocotools installed. + # The pycocoevalcap can be installed with pip as + # pip install git+https://github.com/ronghanghu/coco-caption.git@python23 + # Original pycocoevalcap code is at https://github.com/tylin/coco-caption + # but has no python3 support yet. + try: + from pycocoevalcap.bleu.bleu import Bleu + from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer + except ModuleNotFoundError: + print( + "Please install pycocoevalcap module using " + "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa + ) + raise + + self.tokenizer = PTBTokenizer() + self.scorer = Bleu(4) + + def eval_pred_list(self, pred_list): + # Create reference and hypotheses captions. + gts = {} + res = {} + for idx, entry in enumerate(pred_list): + gts[idx] = [{"caption": a} for a in entry["gt_answers"]] + res[idx] = [{"caption": entry["pred_answer"]}] + + gts = self.tokenizer.tokenize(gts) + res = self.tokenizer.tokenize(res) + score, _ = self.scorer.compute_score(gts, res) + + bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4) + return bleu4 diff --git a/llava/eval/model_vqa.py b/llava/eval/model_vqa.py index 6c02a617c..59dca734c 100644 --- a/llava/eval/model_vqa.py +++ b/llava/eval/model_vqa.py @@ -66,7 +66,7 @@ def eval_model(args): output_ids = model.generate( input_ids, images=image_tensor.unsqueeze(0).half().cuda(), - do_sample=True, + do_sample=True if args.temperature > 0 else False, temperature=args.temperature, top_p=args.top_p, num_beams=args.num_beams, diff --git a/llava/eval/model_vqa_loader.py b/llava/eval/model_vqa_loader.py new file mode 100644 index 000000000..6e28f17cd --- /dev/null +++ b/llava/eval/model_vqa_loader.py @@ -0,0 +1,144 @@ +import argparse +import torch +import os +import json +from tqdm import tqdm +import shortuuid + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path +from torch.utils.data import Dataset, DataLoader + +from PIL import Image +import math + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + + +# Custom dataset class +class CustomDataset(Dataset): + def __init__(self, questions, image_folder, tokenizer, image_processor, model_config): + self.questions = questions + self.image_folder = image_folder + self.tokenizer = tokenizer + self.image_processor = image_processor + self.model_config = model_config + + def __getitem__(self, index): + line = self.questions[index] + image_file = line["image"] + qs = line["text"] + if self.model_config.mm_use_im_start_end: + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs + else: + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB') + image_tensor = process_images([image], self.image_processor, self.model_config)[0] + + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') + + return input_ids, image_tensor + + def __len__(self): + return len(self.questions) + + +# DataLoader +def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4): + assert batch_size == 1, "batch_size must be 1" + dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config) + data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) + return data_loader + + +def eval_model(args): + # Model + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model_name = get_model_name_from_path(model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) + + questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] + questions = get_chunk(questions, args.num_chunks, args.chunk_idx) + answers_file = os.path.expanduser(args.answers_file) + os.makedirs(os.path.dirname(answers_file), exist_ok=True) + ans_file = open(answers_file, "w") + + if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: + args.conv_mode = args.conv_mode + '_mmtag' + print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') + + data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config) + + for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)): + idx = line["question_id"] + cur_prompt = line["text"] + + stop_str = conv_templates[args.conv_mode].sep if conv_templates[args.conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[args.conv_mode].sep2 + input_ids = input_ids.to(device='cuda', non_blocking=True) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + max_new_tokens=128, + use_cache=True) + + input_token_len = input_ids.shape[1] + n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({"question_id": idx, + "prompt": cur_prompt, + "text": outputs, + "answer_id": ans_id, + "model_id": model_name, + "metadata": {}}) + "\n") + # ans_file.flush() + ans_file.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image-folder", type=str, default="") + parser.add_argument("--question-file", type=str, default="tables/question.jsonl") + parser.add_argument("--answers-file", type=str, default="answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llava_v1") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + args = parser.parse_args() + + eval_model(args) diff --git a/llava/eval/model_vqa_mmbench.py b/llava/eval/model_vqa_mmbench.py new file mode 100644 index 000000000..2ffec1b59 --- /dev/null +++ b/llava/eval/model_vqa_mmbench.py @@ -0,0 +1,170 @@ +import argparse +import torch +import os +import json +import pandas as pd +from tqdm import tqdm +import shortuuid + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path + +from PIL import Image +import math + + +all_options = ['A', 'B', 'C', 'D'] + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + + +def is_none(value): + if value is None: + return True + if type(value) is float and math.isnan(value): + return True + if type(value) is str and value.lower() == 'nan': + return True + if type(value) is str and value.lower() == 'none': + return True + return False + +def get_options(row, options): + parsed_options = [] + for option in options: + option_value = row[option] + if is_none(option_value): + break + parsed_options.append(option_value) + return parsed_options + + +def eval_model(args): + # Model + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model_name = get_model_name_from_path(model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) + + questions = pd.read_table(os.path.expanduser(args.question_file)) + questions = get_chunk(questions, args.num_chunks, args.chunk_idx) + answers_file = os.path.expanduser(args.answers_file) + os.makedirs(os.path.dirname(answers_file), exist_ok=True) + ans_file = open(answers_file, "w") + + if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: + args.conv_mode = args.conv_mode + '_mmtag' + print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') + + for index, row in tqdm(questions.iterrows(), total=len(questions)): + options = get_options(row, all_options) + cur_option_char = all_options[:len(options)] + + if args.all_rounds: + num_rounds = len(options) + else: + num_rounds = 1 + + for round_idx in range(num_rounds): + idx = row['index'] + question = row['question'] + hint = row['hint'] + image = load_image_from_base64(row['image']) + if not is_none(hint): + question = hint + '\n' + question + for option_char, option in zip(all_options[:len(options)], options): + question = question + '\n' + option_char + '. ' + option + qs = cur_prompt = question + if model.config.mm_use_im_start_end: + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs + else: + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + + if args.single_pred_prompt: + if args.lang == 'cn': + qs = qs + '\n' + "请直接回答选项字母。" + else: + qs = qs + '\n' + "Answer with the option's letter from the given choices directly." + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + + image_tensor = process_images([image], image_processor, model.config)[0] + # image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensor.unsqueeze(0).half().cuda(), + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + # no_repeat_ngram_size=3, + max_new_tokens=1024, + use_cache=True) + + input_token_len = input_ids.shape[1] + n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({"question_id": idx, + "round_id": round_idx, + "prompt": cur_prompt, + "text": outputs, + "options": options, + "option_char": cur_option_char, + "answer_id": ans_id, + "model_id": model_name, + "metadata": {}}) + "\n") + ans_file.flush() + + # rotate options + options = options[1:] + options[:1] + cur_option_char = cur_option_char[1:] + cur_option_char[:1] + ans_file.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image-folder", type=str, default="") + parser.add_argument("--question-file", type=str, default="tables/question.jsonl") + parser.add_argument("--answers-file", type=str, default="answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llava_v1") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + parser.add_argument("--all-rounds", action="store_true") + parser.add_argument("--single-pred-prompt", action="store_true") + parser.add_argument("--lang", type=str, default="en") + args = parser.parse_args() + + eval_model(args) diff --git a/llava/eval/model_vqa_science.py b/llava/eval/model_vqa_science.py index aa77b39c0..e99501f2b 100644 --- a/llava/eval/model_vqa_science.py +++ b/llava/eval/model_vqa_science.py @@ -57,6 +57,10 @@ def eval_model(args): else: images = None + if args.single_pred_prompt: + qs = qs + '\n' + "Answer with the option's letter from the given choices directly." + cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly." + conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) @@ -72,8 +76,8 @@ def eval_model(args): output_ids = model.generate( input_ids, images=images, - do_sample=True, - temperature=0.2, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, max_new_tokens=1024, use_cache=True, stopping_criteria=stopping_criteria, @@ -98,8 +102,8 @@ def eval_model(args): output_ids = model.generate( input_ids, images=images, - do_sample=True, - temperature=0.2, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, max_new_tokens=64, use_cache=True, stopping_criteria=[stopping_criteria]) @@ -135,7 +139,9 @@ def eval_model(args): parser.add_argument("--conv-mode", type=str, default="llava_v0") parser.add_argument("--num-chunks", type=int, default=1) parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--answer-prompter", action="store_true") + parser.add_argument("--single-pred-prompt", action="store_true") args = parser.parse_args() eval_model(args) diff --git a/llava/eval/summarize_gpt_review.py b/llava/eval/summarize_gpt_review.py index ee26f0a51..0f796a388 100644 --- a/llava/eval/summarize_gpt_review.py +++ b/llava/eval/summarize_gpt_review.py @@ -9,8 +9,10 @@ def parse_args(): parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') parser.add_argument('-d', '--dir', default=None) - parser.add_argument('-f', '--files', nargs='*', default=None) - parser.add_argument('-i', '--ignore', nargs='*', default=None) + parser.add_argument('-v', '--version', default=None) + parser.add_argument('-s', '--select', nargs='*', default=None) + parser.add_argument('-f', '--files', nargs='*', default=[]) + parser.add_argument('-i', '--ignore', nargs='*', default=[]) return parser.parse_args() @@ -20,19 +22,27 @@ def parse_args(): if args.ignore is not None: args.ignore = [int(x) for x in args.ignore] - if args.files is not None and len(args.files) > 0: + if len(args.files) > 0: review_files = args.files else: - review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_'))] + review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)] for review_file in sorted(review_files): config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') + if args.select is not None and any(x not in config for x in args.select): + continue + if '0613' in config: + version = '0613' + else: + version = '0314' + if args.version is not None and args.version != version: + continue scores = defaultdict(list) print(config) with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: for review_str in f: review = json.loads(review_str) - if args.ignore is not None and review['question_id'] in args.ignore: + if review['question_id'] in args.ignore: continue if 'category' in review: scores[review['category']].append(review['tuple']) @@ -46,5 +56,5 @@ def parse_args(): stats = np.asarray(v).mean(0).tolist() stats = [round(x, 3) for x in stats] # print(k, stats, round(stats[1]/stats[0]*100, 1)) - print(k, round(stats[1]/stats[0]*100, 1)) + print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1)) print('=================================') diff --git a/scripts/convert_gqa_for_eval.py b/scripts/convert_gqa_for_eval.py new file mode 100644 index 000000000..4d46c8b87 --- /dev/null +++ b/scripts/convert_gqa_for_eval.py @@ -0,0 +1,18 @@ +import os +import json +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--src", type=str) +parser.add_argument("--dst", type=str) +args = parser.parse_args() + +all_answers = [] +for line_idx, line in enumerate(open(args.src)): + res = json.loads(line) + question_id = res['question_id'] + text = res['text'].rstrip('.').lower() + all_answers.append({"questionId": question_id, "prediction": text}) + +with open(args.dst, 'w') as f: + json.dump(all_answers, f) diff --git a/scripts/convert_mmbench_for_submission.py b/scripts/convert_mmbench_for_submission.py new file mode 100644 index 000000000..27baec12f --- /dev/null +++ b/scripts/convert_mmbench_for_submission.py @@ -0,0 +1,27 @@ +import os +import json +import argparse +import pandas as pd + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--annotation-file", type=str, required=True) + parser.add_argument("--result-dir", type=str, required=True) + parser.add_argument("--upload-dir", type=str, required=True) + parser.add_argument("--experiment", type=str, required=True) + + return parser.parse_args() + +if __name__ == "__main__": + args = get_args() + + df = pd.read_table(args.annotation_file) + + cur_df = df.copy() + cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category']) + cur_df.insert(6, 'prediction', None) + for pred in open(os.path.join(args.result_dir, f"{args.experiment}.jsonl")): + pred = json.loads(pred) + cur_df.loc[df['index'] == pred['question_id'], 'prediction'] = pred['text'] + + cur_df.to_excel(os.path.join(args.upload_dir, f"{args.experiment}.xlsx"), index=False, engine='openpyxl') diff --git a/scripts/convert_mmvet_for_eval.py b/scripts/convert_mmvet_for_eval.py new file mode 100644 index 000000000..97f5cfb7f --- /dev/null +++ b/scripts/convert_mmvet_for_eval.py @@ -0,0 +1,18 @@ +import os +import json +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--src", type=str) +parser.add_argument("--dst", type=str) +args = parser.parse_args() + +cur_result = {} + +for line in open(args.src): + data = json.loads(line) + qid = data['question_id'] + cur_result[f'v1_{qid}'] = data['text'] + +with open(args.dst, 'w') as f: + json.dump(cur_result, f, indent=2) diff --git a/scripts/convert_seed_for_submission.py b/scripts/convert_seed_for_submission.py new file mode 100644 index 000000000..ae903e630 --- /dev/null +++ b/scripts/convert_seed_for_submission.py @@ -0,0 +1,74 @@ +import os +import json +import argparse + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--annotation-file", type=str) + parser.add_argument("--result-file", type=str) + parser.add_argument("--result-upload-file", type=str) + return parser.parse_args() + + +def eval_single(result_file, eval_only_type=None): + results = {} + for line in open(result_file): + row = json.loads(line) + results[row['question_id']] = row + + type_counts = {} + correct_counts = {} + for question_data in data['questions']: + if eval_only_type is not None and question_data['data_type'] != eval_only_type: continue + data_type = question_data['question_type_id'] + type_counts[data_type] = type_counts.get(data_type, 0) + 1 + try: + question_id = int(question_data['question_id']) + except: + question_id = question_data['question_id'] + if question_id not in results: + correct_counts[data_type] = correct_counts.get(data_type, 0) + continue + row = results[question_id] + if row['text'] == question_data['answer']: + correct_counts[data_type] = correct_counts.get(data_type, 0) + 1 + + total_count = 0 + total_correct = 0 + for data_type in sorted(type_counts.keys()): + accuracy = correct_counts[data_type] / type_counts[data_type] * 100 + if eval_only_type is None: + print(f"{ques_type_id_to_name[data_type]}: {accuracy:.2f}%") + + total_count += type_counts[data_type] + total_correct += correct_counts[data_type] + + total_accuracy = total_correct / total_count * 100 + if eval_only_type is None: + print(f"Total accuracy: {total_accuracy:.2f}%") + else: + print(f"{eval_only_type} accuracy: {total_accuracy:.2f}%") + + return results + +if __name__ == "__main__": + args = get_args() + data = json.load(open(args.annotation_file)) + ques_type_id_to_name = {id:n for n,id in data['question_type'].items()} + + results = eval_single(args.result_file) + eval_single(args.result_file, eval_only_type='image') + eval_single(args.result_file, eval_only_type='video') + + with open(args.result_upload_file, 'w') as fp: + for question in data['questions']: + qid = question['question_id'] + if qid in results: + result = results[qid] + else: + result = results[int(qid)] + fp.write(json.dumps({ + 'question_id': qid, + 'prediction': result['text'] + }) + '\n') diff --git a/scripts/convert_vizwiz_for_submission.py b/scripts/convert_vizwiz_for_submission.py new file mode 100644 index 000000000..7836d19f5 --- /dev/null +++ b/scripts/convert_vizwiz_for_submission.py @@ -0,0 +1,47 @@ +import os +import argparse +import json + +from llava.eval.m4c_evaluator import EvalAIAnswerProcessor + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--annotation-file', type=str, required=True) + parser.add_argument('--result-file', type=str, required=True) + parser.add_argument('--result-upload-file', type=str, required=True) + return parser.parse_args() + + +if __name__ == '__main__': + + args = parse_args() + + os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True) + + results = [] + error_line = 0 + for line_idx, line in enumerate(open(args.result_file)): + try: + results.append(json.loads(line)) + except: + error_line += 1 + results = {x['question_id']: x['text'] for x in results} + test_split = [json.loads(line) for line in open(args.annotation_file)] + split_ids = set([x['question_id'] for x in test_split]) + + print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}') + + all_answers = [] + + answer_processor = EvalAIAnswerProcessor() + + for x in test_split: + assert x['question_id'] in results + all_answers.append({ + 'image': x['image'], + 'answer': answer_processor(results[x['question_id']]) + }) + + with open(args.result_upload_file, 'w') as f: + json.dump(all_answers, f) diff --git a/scripts/convert_vqav2_for_submission.py b/scripts/convert_vqav2_for_submission.py new file mode 100644 index 000000000..05f67b33a --- /dev/null +++ b/scripts/convert_vqav2_for_submission.py @@ -0,0 +1,56 @@ +import os +import argparse +import json + +from llava.eval.m4c_evaluator import EvalAIAnswerProcessor + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dir', type=str, default="./playground/data/eval/vqav2") + parser.add_argument('--ckpt', type=str, required=True) + parser.add_argument('--split', type=str, required=True) + return parser.parse_args() + + +if __name__ == '__main__': + + args = parse_args() + + src = os.path.join(args.dir, 'answers', args.split, args.ckpt, 'merge.jsonl') + test_split = os.path.join(args.dir, 'llava_vqav2_mscoco_test2015.jsonl') + dst = os.path.join(args.dir, 'answers_upload', args.split, f'{args.ckpt}.json') + os.makedirs(os.path.dirname(dst), exist_ok=True) + + results = [] + error_line = 0 + for line_idx, line in enumerate(open(src)): + try: + results.append(json.loads(line)) + except: + error_line += 1 + + results = {x['question_id']: x['text'] for x in results} + test_split = [json.loads(line) for line in open(test_split)] + split_ids = set([x['question_id'] for x in test_split]) + + print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}') + + all_answers = [] + + answer_processor = EvalAIAnswerProcessor() + + for x in test_split: + if x['question_id'] not in results: + all_answers.append({ + 'question_id': x['question_id'], + 'answer': '' + }) + else: + all_answers.append({ + 'question_id': x['question_id'], + 'answer': answer_processor(results[x['question_id']]) + }) + + with open(dst, 'w') as f: + json.dump(all_answers, open(dst, 'w')) diff --git a/scripts/v1_5/eval/gqa.sh b/scripts/v1_5/eval/gqa.sh new file mode 100644 index 000000000..5c3c2c31f --- /dev/null +++ b/scripts/v1_5/eval/gqa.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +gpu_list="${CUDA_VISIBLE_DEVICES:-0}" +IFS=',' read -ra GPULIST <<< "$gpu_list" + +CHUNKS=${#GPULIST[@]} + +CKPT="llava-v1.5-13b" +SPLIT="llava_gqa_testdev_balanced" +GQADIR="./playground/data/eval/gqa/data" + +for IDX in $(seq 0 $((CHUNKS-1))); do + CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/gqa/$SPLIT.jsonl \ + --image-folder ./playground/data/eval/gqa/data/images \ + --answers-file ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ + --num-chunks $CHUNKS \ + --chunk-idx $IDX \ + --temperature 0 \ + --conv-mode vicuna_v1 & +done + +wait + +output_file=./playground/data/eval/gqa/answers/$SPLIT/$CKPT/merge.jsonl + +# Clear out the output file if it exists. +> "$output_file" + +# Loop through the indices and concatenate each file. +for IDX in $(seq 0 $((CHUNKS-1))); do + cat ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" +done + +python scripts/convert_gqa_for_eval.py --src $output_file --dst $GQADIR/testdev_balanced_predictions.json + +cd $GQADIR +python eval/eval.py --tier testdev_balanced diff --git a/scripts/v1_5/eval/llavabench.sh b/scripts/v1_5/eval/llavabench.sh new file mode 100644 index 000000000..ed236e4e3 --- /dev/null +++ b/scripts/v1_5/eval/llavabench.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +python -m llava.eval.model_vqa \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/llava-bench-in-the-wild/questions.jsonl \ + --image-folder ./playground/data/eval/llava-bench-in-the-wild/images \ + --answers-file ./playground/data/eval/llava-bench-in-the-wild/answers/llava-v1.5-13b.jsonl \ + --temperature 0 \ + --conv-mode vicuna_v1 + +mkdir -p playground/data/eval/llava-bench-in-the-wild/reviews + +python llava/eval/eval_gpt_review_bench.py \ + --question playground/data/eval/llava-bench-in-the-wild/questions.jsonl \ + --context playground/data/eval/llava-bench-in-the-wild/context.jsonl \ + --rule llava/eval/table/rule.json \ + --answer-list \ + playground/data/eval/llava-bench-in-the-wild/answers_gpt4.jsonl \ + playground/data/eval/llava-bench-in-the-wild/answers/llava-v1.5-13b.jsonl \ + --output \ + playground/data/eval/llava-bench-in-the-wild/reviews/llava-v1.5-13b.jsonl + +python llava/eval/summarize_gpt_review.py -f playground/data/eval/llava-bench-in-the-wild/reviews/llava-v1.5-13b.jsonl diff --git a/scripts/v1_5/eval/mmbench.sh b/scripts/v1_5/eval/mmbench.sh new file mode 100644 index 000000000..d0b3a5c63 --- /dev/null +++ b/scripts/v1_5/eval/mmbench.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +SPLIT="mmbench_dev_20230712" + +python -m llava.eval.model_vqa_mmbench \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/mmbench/$SPLIT.tsv \ + --answers-file ./playground/data/eval/mmbench/answers/$SPLIT/llava-v1.5-13b.jsonl \ + --single-pred-prompt \ + --temperature 0 \ + --conv-mode vicuna_v1 + +mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT + +python scripts/convert_mmbench_for_submission.py \ + --annotation-file ./playground/data/eval/mmbench/$SPLIT.tsv \ + --result-dir ./playground/data/eval/mmbench/answers/$SPLIT \ + --upload-dir ./playground/data/eval/mmbench/answers_upload/$SPLIT \ + --experiment llava-v1.5-13b diff --git a/scripts/v1_5/eval/mmbench_cn.sh b/scripts/v1_5/eval/mmbench_cn.sh new file mode 100644 index 000000000..ce27c93aa --- /dev/null +++ b/scripts/v1_5/eval/mmbench_cn.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +SPLIT="mmbench_dev_cn_20231003" + +python -m llava.eval.model_vqa_mmbench \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/mmbench_cn/$SPLIT.tsv \ + --answers-file ./playground/data/eval/mmbench_cn/answers/$SPLIT/llava-v1.5-13b.jsonl \ + --lang cn \ + --single-pred-prompt \ + --temperature 0 \ + --conv-mode vicuna_v1 + +mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT + +python scripts/convert_mmbench_for_submission.py \ + --annotation-file ./playground/data/eval/mmbench_cn/$SPLIT.tsv \ + --result-dir ./playground/data/eval/mmbench_cn/answers/$SPLIT \ + --upload-dir ./playground/data/eval/mmbench_cn/answers_upload/$SPLIT \ + --experiment llava-v1.5-13b diff --git a/scripts/v1_5/eval/mme.sh b/scripts/v1_5/eval/mme.sh new file mode 100644 index 000000000..9b0f8ca65 --- /dev/null +++ b/scripts/v1_5/eval/mme.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +python -m llava.eval.model_vqa_loader \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/MME/llava_mme.jsonl \ + --image-folder ./playground/data/eval/MME/MME_Benchmark_release_version \ + --answers-file ./playground/data/eval/MME/answers/llava-v1.5-13b.jsonl \ + --temperature 0 \ + --conv-mode vicuna_v1 + +cd ./playground/data/eval/MME + +python convert_answer_to_mme.py --experiment llava-v1.5-13b + +cd eval_tool + +python calculation.py --results_dir answers/llava-v1.5-13b diff --git a/scripts/v1_5/eval/mmvet.sh b/scripts/v1_5/eval/mmvet.sh new file mode 100644 index 000000000..9ff31ed46 --- /dev/null +++ b/scripts/v1_5/eval/mmvet.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +python -m llava.eval.model_vqa \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/mm-vet/llava-mm-vet.jsonl \ + --image-folder ./playground/data/eval/mm-vet/images \ + --answers-file ./playground/data/eval/mm-vet/answers/llava-v1.5-13b.jsonl \ + --temperature 0 \ + --conv-mode vicuna_v1 + +mkdir -p ./playground/data/eval/mm-vet/results + +python scripts/convert_mmvet_for_eval.py \ + --src ./playground/data/eval/mm-vet/answers/llava-v1.5-13b.jsonl \ + --dst ./playground/data/eval/mm-vet/results/llava-v1.5-13b.json + diff --git a/scripts/v1_5/eval/pope.sh b/scripts/v1_5/eval/pope.sh new file mode 100644 index 000000000..93fe449d9 --- /dev/null +++ b/scripts/v1_5/eval/pope.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +python -m llava.eval.model_vqa_loader \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ + --image-folder ./playground/data/eval/pope/val2014 \ + --answers-file ./playground/data/eval/pope/answers/llava-v1.5-13b.jsonl \ + --temperature 0 \ + --conv-mode vicuna_v1 + +python llava/eval/eval_pope.py \ + --annotation-dir ./playground/data/eval/pope/coco \ + --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ + --result-file ./playground/data/eval/pope/answers/llava-v1.5-13b.jsonl diff --git a/scripts/v1_5/eval/seed.sh b/scripts/v1_5/eval/seed.sh new file mode 100644 index 000000000..565e54d1d --- /dev/null +++ b/scripts/v1_5/eval/seed.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +gpu_list="${CUDA_VISIBLE_DEVICES:-0}" +IFS=',' read -ra GPULIST <<< "$gpu_list" + +CHUNKS=${#GPULIST[@]} + +CKPT="llava-v1.5-13b" + +for IDX in $(seq 0 $((CHUNKS-1))); do + CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/seed_bench/llava-seed-bench.jsonl \ + --image-folder ./playground/data/eval/seed_bench \ + --answers-file ./playground/data/eval/seed_bench/answers/$CKPT/${CHUNKS}_${IDX}.jsonl \ + --num-chunks $CHUNKS \ + --chunk-idx $IDX \ + --temperature 0 \ + --conv-mode vicuna_v1 & +done + +wait + +output_file=./playground/data/eval/seed_bench/answers/$CKPT/merge.jsonl + +# Clear out the output file if it exists. +> "$output_file" + +# Loop through the indices and concatenate each file. +for IDX in $(seq 0 $((CHUNKS-1))); do + cat ./playground/data/eval/seed_bench/answers/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" +done + +# Evaluate +python scripts/convert_seed_for_submission.py \ + --annotation-file ./playground/data/eval/seed_bench/SEED-Bench.json \ + --result-file $output_file \ + --result-upload-file ./playground/data/eval/seed_bench/answers_upload/llava-v1.5-13b.jsonl + diff --git a/scripts/v1_5/eval/sqa.sh b/scripts/v1_5/eval/sqa.sh new file mode 100644 index 000000000..8c82dbc25 --- /dev/null +++ b/scripts/v1_5/eval/sqa.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +python -m llava.eval.model_vqa_science \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/scienceqa/llava_test_CQM-A.json \ + --image-folder ./playground/data/eval/scienceqa/images/test \ + --answers-file ./playground/data/eval/scienceqa/answers/llava-v1.5-13b.jsonl \ + --single-pred-prompt \ + --temperature 0 \ + --conv-mode vicuna_v1 + +python llava/eval/eval_science_qa.py \ + --base-dir ./playground/data/eval/scienceqa \ + --result-file ./playground/data/eval/scienceqa/answers/llava-v1.5-13b.jsonl \ + --output-file ./playground/data/eval/scienceqa/answers/llava-v1.5-13b_output.jsonl \ + --output-result ./playground/data/eval/scienceqa/answers/llava-v1.5-13b_result.json diff --git a/scripts/v1_5/eval/textvqa.sh b/scripts/v1_5/eval/textvqa.sh new file mode 100644 index 000000000..12311c3cc --- /dev/null +++ b/scripts/v1_5/eval/textvqa.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +python -m llava.eval.model_vqa_loader \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \ + --image-folder ./playground/data/eval/textvqa/train_images \ + --answers-file ./playground/data/eval/textvqa/answers/llava-v1.5-13b.jsonl \ + --temperature 0 \ + --conv-mode vicuna_v1 + +python -m llava.eval.eval_textvqa \ + --annotation-file ./playground/data/eval/textvqa/TextVQA_0.5.1_val.json \ + --result-file ./playground/data/eval/textvqa/answers/llava-v1.5-13b.jsonl diff --git a/scripts/v1_5/eval/vizwiz.sh b/scripts/v1_5/eval/vizwiz.sh new file mode 100644 index 000000000..16cf35ce1 --- /dev/null +++ b/scripts/v1_5/eval/vizwiz.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +python -m llava.eval.model_vqa_loader \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/vizwiz/llava_test.jsonl \ + --image-folder ./playground/data/eval/vizwiz/test \ + --answers-file ./playground/data/eval/vizwiz/answers/llava-v1.5-13b.jsonl \ + --temperature 0 \ + --conv-mode vicuna_v1 + +python scripts/convert_vizwiz_for_submission.py \ + --annotation-file ./playground/data/eval/vizwiz/llava_test.jsonl \ + --result-file ./playground/data/eval/vizwiz/answers/llava-v1.5-13b.jsonl \ + --result-upload-file ./playground/data/eval/vizwiz/answers_upload/llava-v1.5-13b.json diff --git a/scripts/v1_5/eval/vqav2.sh b/scripts/v1_5/eval/vqav2.sh new file mode 100644 index 000000000..696efe533 --- /dev/null +++ b/scripts/v1_5/eval/vqav2.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +gpu_list="${CUDA_VISIBLE_DEVICES:-0}" +IFS=',' read -ra GPULIST <<< "$gpu_list" + +CHUNKS=${#GPULIST[@]} + +CKPT="llava-v1.5-13b" +SPLIT="llava_vqav2_mscoco_test-dev2015" + +for IDX in $(seq 0 $((CHUNKS-1))); do + CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/vqav2/$SPLIT.jsonl \ + --image-folder ./playground/data/eval/vqav2/test2015 \ + --answers-file ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ + --num-chunks $CHUNKS \ + --chunk-idx $IDX \ + --temperature 0 \ + --conv-mode vicuna_v1 & +done + +wait + +output_file=./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/merge.jsonl + +# Clear out the output file if it exists. +> "$output_file" + +# Loop through the indices and concatenate each file. +for IDX in $(seq 0 $((CHUNKS-1))); do + cat ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" +done + +python scripts/convert_vqav2_for_submission.py --split $SPLIT --ckpt $CKPT + From 1c8418076bd3f3182c73297ae7d39d0c8bda210c Mon Sep 17 00:00:00 2001 From: Haotian Liu Date: Wed, 11 Oct 2023 17:00:09 -0700 Subject: [PATCH 07/12] Release evaluation scripts. --- docs/Evaluation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Evaluation.md b/docs/Evaluation.md index 899dfe8e9..cb14133fb 100644 --- a/docs/Evaluation.md +++ b/docs/Evaluation.md @@ -72,7 +72,7 @@ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/sqa.sh ### TextVQA -1. Download [`TextVQA_0.5.1_val.json](https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json) and [images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) and extract to `./playground/data/eval/textvqa`. +1. Download [`TextVQA_0.5.1_val.json`](https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json) and [images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) and extract to `./playground/data/eval/textvqa`. 2. Single-GPU inference and evaluate. ```Shell CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/textvqa.sh From bb4b22748618fb6cbec35d8d4642fbbd32dec76a Mon Sep 17 00:00:00 2001 From: Haotian Liu Date: Wed, 11 Oct 2023 17:00:29 -0700 Subject: [PATCH 08/12] Bump version to v1.1.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 740f0dd41..4d2f34775 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "llava" -version = "1.1.0" +version = "1.1.1" description = "Towards GPT-4 like large language and visual assistant." readme = "README.md" requires-python = ">=3.8" From bc78774e50852718a782ead510116c6d7f432357 Mon Sep 17 00:00:00 2001 From: Haotian Liu Date: Wed, 11 Oct 2023 17:04:40 -0700 Subject: [PATCH 09/12] Release evaluation scripts. --- llava/eval/eval_mmbench.py | 226 ------------------------------------- 1 file changed, 226 deletions(-) delete mode 100644 llava/eval/eval_mmbench.py diff --git a/llava/eval/eval_mmbench.py b/llava/eval/eval_mmbench.py deleted file mode 100644 index c4205d61b..000000000 --- a/llava/eval/eval_mmbench.py +++ /dev/null @@ -1,226 +0,0 @@ -import argparse -import os -import json -import pandas as pd -from tqdm import tqdm -import openai -from concurrent.futures import ThreadPoolExecutor, as_completed -import math -import time - - -all_options = ['A', 'B', 'C', 'D'] - - -def split_list(lst, n): - """Split a list into n (roughly) equal-sized chunks""" - chunk_size = math.ceil(len(lst) / n) # integer division - return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] - - -def get_chunk(lst, n, k): - chunks = split_list(lst, n) - return chunks[k] - - -def get_row(df, colname, value): - assert (df[colname] == value).sum() == 1 - return df[df[colname] == value].iloc[0] - - -def encode_query(question, options, answer): - query = "" - query += "Question: " + question + "\n" - query += "Options: " + "\n".join([f"{option_char}. {option}" for option_char, option in zip(all_options[:len(options)], options)]) + "\n" - query += "Answer: " + answer + "\n" - return query - - -def get_openai_api(): - api_type = os.environ.get('API_TYPE', 'azure') - - if api_type == 'azure': - api_key = os.environ.get('API_KEY', 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx') - engine = os.environ.get('ENGINE', 'chatgpt-turbo') - api_host = os.environ.get('API_BASE') - return { - 'api_type': 'azure', - 'api_version': '2023-06-01-preview', - 'engine': engine, - 'api_key': api_key, - 'api_base': f'https://{api_host}.openai.azure.com', - } - else: - api_key = os.environ.get('API_KEY', 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx') - model = os.environ.get('MODEL', 'gpt-3.5-turbo-0301') - - return { - 'model': model, - 'api_key': api_key, - } - - -def chatgpt_extract_answer( - question, options, answer, max_tokens=64, temperature=0.2, top_p=0.9, frequency_penalty=0, presence_penalty=0, - request_timeout=None, num_retry=1): - api_kwargs = get_openai_api() - - system_message = """You are an AI assistant to help me matching an answer with several options of a multiple choice question. -You are provided with a question, several options, and an answer, and you need to find which option is most similar to the answer. -If the meaning of all options are significantly different from the answer, output X. -You should output a single uppercase character in A, B, C, D, if they are valid options, and X otherwise.""" - exemplers = [ - { - "question": "What is the main object in image?", - "options": ["teddy bear", "rabbit", "cat", "dog"], - "answer": "a cute teddy bear", - "output": "A", - }, - { - "question": "What is the main object in image?", - "options": ["teddy bear", "rabbit", "cat", "dog"], - "answer": "Spider", - "output": "X", - }, - ] - - messages = [ - {"role": "system", "content": system_message}, - ] - for exempler in exemplers: - messages.append({"role": "user", "content": encode_query(exempler['question'], exempler['options'], exempler['answer'])}) - messages.append({"role": "assistant", "content": exempler['output']}) - messages.append({"role": "user", "content": encode_query(question, options, answer)}) - - response = None - attempts = [] - for i in range(num_retry): - try: - response = openai.ChatCompletion.create( - messages = messages, - max_tokens = max_tokens, - temperature = temperature, - top_p = top_p, - frequency_penalty = frequency_penalty, - presence_penalty = presence_penalty, - request_timeout = request_timeout, - **api_kwargs - ) - except Exception as e: - if type(e) in [openai.error.RateLimitError, openai.error.APIError, openai.error.APIConnectionError, openai.error.Timeout]: - pass - elif type(e) in [openai.error.AuthenticationError, openai.error.InvalidRequestError]: - print(e) - return None - else: - print(type(e), e) - attempts.append(e.__class__.__name__) - time.sleep(1) - else: - time.sleep(1) - break - - if response is None: - print(f'All {num_retry} attempts failed: {attempts}. Returning None.') - return None - - content = response['choices'][0]['message']['content'] - content = content.strip() - return content - -def is_none(value): - if value is None: - return True - if type(value) is float and math.isnan(value): - return True - if type(value) is str and value.lower() == 'nan': - return True - if type(value) is str and value.lower() == 'none': - return True - return False - -def get_options(row, options): - parsed_options = [] - for option in options: - option_value = row[option] - if is_none(option_value): - break - parsed_options.append(option_value) - return parsed_options - -def auto_parse_answer(question, options, answer): - if answer.strip('.').strip().upper() in all_options[:len(options)]: - return answer.strip('.').strip().upper() - expand_option_valid = [f'The answer is {option}.'.lower() in answer.lower() for option in all_options[:len(options)]] - if any(expand_option_valid): - return all_options[expand_option_valid.index(True)] - - matched_ops = [all_options[_i] for _i, option in enumerate(options) if answer.lower() in option.lower()] - if len(matched_ops) == 1: - return matched_ops[0] - return None - -def eval_results(args): - questions = pd.read_table(os.path.expanduser(args.question_file)) - answers = [json.loads(line) for line in open(os.path.expanduser(args.answers_file))] - answers = {(row['question_id'], row.get('round_id', 0)): row for row in answers} - results_file = os.path.expanduser(args.results_file) - if os.path.exists(results_file): - results = [json.loads(line) for line in open(results_file)] - results = {(row['question_id'], row.get('round_id', 0)): row for row in results} - else: - results = {} - results_writer = open(results_file, 'a') - - def process_answer(idx, answer): - if idx in results: - return None - question_id, round_id = idx - question_data = get_row(questions, 'index', question_id) - if 'options' in answer: - options = answer['options'] - option_char = answer['option_char'] - else: - assert round_id == 0, "round_id must be 0 when options are not provided" - options = get_options(question_data, all_options) - option_char = all_options[:len(options)] - option_map = {all_options[i]: option_char[i] for i in range(len(options))} - option_map['X'] = 'X' - parsed_answer = auto_parse_answer(question_data['question'], options, answer['text']) - if parsed_answer is None: - parsed_answer = chatgpt_extract_answer( - question_data['question'], options, answer['text'], - request_timeout=args.request_timeout, num_retry=args.num_retry) - if parsed_answer is None: - return None - if parsed_answer not in option_map: - print(f'Invalid parsed answer: {parsed_answer}') - return None - answer['parsed_answer'] = option_map[parsed_answer] - return answer - - with ThreadPoolExecutor(max_workers=args.max_workers) as executor: - # Submit all tasks to the executor - futures = {executor.submit(process_answer, key, value): key for key, value in answers.items()} - - # Process results as they become available - for future in tqdm(as_completed(futures), total=len(answers)): - answer = future.result() - if answer is not None: - results_writer.write(json.dumps(answer) + '\n') - results_writer.flush() - - results_writer.close() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--question-file", type=str, default="tables/question.jsonl") - parser.add_argument("--answers-file", type=str, default="answer.jsonl") - parser.add_argument("--results-file", type=str, default="results.jsonl") - parser.add_argument("--max-workers", type=int, default=1) - parser.add_argument("--num-retry", type=int, default=3) - parser.add_argument("--request-timeout", type=int, default=None) - args = parser.parse_args() - - eval_results(args) From b084941c46482ff1c4f559871ff23f1ada9cd560 Mon Sep 17 00:00:00 2001 From: Haotian Liu Date: Wed, 11 Oct 2023 17:09:56 -0700 Subject: [PATCH 10/12] Release evaluation scripts. --- docs/Evaluation.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Evaluation.md b/docs/Evaluation.md index cb14133fb..bbbb5e2e8 100644 --- a/docs/Evaluation.md +++ b/docs/Evaluation.md @@ -98,7 +98,7 @@ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mme.sh ### MMBench -1. Download `mmbench_dev_20230712.tsv` from the official [website](https://github.com/open-compass/MMBench) and put under `./playground/data/eval/mmbench`. +1. Download [`mmbench_dev_20230712.tsv`](https://download.openmmlab.com/mmclassification/datasets/mmbench/mmbench_dev_20230712.tsv) and put under `./playground/data/eval/mmbench`. 2. Single-GPU inference. ```Shell CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench.sh @@ -107,7 +107,7 @@ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench.sh ### MMBench-CN -1. Download `mmbench_dev_cn_20231003.tsv` from the official [website](https://github.com/open-compass/MMBench) and put under `./playground/data/eval/mmbench`. +1. Download [`mmbench_dev_cn_20231003.tsv`](https://download.openmmlab.com/mmclassification/datasets/mmbench/mmbench_dev_en_20231003.tsv) and put under `./playground/data/eval/mmbench`. 2. Single-GPU inference. ```Shell CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench_cn.sh From 1619889c712e347be1cb4f78ec66e7cf414ac1a6 Mon Sep 17 00:00:00 2001 From: Haotian Liu Date: Wed, 11 Oct 2023 17:15:55 -0700 Subject: [PATCH 11/12] Release evaluation scripts. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 805376773..d5dada153 100644 --- a/README.md +++ b/README.md @@ -243,7 +243,7 @@ New options to note: In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs. -Detailed evaluation scripts coming soon. +See [Evaluation.md](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md). ### GPT-assisted Evaluation From 08492620c9d041f32d1eba6fc86f9b703ce9ad39 Mon Sep 17 00:00:00 2001 From: Haotian Liu Date: Wed, 11 Oct 2023 17:24:12 -0700 Subject: [PATCH 12/12] Add additional clarification about projector weights --- docs/MODEL_ZOO.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md index 3faf9fab3..d4de226ec 100644 --- a/docs/MODEL_ZOO.md +++ b/docs/MODEL_ZOO.md @@ -34,7 +34,7 @@ The model weights below are *merged* weights. You do not need to apply delta. Th ## Projector weights -The model weights below are projector weights we have pretrained. You can use these projector weights for visual instruction tuning. We'll add more projector weights into model zoo very soon. +These are projector weights we have pretrained. You can use these projector weights for visual instruction tuning. They are just pretrained on image-text pairs, and are **NOT** instruction tuned, which means they do **NOT** follow instructions as good as our official models, and can output repetitive, lengthy, and garbled outputs. If you want to have nice conversations with LLaVA, use the checkpoints above (LLaVA v1.5). **NOTE**: These projector weights are only compatible with the `llava>=1.0.0`, please check out the latest code base if your local code version is below `v1.0.0`.