diff --git a/docs/concepts/models/llms.md b/docs/concepts/models/llms.md new file mode 100644 index 0000000..e4c27fe --- /dev/null +++ b/docs/concepts/models/llms.md @@ -0,0 +1,58 @@ +# LLMs +LLMs are the core components of Omagent. They are responsible for generating text via Large Language Models. + +It is constructed by following parts: +- ```BaseLLM```: The base class for all LLMs, it defines the basic properties and methods for all LLMs. +- ```BaseLLMBackend```: The enhanced class for better using LLMs, you can assemble specific LLMs with different prompt templates and output parsers. +- ```BasePromptTemplate```: The base class for all prompt templates, it defines the input variables and output parser for a prompt template. +- ```BaseOutputParser```: The base class for all output parsers, it defines how to parse the output of an LLM result. + +## Prompt Template +This is a simple way to define a prompt template. +```python +from omagent_core.models.llms.prompt.prompt import PromptTemplate + +# Define a system prompt template +system_prompt = PromptTemplate.from_template("You are a helpful assistant.", role="system") +# Define a user prompt template +user_prompt = PromptTemplate.from_template("Tell me a joke about {{topic}}", role="user") +``` +`topic` is a variable in the user prompt template, it will be replaced by the actual input value. + +## Output Parser +This is a simple way to define a output parser. +```python +from omagent_core.models.llms.prompt.parser import StrParser + +output_parser = StrParser() +``` +`StrParser` is a simple output parser that returns the output as a string. + +## Get LLM Result +This is a simple way to define a LLM request and get the result of an LLM. +1. The worker class should inherit from `BaseWorker` and `BaseLLMBackend`, and define the LLM model in the `prompts` and `llm` field. `OutputParser` is optional, if not defined, the default `StrParser` will be used. +2. Override the `_run` method to define the workflow logic. +```python +def _run(self, *args, **kwargs): + payload = { + "topic": "weather" + } + # 1. use the `infer` method to get the LLM result + chat_complete_res = self.infer(input_list=[payload])[0]["choices"][0]["message"].get("content") + # 2. use the `simple_infer` method to get the LLM result, it's a shortcut for the `infer` method + simple_infer_res = self.simple_infer(topic="weather")["choices"][0]["message"].get("content") + content = chat_complete_res[0]["choices"][0]["message"].get("content") + print(content) + return {'output': content} +``` + +For Multi-Modal LLMs, it's also simple and intuitive. +```python +def _run(self, *args, **kwargs): + payload = { + "topic": ["this image", PIL.Image.Image object, ...] + } + chat_complete_res = self.infer(input_list=[payload])[0]["choices"][0]["message"].get("content") + return {'output': chat_complete_res} +``` +The order of prompts given to the LLM is consistent with the order of elements in the list of variables, resulting in an alternating pattern of text and images. \ No newline at end of file diff --git a/examples/step2_outfit_with_switch/agent/outfit_recommendation/outfit_recommendation.py b/examples/step2_outfit_with_switch/agent/outfit_recommendation/outfit_recommendation.py index 1bcb2ad..3030e57 100644 --- a/examples/step2_outfit_with_switch/agent/outfit_recommendation/outfit_recommendation.py +++ b/examples/step2_outfit_with_switch/agent/outfit_recommendation/outfit_recommendation.py @@ -59,7 +59,8 @@ def _run(self, *args, **kwargs): search_info = self.stm(self.workflow_instance_id)["search_info"] if "search_info" in self.stm(self.workflow_instance_id) else None # Generate outfit recommendations using LLM with weather and user input - chat_complete_res = self.simple_infer(weather=str(search_info), instruction=user_instruct, image='') + image_cache = self.stm(self.workflow_instance_id)['image_cache'] + chat_complete_res = self.simple_infer(weather=str(search_info), instruction=user_instruct, image=image_cache.get('')) # Extract recommendations from LLM response outfit_recommendation = chat_complete_res["choices"][0]["message"]["content"] diff --git a/examples/video_understanding/agent/tools/video_rewinder/rewinder.py b/examples/video_understanding/agent/tools/video_rewinder/rewinder.py index c970eed..ddc2212 100755 --- a/examples/video_understanding/agent/tools/video_rewinder/rewinder.py +++ b/examples/video_understanding/agent/tools/video_rewinder/rewinder.py @@ -81,22 +81,14 @@ def _run( frames, time_stamps = video.get_video_frames((start, end), interval) # self.stm.image_cache.clear() - extracted_frames = [] + payload = [] for i, (frame, time_stamp) in enumerate(zip(frames, time_stamps)): - img_index = f"image_timestamp-{time_stamp}" - extracted_frames.append(time_stamp) - image_cache = self.stm(self.workflow_instance_id).get("image_cache", {}) - if image_cache.get(f"<{img_index}>", None) is None: - image_cache[f"<{img_index}>"] = frame - self.stm(self.workflow_instance_id)["image_cache"] = image_cache - res = self.simple_infer( - image_placeholders="".join( - [f"" for each in extracted_frames] - ) - )["choices"][0]["message"]["content"] + payload.append(f"timestamp_{time_stamp}") + payload.append(frame) + res = self.infer(input_list=[{"timestamp_with_images": payload}])[0]["choices"][0]["message"]["content"] image_contents = json_repair.loads(res) self.stm(self.workflow_instance_id)['image_cache'] = {} - return f"{extracted_frames} described as: {image_contents}." + return f"extracted_frames described as: {image_contents}." async def _arun( self, start_time: float = 0.0, end_time: float = None, number: int = 1 diff --git a/examples/video_understanding/agent/tools/video_rewinder/rewinder_user_prompt.prompt b/examples/video_understanding/agent/tools/video_rewinder/rewinder_user_prompt.prompt index f1e841d..99e9b8a 100644 --- a/examples/video_understanding/agent/tools/video_rewinder/rewinder_user_prompt.prompt +++ b/examples/video_understanding/agent/tools/video_rewinder/rewinder_user_prompt.prompt @@ -1 +1 @@ -Image: {{image_placeholders}} \ No newline at end of file +{{timestamp_with_images}} \ No newline at end of file diff --git a/omagent-core/src/omagent_core/advanced_components/node/conclude/conclude.py b/omagent-core/src/omagent_core/advanced_components/node/conclude/conclude.py index acfd58c..9368df4 100644 --- a/omagent-core/src/omagent_core/advanced_components/node/conclude/conclude.py +++ b/omagent-core/src/omagent_core/advanced_components/node/conclude/conclude.py @@ -52,7 +52,7 @@ def _run(self, agent_task: dict, last_output: str, *args, **kwargs): self.callback.info(agent_id=self.workflow_instance_id, progress=f'Conclude', message=f'{task.get_current_node().task}') chat_complete_res = self.simple_infer( task=task.get_root().task, - result=last_output, + result=str(last_output), img_placeholders="".join(list(self.stm(self.workflow_instance_id).get('image_cache', {}).keys())), ) self.callback.send_answer(agent_id=self.workflow_instance_id, msg=f'Answer: {chat_complete_res["choices"][0]["message"]["content"]}') diff --git a/omagent-core/src/omagent_core/models/llms/azure_gpt.py b/omagent-core/src/omagent_core/models/llms/azure_gpt.py index 1cd8ab1..f1ae79d 100644 --- a/omagent-core/src/omagent_core/models/llms/azure_gpt.py +++ b/omagent-core/src/omagent_core/models/llms/azure_gpt.py @@ -56,24 +56,6 @@ def _call(self, records: List[Message], **kwargs) -> Dict: if self.api_key is None or self.api_key == "": raise ValueError("api_key is required") - if self.stm(self.workflow_instance_id).get('image_cache') is not None and len(self.stm(self.workflow_instance_id)['image_cache']): - for record in records: - record.combine_image_message( - image_cache={ - key: encode_image(value) - for key, value in self.stm.image_cache.items() - } - ) - elif len(kwargs.get("images", [])): - image_cache = {} - for index, each in enumerate(kwargs["images"]): - image_cache[f""] = each - for record in records: - record.combine_image_message( - image_cache={ - key: encode_image(value) for key, value in image_cache.items() - } - ) body = self._msg2req(records) if kwargs.get("tool_choice"): body["tool_choice"] = kwargs["tool_choice"] @@ -104,24 +86,6 @@ async def _acall(self, records: List[Message], **kwargs) -> Dict: if self.api_key is None or self.api_key == "": raise ValueError("api_key is required") - if self.stm(self.workflow_instance_id).get('image_cache') is not None and len(self.stm(self.workflow_instance_id)['image_cache']): - for record in records: - record.combine_image_message( - image_cache={ - key: encode_image(value) - for key, value in self.stm.image_cache.items() - } - ) - elif len(kwargs.get("images", [])): - image_cache = {} - for index, each in enumerate(kwargs["images"]): - image_cache[f""] = each - for record in records: - record.combine_image_message( - image_cache={ - key: encode_image(value) for key, value in image_cache.items() - } - ) body = self._msg2req(records) if kwargs.get("tool_choice"): body["tool_choice"] = kwargs["tool_choice"] diff --git a/omagent-core/src/omagent_core/models/llms/base.py b/omagent-core/src/omagent_core/models/llms/base.py index 0d2abf9..1b3b79b 100644 --- a/omagent-core/src/omagent_core/models/llms/base.py +++ b/omagent-core/src/omagent_core/models/llms/base.py @@ -16,6 +16,9 @@ from .prompt.base import _OUTPUT_PARSER, StrParser from .prompt.parser import BaseOutputParser from .prompt.prompt import PromptTemplate +from collections.abc import Hashable +from PIL import Image +import re T = TypeVar("T", str, dict, list) @@ -149,23 +152,44 @@ def set_llm(cls, llm: Union[BaseLLM, Dict]): raise ValueError("LLM only support dict and BaseLLM object") def prep_prompt( - self, input_list: List[Dict[str, Any]], prompts=None + self, input_list: List[Dict[str, Any]], prompts=None, **kwargs ) -> List[Message]: """Prepare prompts from inputs.""" if prompts is None: prompts = self.prompts + images = [] + if len(kwargs_images:=kwargs.get("images", [])): + images = kwargs_images processed_prompts = [] for inputs in input_list: records = [] for prompt in prompts: - selected_inputs = {k: inputs[k] for k in prompt.input_variables} - prompt_str = prompt.format(**selected_inputs) - records.append(Message(content=prompt_str, role=prompt.role)) + selected_inputs = {k: inputs.get(k, '') for k in prompt.input_variables} + prompt_str = prompt.template + parts = re.split(r"(\{\{.*?\}\})", prompt_str) + formatted_parts = [] + for part in parts: + if part.startswith("{{") and part.endswith("}}"): + part = part[2:-2].strip() + value = selected_inputs[part] + if isinstance(value, (Image.Image, list)): + formatted_parts.extend([value] if isinstance(value, Image.Image) else value) + else: + formatted_parts.append(str(value)) + else: + formatted_parts.append(str(part)) + formatted_parts = formatted_parts[0] if len(formatted_parts) == 1 else formatted_parts + if prompt.role == "system": + records.append(Message.system(formatted_parts)) + elif prompt.role == "user": + records.append(Message.user(formatted_parts)) + if len(images): + records.append(Message.user(images)) processed_prompts.append(records) return processed_prompts def infer(self, input_list: List[Dict[str, Any]], **kwargs) -> List[T]: - prompts = self.prep_prompt(input_list) + prompts = self.prep_prompt(input_list, **kwargs) res = [] for prompt in prompts: output = self.llm.generate(prompt, **kwargs) diff --git a/omagent-core/src/omagent_core/models/llms/openai_gpt.py b/omagent-core/src/omagent_core/models/llms/openai_gpt.py index 7391bff..fa91d8e 100644 --- a/omagent-core/src/omagent_core/models/llms/openai_gpt.py +++ b/omagent-core/src/omagent_core/models/llms/openai_gpt.py @@ -48,24 +48,6 @@ def _call(self, records: List[Message], **kwargs) -> Dict: if self.api_key is None or self.api_key == "": raise ValueError("api_key is required") - if self.stm(self.workflow_instance_id).get('image_cache') is not None and len(self.stm(self.workflow_instance_id)['image_cache']): - for record in records: - record.combine_image_message( - image_cache={ - key: encode_image(value) - for key, value in self.stm(self.workflow_instance_id)['image_cache'].items() - } - ) - elif len(kwargs.get("images", [])): - image_cache = {} - for index, each in enumerate(kwargs["images"]): - image_cache[f""] = each - for record in records: - record.combine_image_message( - image_cache={ - key: encode_image(value) for key, value in image_cache.items() - } - ) body = self._msg2req(records) if kwargs.get("tool_choice"): body["tool_choice"] = kwargs["tool_choice"] @@ -97,24 +79,6 @@ async def _acall(self, records: List[Message], **kwargs) -> Dict: if self.api_key is None or self.api_key == "": raise ValueError("api_key is required") - if self.stm(self.workflow_instance_id).get('image_cache') is not None and len(self.stm(self.workflow_instance_id)['image_cache']): - for record in records: - record.combine_image_message( - image_cache={ - key: encode_image(value) - for key, value in self.stm.image_cache.items() - } - ) - elif len(kwargs.get("images", [])): - image_cache = {} - for index, each in enumerate(kwargs["images"]): - image_cache[f""] = each - for record in records: - record.combine_image_message( - image_cache={ - key: encode_image(value) for key, value in image_cache.items() - } - ) body = self._msg2req(records) if kwargs.get("tool_choice"): body["tool_choice"] = kwargs["tool_choice"] diff --git a/omagent-core/src/omagent_core/models/llms/schemas.py b/omagent-core/src/omagent_core/models/llms/schemas.py index 1d2720c..7260549 100644 --- a/omagent-core/src/omagent_core/models/llms/schemas.py +++ b/omagent-core/src/omagent_core/models/llms/schemas.py @@ -1,8 +1,12 @@ import re -from typing import Dict, List, Optional +from typing import Dict, List, Optional, ClassVar from ..od.schemas import Target - +from itertools import groupby from pydantic import BaseModel, field_validator, model_validator +from PIL import Image +import datetime +import time +from ...utils.general import encode_image from enum import Enum @@ -33,7 +37,6 @@ def validate_detail(cls, detail: str) -> str: ) return detail - class Content(BaseModel): type: str = "text" text: Optional[str] = None @@ -74,26 +77,55 @@ class Message(BaseModel): content: List[Content | Dict] | Content | str objects: List[Target] = [] kwargs: dict = {} + basic_data_types: ClassVar[List[type]] = [str, list, tuple, int, float, bool, datetime.datetime, datetime.time] + + @classmethod + def merge_consecutive_text(cls, content) -> List: + result = [] + current_str = "" + + for part in content: + if isinstance(part, str): + current_str += part + else: + if current_str: + result.append(current_str) + current_str = "" + result.append(part) + + if current_str: # 处理最后的字符串 + result.append(current_str) + + return result @field_validator("content", mode="before") @classmethod def content_validator( cls, content: List[Content | Dict] | Content | str - ) -> List[Content] | Content: + ) -> List[Content] | Content: if isinstance(content, str): return Content(type="text", text=content) elif isinstance(content, list): + # combine str elements in list + content = cls.merge_consecutive_text(content) formatted = [] for c in content: + if not c: + continue if isinstance(c, Content): formatted.append(c) elif isinstance(c, dict): - formatted.append(Content(**c)) - elif isinstance(c, str): - formatted.append(Content(type="text", text=c)) + try: + formatted.append(Content(**c)) + except Exception as e: + formatted.append(Content(type="text", text=str(c))) + elif isinstance(c, Image.Image): + formatted.append(Content(type="image_url", image_url={"url": f"data:image/jpeg;base64,{encode_image(c)}"})) + elif isinstance(c, tuple(cls.basic_data_types)): + formatted.append(Content(type="text", text=str(c))) else: raise ValueError( - "Content list must contain Content objects, strings or dicts." + f"Content list must contain [Content, str, list, dict, PIL.Image], got {type(c)}" ) else: raise ValueError( @@ -101,76 +133,14 @@ def content_validator( ) return formatted - def combine_image_message(self, **kwargs): - if isinstance(self.content, list): - for index, each_content in enumerate(self.content): - if each_content.text is not None: - image_patterns = [ - f"" - for each in re.findall(r"", each_content.text) - ] - # set max_num to 20 - image_patterns = image_patterns[-min(20, len(image_patterns)) :] - else: - image_patterns = [] - if image_patterns: - image_cache = kwargs.get("image_cache") - if len(image_cache) < len(image_patterns): - raise ValueError("Image number is not enough. Please check.") - segments = re.split( - "({})".format("|".join(map(re.escape, image_patterns))), - each_content.text, - ) - segments = [each for each in segments if each.strip()] - modified_content = [] - for segment in segments: - if segment in image_patterns: - modified_content.append( - Content( - type="image_url", - image_url={ - "url": f"data:image/jpeg;base64,{image_cache[segment]}" - }, - ) - ) - else: - modified_content.append(Content(type="text", text=segment)) - self.content[index] = modified_content - self.message_type = MessageType.MIXED - else: - image_patterns = [ - f"" - for each in re.findall(r"", self.content.text) - ] - # set max_num to 20 - image_patterns = image_patterns[-min(20, len(image_patterns)) :] - if image_patterns: - image_cache = kwargs.get("image_cache") - if len(image_cache) < len(image_patterns): - raise ValueError("Image number is not enough. Please check.") - segments = re.split( - "({})".format("|".join(map(re.escape, image_patterns))), - self.content.text, - ) - segments = [each for each in segments if each.strip()] - modified_content = [] - for segment in segments: - if segment in image_patterns: - modified_content.append( - Content( - type="text", - text=f"Image of {re.match(r'', segment).group(1)}", - ) - ) - modified_content.append( - Content( - type="image_url", - image_url={ - "url": f"data:image/jpeg;base64,{image_cache[segment]}" - }, - ) - ) - else: - modified_content.append(Content(type="text", text=segment)) - self.content = modified_content - self.message_type = MessageType.MIXED \ No newline at end of file + @classmethod + def system(cls, content: str | List[str | Dict | Content]) -> "Message": + return cls(role=Role.SYSTEM, content=content) + + @classmethod + def user(cls, content: str | List[str | Dict | Content]) -> "Message": + return cls(role=Role.USER, content=content) + + @classmethod + def assistant(cls, content: str | List[str | Dict | Content]) -> "Message": + return cls(role=Role.ASSISTANT, content=content) \ No newline at end of file