Skip to content

Commit

Permalink
Merge pull request #100 from XeonHis/develop/v0.2.1
Browse files Browse the repository at this point in the history
Develop/v0.2.1: Optimize LLM use experience
  • Loading branch information
panregedit authored Nov 29, 2024
2 parents c13262e + add544c commit fa90053
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 174 deletions.
58 changes: 58 additions & 0 deletions docs/concepts/models/llms.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# LLMs
LLMs are the core components of Omagent. They are responsible for generating text via Large Language Models.

It is constructed by following parts:
- ```BaseLLM```: The base class for all LLMs, it defines the basic properties and methods for all LLMs.
- ```BaseLLMBackend```: The enhanced class for better using LLMs, you can assemble specific LLMs with different prompt templates and output parsers.
- ```BasePromptTemplate```: The base class for all prompt templates, it defines the input variables and output parser for a prompt template.
- ```BaseOutputParser```: The base class for all output parsers, it defines how to parse the output of an LLM result.

## Prompt Template
This is a simple way to define a prompt template.
```python
from omagent_core.models.llms.prompt.prompt import PromptTemplate

# Define a system prompt template
system_prompt = PromptTemplate.from_template("You are a helpful assistant.", role="system")
# Define a user prompt template
user_prompt = PromptTemplate.from_template("Tell me a joke about {{topic}}", role="user")
```
`topic` is a variable in the user prompt template, it will be replaced by the actual input value.

## Output Parser
This is a simple way to define a output parser.
```python
from omagent_core.models.llms.prompt.parser import StrParser

output_parser = StrParser()
```
`StrParser` is a simple output parser that returns the output as a string.

## Get LLM Result
This is a simple way to define a LLM request and get the result of an LLM.
1. The worker class should inherit from `BaseWorker` and `BaseLLMBackend`, and define the LLM model in the `prompts` and `llm` field. `OutputParser` is optional, if not defined, the default `StrParser` will be used.
2. Override the `_run` method to define the workflow logic.
```python
def _run(self, *args, **kwargs):
payload = {
"topic": "weather"
}
# 1. use the `infer` method to get the LLM result
chat_complete_res = self.infer(input_list=[payload])[0]["choices"][0]["message"].get("content")
# 2. use the `simple_infer` method to get the LLM result, it's a shortcut for the `infer` method
simple_infer_res = self.simple_infer(topic="weather")["choices"][0]["message"].get("content")
content = chat_complete_res[0]["choices"][0]["message"].get("content")
print(content)
return {'output': content}
```

For Multi-Modal LLMs, it's also simple and intuitive.
```python
def _run(self, *args, **kwargs):
payload = {
"topic": ["this image", PIL.Image.Image object, ...]
}
chat_complete_res = self.infer(input_list=[payload])[0]["choices"][0]["message"].get("content")
return {'output': chat_complete_res}
```
The order of prompts given to the LLM is consistent with the order of elements in the list of variables, resulting in an alternating pattern of text and images.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def _run(self, *args, **kwargs):
search_info = self.stm(self.workflow_instance_id)["search_info"] if "search_info" in self.stm(self.workflow_instance_id) else None

# Generate outfit recommendations using LLM with weather and user input
chat_complete_res = self.simple_infer(weather=str(search_info), instruction=user_instruct, image='<image_0>')
image_cache = self.stm(self.workflow_instance_id)['image_cache']
chat_complete_res = self.simple_infer(weather=str(search_info), instruction=user_instruct, image=image_cache.get('<image_0>'))

