Skip to content

Commit

Permalink
LLaVA_13b_4bit_caption_colab
Browse files Browse the repository at this point in the history
  • Loading branch information
camenduru authored Oct 14, 2023
1 parent ebe5416 commit 2fb7928
Showing 1 changed file with 134 additions and 0 deletions.
134 changes: 134 additions & 0 deletions LLaVA_13b_4bit_caption_colab.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github"
},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/LLaVA-colab/blob/main/LLaVA_13b_4bit_vanilla_colab.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VjYy0F2gZIPR"
},
"outputs": [],
"source": [
"%cd /content\n",
"!git clone -b v1.0 https://github.com/camenduru/LLaVA\n",
"%cd /content/LLaVA\n",
"!pip install -q gradio .\n",
"\n",
"from transformers import AutoTokenizer, BitsAndBytesConfig\n",
"from llava.model import LlavaLlamaForCausalLM\n",
"import torch\n",
"model_path = \"4bit/llava-v1.5-13b-3GB\"\n",
"kwargs = {\"device_map\": \"auto\"}\n",
"kwargs['load_in_4bit'] = True\n",
"kwargs['quantization_config'] = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_compute_dtype=torch.float16,\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_quant_type='nf4'\n",
")\n",
"model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
"\n",
"vision_tower = model.get_vision_tower()\n",
"if not vision_tower.is_loaded:\n",
" vision_tower.load_model()\n",
"vision_tower.to(device='cuda')\n",
"image_processor = vision_tower.image_processor\n",
"\n",
"import os\n",
"import requests\n",
"from PIL import Image\n",
"from io import BytesIO\n",
"from llava.conversation import conv_templates, SeparatorStyle\n",
"from llava.utils import disable_torch_init\n",
"from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
"from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n",
"from transformers import TextStreamer\n",
"import torch\n",
"\n",
"def caption_image(image_file, prompt):\n",
" if image_file.startswith('http') or image_file.startswith('https'):\n",
" response = requests.get(image_file)\n",
" image = Image.open(BytesIO(response.content)).convert('RGB')\n",
" else:\n",
" image = Image.open(image_file).convert('RGB')\n",
" disable_torch_init()\n",
" conv_mode = \"llava_v0\"\n",
" conv = conv_templates[conv_mode].copy()\n",
" roles = conv.roles\n",
" image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n",
" inp = f\"{roles[0]}: {prompt}\"\n",
" inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n",
" conv.append_message(conv.roles[0], inp)\n",
" conv.append_message(conv.roles[1], None)\n",
" raw_prompt = conv.get_prompt()\n",
" input_ids = tokenizer_image_token(raw_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
" stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
" keywords = [stop_str]\n",
" stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
" with torch.inference_mode():\n",
" output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, \n",
" max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])\n",
" outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n",
" conv.messages[-1][-1] = outputs\n",
" output = outputs.rsplit('</s>', 1)[0]\n",
" return image, output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import locale\n",
"locale.getpreferredencoding = lambda: \"UTF-8\"\n",
"!mkdir /content/images\n",
"!wget --header 'Authorization: Bearer TOKEN_HERE' https://huggingface.co/camenduru/polaroid/resolve/main/style_name_fix.zip\n",
"!unzip style_name_fix.zip -d /content/images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"file_names = os.listdir('/content/images')\n",
"sorted_file_names = sorted(file_names)\n",
"for file_name in sorted_file_names:\n",
" try:\n",
" image, output = caption_image(f'/content/images/{file_name}', 'Describe the image and color details.')\n",
" print(output)\n",
" # image\n",
" except Exception as e:\n",
" print(f\"Error processing {file_name}: {str(e)}\")\n",
" continue"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit 2fb7928

Please sign in to comment.