-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalttexter.py
127 lines (107 loc) · 4.57 KB
/
alttexter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import logging
import mimetypes
import os
import time
from typing import List, Optional, Tuple
from langchain import callbacks
from langchain.callbacks.tracers.langchain import wait_for_all_tracers
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langsmith import Client
from schema import AlttexterResponse, ImageAltText
def determine_llm() -> ChatOpenAI:
"""Determine which LLM to use based on environment variables."""
model_env = os.getenv("ALTTEXTER_MODEL")
if model_env == 'openai':
return ChatOpenAI(verbose=True, temperature=0, model="gpt-4-vision-preview", max_tokens=4096)
elif model_env == 'openai_azure':
return AzureChatOpenAI(verbose=True, temperature=0, openai_api_version="2024-02-15-preview",
azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT"),
model="vision-preview", max_tokens=4096)
else:
raise ValueError(f"Unsupported model specified: {model_env}")
def alttexter(input_text: str, images: dict, image_urls: List[str]) -> Tuple[List[ImageAltText], Optional[str]]:
"""
Processes input text and images to generate alt text and title attributes.
Args:
input_text (str): Article text.
images (dict): Base64 encoded images.
image_urls (List[str]): Image URLs.
Returns:
Tuple[AlttexterResponse, str]: Generated alt texts and optional tracing URL.
"""
llm = determine_llm()
content = [
{
"type": "text",
"text": f"""ARTICLE: {input_text}"""
}
]
# Process images and add to content
for image_name, base64_string in images.items():
mime_type, _ = mimetypes.guess_type(image_name)
if not mime_type:
logging.warning(f"Could not determine MIME type for image: {image_name}")
continue
image_entry = {
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_string}",
"detail": "auto",
}
}
content.append(image_entry)
# Add image URLs to content
for url in image_urls:
image_entry = {
"type": "image_url",
"image_url": {
"url": url,
"detail": "auto",
}
}
content.append(image_entry)
parser = PydanticOutputParser(pydantic_object=AlttexterResponse)
all_image_identifiers = list(images.keys()) + image_urls
messages = ChatPromptTemplate.from_messages(
[
SystemMessage(
content='''You are a world-class expert at generating concise alternative text and title attributes for images defined in technical articles written in markdown format.\nFor each image in the article use a contextual understanding of the article text and the image itself to generate a concise alternative text and title attribute.\n{format_instructions}'''.format(format_instructions=parser.get_format_instructions())),
HumanMessage(content=content),
HumanMessage(
content=f"Tip: List of file names of images including their paths or URLs: {str(all_image_identifiers)}"
),
]
)
alttexts = None
run_url = None
tracing_enabled = os.getenv("LANGCHAIN_TRACING_V2", "").lower() == "true"
if tracing_enabled:
client = Client()
try:
with callbacks.collect_runs() as cb:
alttexts = llm.invoke(messages.format_messages())
# Ensure that all tracers complete their execution
wait_for_all_tracers()
if alttexts:
# Get public URL for run
run_id = cb.traced_runs[0].id
time.sleep(2)
client.share_run(run_id)
run_url = client.read_run_shared_link(run_id)
except Exception as e:
logging.error(f"Error during LLM invocation with tracing: {str(e)}")
else:
try:
alttexts = llm.invoke(messages.format_messages())
except Exception as e:
logging.error(f"Error during LLM invocation without tracing: {str(e)}")
if alttexts:
try:
alttexts_parsed = parser.parse(alttexts.content)
return alttexts_parsed, run_url
except Exception as e:
logging.error(f"Error parsing LLM response: {str(e)}")
return None, run_url