# Extract recommendations from LLM response
outfit_recommendation = chat_complete_res["choices"][0]["message"]["content"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,14 @@ def _run(
frames, time_stamps = video.get_video_frames((start, end), interval)

# self.stm.image_cache.clear()
extracted_frames = []
payload = []
for i, (frame, time_stamp) in enumerate(zip(frames, time_stamps)):
img_index = f"image_timestamp-{time_stamp}"
extracted_frames.append(time_stamp)
image_cache = self.stm(self.workflow_instance_id).get("image_cache", {})
if image_cache.get(f"<{img_index}>", None) is None:
image_cache[f"<{img_index}>"] = frame
self.stm(self.workflow_instance_id)["image_cache"] = image_cache
res = self.simple_infer(
image_placeholders="".join(
[f"<image_timestamp-{each}>" for each in extracted_frames]
)
)["choices"][0]["message"]["content"]
payload.append(f"timestamp_{time_stamp}")
payload.append(frame)
res = self.infer(input_list=[{"timestamp_with_images": payload}])[0]["choices"][0]["message"]["content"]
image_contents = json_repair.loads(res)
self.stm(self.workflow_instance_id)['image_cache'] = {}
return f"{extracted_frames} described as: {image_contents}."
return f"extracted_frames described as: {image_contents}."

async def _arun(
self, start_time: float = 0.0, end_time: float = None, number: int = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Image: {{image_placeholders}}
{{timestamp_with_images}}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _run(self, agent_task: dict, last_output: str, *args, **kwargs):
self.callback.info(agent_id=self.workflow_instance_id, progress=f'Conclude', message=f'{task.get_current_node().task}')
chat_complete_res = self.simple_infer(
task=task.get_root().task,
result=last_output,
result=str(last_output),
img_placeholders="".join(list(self.stm(self.workflow_instance_id).get('image_cache', {}).keys())),
)
self.callback.send_answer(agent_id=self.workflow_instance_id, msg=f'Answer: {chat_complete_res["choices"][0]["message"]["content"]}')
Expand Down
36 changes: 0 additions & 36 deletions omagent-core/src/omagent_core/models/llms/azure_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,6 @@ def _call(self, records: List[Message], **kwargs) -> Dict:
if self.api_key is None or self.api_key == "":
raise ValueError("api_key is required")

if self.stm(self.workflow_instance_id).get('image_cache') is not None and len(self.stm(self.workflow_instance_id)['image_cache']):
for record in records:
record.combine_image_message(
image_cache={
key: encode_image(value)
for key, value in self.stm.image_cache.items()
}
)
elif len(kwargs.get("images", [])):
image_cache = {}
for index, each in enumerate(kwargs["images"]):
image_cache[f"<image_{index}>"] = each
for record in records:
record.combine_image_message(
image_cache={
key: encode_image(value) for key, value in image_cache.items()
}
)
body = self._msg2req(records)
if kwargs.get("tool_choice"):
body["tool_choice"] = kwargs["tool_choice"]
Expand Down Expand Up @@ -104,24 +86,6 @@ async def _acall(self, records: List[Message], **kwargs) -> Dict:
if self.api_key is None or self.api_key == "":
raise ValueError("api_key is required")

if self.stm(self.workflow_instance_id).get('image_cache') is not None and len(self.stm(self.workflow_instance_id)['image_cache']):
for record in records:
record.combine_image_message(
image_cache={
key: encode_image(value)
for key, value in self.stm.image_cache.items()
}
)
elif len(kwargs.get("images", [])):
image_cache = {}
for index, each in enumerate(kwargs["images"]):
image_cache[f"<image_{index}>"] = each
for record in records:
record.combine_image_message(
image_cache={
key: encode_image(value) for key, value in image_cache.items()
}
)
body = self._msg2req(records)
if kwargs.get("tool_choice"):
body["tool_choice"] = kwargs["tool_choice"]
Expand Down
34 changes: 29 additions & 5 deletions omagent-core/src/omagent_core/models/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from .prompt.base import _OUTPUT_PARSER, StrParser
from .prompt.parser import BaseOutputParser
from .prompt.prompt import PromptTemplate
from collections.abc import Hashable
from PIL import Image
import re

T = TypeVar("T", str, dict, list)

Expand Down Expand Up @@ -149,23 +152,44 @@ def set_llm(cls, llm: Union[BaseLLM, Dict]):
raise ValueError("LLM only support dict and BaseLLM object")

def prep_prompt(
self, input_list: List[Dict[str, Any]], prompts=None
self, input_list: List[Dict[str, Any]], prompts=None, **kwargs
) -> List[Message]:
"""Prepare prompts from inputs."""
if prompts is None:
prompts = self.prompts
images = []
if len(kwargs_images:=kwargs.get("images", [])):
images = kwargs_images
processed_prompts = []
for inputs in input_list:
records = []
for prompt in prompts:
selected_inputs = {k: inputs[k] for k in prompt.input_variables}
prompt_str = prompt.format(**selected_inputs)
records.append(Message(content=prompt_str, role=prompt.role))
selected_inputs = {k: inputs.get(k, '') for k in prompt.input_variables}
prompt_str = prompt.template
parts = re.split(r"(\{\{.*?\}\})", prompt_str)
formatted_parts = []
for part in parts:
if part.startswith("{{") and part.endswith("}}"):
part = part[2:-2].strip()
value = selected_inputs[part]
if isinstance(value, (Image.Image, list)):
formatted_parts.extend([value] if isinstance(value, Image.Image) else value)
else:
formatted_parts.append(str(value))
else:
formatted_parts.append(str(part))
formatted_parts = formatted_parts[0] if len(formatted_parts) == 1 else formatted_parts
if prompt.role == "system":
records.append(Message.system(formatted_parts))
elif prompt.role == "user":
records.append(Message.user(formatted_parts))
if len(images):
records.append(Message.user(images))
processed_prompts.append(records)
return processed_prompts

def infer(self, input_list: List[Dict[str, Any]], **kwargs) -> List[T]:
prompts = self.prep_prompt(input_list)
prompts = self.prep_prompt(input_list, **kwargs)
res = []
for prompt in prompts:
output = self.llm.generate(prompt, **kwargs)
Expand Down
36 changes: 0 additions & 36 deletions omagent-core/src/omagent_core/models/llms/openai_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,6 @@ def _call(self, records: List[Message], **kwargs) -> Dict:
if self.api_key is None or self.api_key == "":
raise ValueError("api_key is required")

if self.stm(self.workflow_instance_id).get('image_cache') is not None and len(self.stm(self.workflow_instance_id)['image_cache']):
for record in records:
record.combine_image_message(
image_cache={
key: encode_image(value)
for key, value in self.stm(self.workflow_instance_id)['image_cache'].items()
}
)
elif len(kwargs.get("images", [])):
image_cache = {}
for index, each in enumerate(kwargs["images"]):
image_cache[f"<image_{index}>"] = each
for record in records:
record.combine_image_message(
image_cache={
key: encode_image(value) for key, value in image_cache.items()
}
)
body = self._msg2req(records)
if kwargs.get("tool_choice"):
body["tool_choice"] = kwargs["tool_choice"]
Expand Down Expand Up @@ -97,24 +79,6 @@ async def _acall(self, records: List[Message], **kwargs) -> Dict:
if self.api_key is None or self.api_key == "":
raise ValueError("api_key is required")

if self.stm(self.workflow_instance_id).get('image_cache') is not None and len(self.stm(self.workflow_instance_id)['image_cache']):
for record in records:
record.combine_image_message(
image_cache={
key: encode_image(value)
for key, value in self.stm.image_cache.items()
}
)
elif len(kwargs.get("images", [])):
image_cache = {}
for index, each in enumerate(kwargs["images"]):
image_cache[f"<image_{index}>"] = each
for record in records:
record.combine_image_message(
image_cache={
key: encode_image(value) for key, value in image_cache.items()
}
)
body = self._msg2req(records)
if kwargs.get("tool_choice"):
body["tool_choice"] = kwargs["tool_choice"]
Expand Down
Loading

0 comments on commit fa90053

Please sign in to comment.