Skip to content

Commit

Permalink
LLaVA_13b_4bit_vanilla_colab
Browse files Browse the repository at this point in the history
  • Loading branch information
camenduru authored Oct 14, 2023
1 parent c86c8d2 commit 25f16b5
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 103 deletions.
103 changes: 0 additions & 103 deletions LLaVA_13b_4bit_colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,109 +24,6 @@
"!pip install -e ."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# from transformers import AutoTokenizer, BitsAndBytesConfig\n",
"# from llava.model import LlavaLlamaForCausalLM\n",
"# import torch\n",
"# model_path = \"4bit/llava-v1.5-13b-3GB\"\n",
"# kwargs = {\"device_map\": \"auto\"}\n",
"# kwargs['load_in_4bit'] = True\n",
"# kwargs['quantization_config'] = BitsAndBytesConfig(\n",
"# load_in_4bit=True,\n",
"# bnb_4bit_compute_dtype=torch.float16,\n",
"# bnb_4bit_use_double_quant=True,\n",
"# bnb_4bit_quant_type='nf4'\n",
"# )\n",
"# model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n",
"# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
"\n",
"# vision_tower = model.get_vision_tower()\n",
"# if not vision_tower.is_loaded:\n",
"# vision_tower.load_model()\n",
"# vision_tower.to(device='cuda')\n",
"# image_processor = vision_tower.image_processor\n",
"\n",
"# import requests\n",
"# from PIL import Image\n",
"# from io import BytesIO\n",
"\n",
"# def load_image(image_file):\n",
"# if image_file.startswith('http') or image_file.startswith('https'):\n",
"# response = requests.get(image_file)\n",
"# image = Image.open(BytesIO(response.content)).convert('RGB')\n",
"# else:\n",
"# image = Image.open(image_file).convert('RGB')\n",
"# return image\n",
"\n",
"# from llava.conversation import conv_templates, SeparatorStyle\n",
"# from llava.utils import disable_torch_init\n",
"\n",
"# disable_torch_init()\n",
"# conv_mode = \"llava_v0\"\n",
"# conv = conv_templates[conv_mode].copy()\n",
"# roles = conv.roles\n",
"\n",
"# image = load_image(\"https://llava-vl.github.io/static/images/view.jpg\")\n",
"# image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n",
"\n",
"# from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
"# from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n",
"# from transformers import TextStreamer\n",
"# import torch\n",
"\n",
"# while True:\n",
"# try:\n",
"# inp = input(f\"{roles[0]}: \")\n",
"# except EOFError:\n",
"# inp = \"\"\n",
"# if not inp:\n",
"# print(\"exit...\")\n",
"# break\n",
"\n",
"# print(f\"{roles[1]}: \", end=\"\")\n",
"\n",
"# if image is not None:\n",
"# # first message\n",
"# if model.config.mm_use_im_start_end:\n",
"# inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n",
"# else:\n",
"# inp = DEFAULT_IMAGE_TOKEN + '\\n' + inp\n",
"# conv.append_message(conv.roles[0], inp)\n",
"# image = None\n",
"# else:\n",
"# # later messages\n",
"# conv.append_message(conv.roles[0], inp)\n",
"# conv.append_message(conv.roles[1], None)\n",
"# prompt = conv.get_prompt()\n",
"\n",
"# input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
"# stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
"# keywords = [stop_str]\n",
"# stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
"# streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
"\n",
"# with torch.inference_mode():\n",
"# output_ids = model.generate(\n",
"# input_ids,\n",
"# images=image_tensor,\n",
"# do_sample=True,\n",
"# temperature=0.2,\n",
"# max_new_tokens=1024,\n",
"# streamer=streamer,\n",
"# use_cache=True,\n",
"# stopping_criteria=[stopping_criteria])\n",
"\n",
"# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n",
"# conv.messages[-1][-1] = outputs\n",
"\n",
"# print(\"\\n\", {\"prompt\": prompt, \"outputs\": outputs}, \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
154 changes: 154 additions & 0 deletions LLaVA_13b_4bit_vanilla_colab.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github"
},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/LLaVA-colab/blob/main/LLaVA_13b_4bit_colab.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VjYy0F2gZIPR"
},
"outputs": [],
"source": [
"%cd /content\n",
"!git clone -b 5GB https://github.com/camenduru/LLaVA\n",
"%cd /content/LLaVA\n",
"\n",
"!pip install -e ."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, BitsAndBytesConfig\n",
"from llava.model import LlavaLlamaForCausalLM\n",
"import torch\n",
"model_path = \"4bit/llava-v1.5-13b-3GB\"\n",
"kwargs = {\"device_map\": \"auto\"}\n",
"kwargs['load_in_4bit'] = True\n",
"kwargs['quantization_config'] = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_compute_dtype=torch.float16,\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_quant_type='nf4'\n",
")\n",
"model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
"\n",
"vision_tower = model.get_vision_tower()\n",
"if not vision_tower.is_loaded:\n",
" vision_tower.load_model()\n",
"vision_tower.to(device='cuda')\n",
"image_processor = vision_tower.image_processor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from PIL import Image\n",
"from io import BytesIO\n",
"\n",
"def load_image(image_file):\n",
" if image_file.startswith('http') or image_file.startswith('https'):\n",
" response = requests.get(image_file)\n",
" image = Image.open(BytesIO(response.content)).convert('RGB')\n",
" else:\n",
" image = Image.open(image_file).convert('RGB')\n",
" return image\n",
"\n",
"from llava.conversation import conv_templates, SeparatorStyle\n",
"from llava.utils import disable_torch_init\n",
"\n",
"disable_torch_init()\n",
"conv_mode = \"llava_v0\"\n",
"conv = conv_templates[conv_mode].copy()\n",
"roles = conv.roles\n",
"\n",
"image = load_image(\"https://llava-vl.github.io/static/images/view.jpg\")\n",
"image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n",
"\n",
"from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
"from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n",
"from transformers import TextStreamer\n",
"import torch\n",
"\n",
"while True:\n",
" try:\n",
" inp = input(f\"{roles[0]}: \")\n",
" except EOFError:\n",
" inp = \"\"\n",
" if not inp:\n",
" print(\"exit...\")\n",
" break\n",
"\n",
" print(f\"{roles[1]}: \", end=\"\")\n",
"\n",
" if image is not None:\n",
" # first message\n",
" if model.config.mm_use_im_start_end:\n",
" inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n",
" else:\n",
" inp = DEFAULT_IMAGE_TOKEN + '\\n' + inp\n",
" conv.append_message(conv.roles[0], inp)\n",
" image = None\n",
" else:\n",
" # later messages\n",
" conv.append_message(conv.roles[0], inp)\n",
" conv.append_message(conv.roles[1], None)\n",
" prompt = conv.get_prompt()\n",
"\n",
" input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
" stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
" keywords = [stop_str]\n",
" stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
" streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
"\n",
" with torch.inference_mode():\n",
" output_ids = model.generate(\n",
" input_ids,\n",
" images=image_tensor,\n",
" do_sample=True,\n",
" temperature=0.2,\n",
" max_new_tokens=1024,\n",
" streamer=streamer,\n",
" use_cache=True,\n",
" stopping_criteria=[stopping_criteria])\n",
"\n",
" outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n",
" conv.messages[-1][-1] = outputs\n",
"\n",
" print(\"\\n\", {\"prompt\": prompt, \"outputs\": outputs}, \"\\n\")"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
| Colab | Info
| --- | --- |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/LLaVA-colab/blob/main/LLaVA_13b_4bit_colab.ipynb) | 🌋 LLaVA_13b_4bit_colab 13B (4bit)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/LLaVA-colab/blob/main/LLaVA_13b_4bit_vanilla_colab.ipynb) | 🌋LLaVA_13b_4bit_vanilla_colab 13B (4bit) (no gradio)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/LLaVA-colab/blob/main/LLaVA_7b_8bit_colab.ipynb) | 🌋 LLaVA_7b_8bit_colab 7B (8bit)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/LLaVA-colab/blob/main/LLaVA_7b_colab.ipynb) | 🌋 LLaVA_7b_colab 7B (16bit) (Pro High-RAM 😐 22GB RAM 14GB VRAM)

Expand Down

0 comments on commit 25f16b5

Please sign in to comment.