Skip to content

Commit

Permalink
LLaVA_13b_4bit_vanilla_colab
Browse files Browse the repository at this point in the history
  • Loading branch information
camenduru authored Oct 14, 2023
1 parent b082fcc commit ba59cb1
Showing 1 changed file with 35 additions and 77 deletions.
112 changes: 35 additions & 77 deletions LLaVA_13b_4bit_vanilla_colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,14 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VjYy0F2gZIPR"
},
"metadata": {},
"outputs": [],
"source": [
"%cd /content\n",
"!git clone -b v1.0 https://github.com/camenduru/LLaVA\n",
"%cd /content/LLaVA\n",
"!pip install -q gradio .\n",
"\n",
"!pip install -e ."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, BitsAndBytesConfig\n",
"from llava.model import LlavaLlamaForCausalLM\n",
"import torch\n",
Expand All @@ -49,89 +39,57 @@
"if not vision_tower.is_loaded:\n",
" vision_tower.load_model()\n",
"vision_tower.to(device='cuda')\n",
"image_processor = vision_tower.image_processor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image_processor = vision_tower.image_processor\n",
"\n",
"import os\n",
"import requests\n",
"from PIL import Image\n",
"from io import BytesIO\n",
"\n",
"def load_image(image_file):\n",
" if image_file.startswith('http') or image_file.startswith('https'):\n",
" response = requests.get(image_file)\n",
" image = Image.open(BytesIO(response.content)).convert('RGB')\n",
" else:\n",
" image = Image.open(image_file).convert('RGB')\n",
" return image\n",
"\n",
"from llava.conversation import conv_templates, SeparatorStyle\n",
"from llava.utils import disable_torch_init\n",
"\n",
"disable_torch_init()\n",
"conv_mode = \"llava_v0\"\n",
"conv = conv_templates[conv_mode].copy()\n",
"roles = conv.roles\n",
"\n",
"image = load_image(\"https://llava-vl.github.io/static/images/view.jpg\")\n",
"image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n",
"\n",
"from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
"from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n",
"from transformers import TextStreamer\n",
"import torch\n",
"\n",
"while True:\n",
" try:\n",
" inp = input(f\"{roles[0]}: \")\n",
" except EOFError:\n",
" inp = \"\"\n",
" if not inp:\n",
" print(\"exit...\")\n",
" break\n",
"\n",
" print(f\"{roles[1]}: \", end=\"\")\n",
"\n",
" if image is not None:\n",
" # first message\n",
" if model.config.mm_use_im_start_end:\n",
" inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n",
" else:\n",
" inp = DEFAULT_IMAGE_TOKEN + '\\n' + inp\n",
" conv.append_message(conv.roles[0], inp)\n",
" image = None\n",
"def caption_image(image_file, prompt):\n",
" if image_file.startswith('http') or image_file.startswith('https'):\n",
" response = requests.get(image_file)\n",
" image = Image.open(BytesIO(response.content)).convert('RGB')\n",
" else:\n",
" # later messages\n",
" conv.append_message(conv.roles[0], inp)\n",
" image = Image.open(image_file).convert('RGB')\n",
" disable_torch_init()\n",
" conv_mode = \"llava_v0\"\n",
" conv = conv_templates[conv_mode].copy()\n",
" roles = conv.roles\n",
" image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n",
" inp = f\"{roles[0]}: {prompt}\"\n",
" inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n",
" conv.append_message(conv.roles[0], inp)\n",
" conv.append_message(conv.roles[1], None)\n",
" prompt = conv.get_prompt()\n",
"\n",
" input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
" raw_prompt = conv.get_prompt()\n",
" input_ids = tokenizer_image_token(raw_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
" stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
" keywords = [stop_str]\n",
" stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
" streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
"\n",
" with torch.inference_mode():\n",
" output_ids = model.generate(\n",
" input_ids,\n",
" images=image_tensor,\n",
" do_sample=True,\n",
" temperature=0.2,\n",
" max_new_tokens=1024,\n",
" streamer=streamer,\n",
" use_cache=True,\n",
" stopping_criteria=[stopping_criteria])\n",
"\n",
" output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, \n",
" max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])\n",
" outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n",
" conv.messages[-1][-1] = outputs\n",
"\n",
" print(\"\\n\", {\"prompt\": prompt, \"outputs\": outputs}, \"\\n\")"
" output = outputs.rsplit('</s>', 1)[0]\n",
" return image, output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image, output = caption_image(f'https://llava-vl.github.io/static/images/view.jpg', 'Describe the image and color details.')\n",
"print(output)\n",
"image"
]
}
],
Expand Down

0 comments on commit ba59cb1

Please sign in to comment.