-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathgit_captioner.py
101 lines (79 loc) · 3.15 KB
/
git_captioner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) Hello Robot, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in the root directory
# of this source tree.
#
# Some code may be adapted from other open-source works with their respective licenses. Original
# license information maybe found below, if so.
from typing import Optional, Union
import click
import torch
from numpy import ndarray
from PIL import Image
from torch import Tensor
from transformers import AutoModelForCausalLM, AutoProcessor
from .base_captioner import BaseCaptioner
class GitCaptioner(BaseCaptioner):
"""Image captioner using GIT (Generative Image-to-text Transformer) model."""
def __init__(self, max_length: int = 30, num_beams: int = 4, device: Optional[str] = None):
"""Initialize the GIT image captioner.
Args:
max_length (int, optional): Maximum length of the generated caption. Defaults to 30.
num_beams (int, optional): Number of beams for beam search. Defaults to 4.
device (str, optional): Device to run the model on. Defaults to None (auto-detect).
"""
super(GitCaptioner, self).__init__()
self.max_length = max_length
self.num_beams = num_beams
if device is None:
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self._device = torch.device(device)
# Create models
self.processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
self.model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco").to(
self._device
)
def caption_image(self, image: Union[ndarray, Tensor, Image.Image]) -> str:
"""Generate a caption for the given image.
Args:
image (Union[ndarray, Tensor, Image.Image]): The input image.
Returns:
str: The generated caption.
"""
if isinstance(image, Image.Image):
pil_image = image
else:
if isinstance(image, Tensor):
_image = image.cpu().numpy()
else:
_image = image
pil_image = Image.fromarray(_image)
# Preprocess the image
inputs = self.processor(images=pil_image, return_tensors="pt").to(self._device)
# Generate caption
generated_ids = self.model.generate(
pixel_values=inputs.pixel_values,
max_length=self.max_length,
num_beams=self.num_beams,
do_sample=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
)
# Decode the output ids to text
caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return caption
@click.command()
@click.option("--image_path", default="object.png", help="Path to image file")
def main(image_path: str):
captioner = GitCaptioner()
# Load image from file
image = Image.open(image_path)
# Generate caption
caption = captioner.caption_image(image)
# Print caption
print(caption)
if __name__ == "__main__":
main()