Skip to content

Commit

Permalink
Code format optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
panregedit committed Sep 11, 2024
1 parent aad7024 commit 8127b61
Show file tree
Hide file tree
Showing 18 changed files with 130 additions and 72 deletions.
17 changes: 12 additions & 5 deletions engine/node/video_qa/conqueror/conqueror.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
from omagent_core.utils.env import EnvVar
from omagent_core.utils.registry import registry
from pydantic import Field
from tenacity import (retry, retry_if_exception_message, stop_after_attempt,
stop_after_delay)
from tenacity import (
retry,
retry_if_exception_message,
stop_after_attempt,
stop_after_delay,
)

CURRENT_PATH = Path(__file__).parents[0]

Expand Down Expand Up @@ -66,7 +70,10 @@ def _run(self, args: BaseInterface, ltm: LTM) -> Tuple[BaseInterface, str]:
}
self.callback.send_block(chat_structure)
chat_complete_body = {
"video_meta": {"video_duration_seconds(s)": self.stm.video.stream.duration.get_seconds(), "frame_rate": self.stm.video.stream.frame_rate},
"video_meta": {
"video_duration_seconds(s)": self.stm.video.stream.duration.get_seconds(),
"frame_rate": self.stm.video.stream.frame_rate,
},
"video_summary": self.stm.video_summary,
"task": task.task,
"tools": self.tool_manager.generate_prompt(),
Expand Down Expand Up @@ -293,8 +300,8 @@ async def _arun(self, args: BaseInterface, ltm: LTM) -> Tuple[BaseInterface, str

def _extract_from_result(self, result: str) -> dict:
try:
pattern = r'```json\s*(\{(?:.|\s)*?\})\s*```'
result = result.replace('\n', '')
pattern = r"```json\s*(\{(?:.|\s)*?\})\s*```"
result = result.replace("\n", "")
match = re.search(pattern, result, re.DOTALL)
if match:
return json.loads(match.group(1))
Expand Down
12 changes: 8 additions & 4 deletions engine/node/video_qa/divider/divider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
from omagent_core.utils.env import EnvVar
from omagent_core.utils.registry import registry
from pydantic import Field
from tenacity import (retry, retry_if_exception_message, stop_after_attempt,
stop_after_delay)
from tenacity import (
retry,
retry_if_exception_message,
stop_after_attempt,
stop_after_delay,
)

CURRENT_PATH = Path(__file__).parents[0]

Expand Down Expand Up @@ -121,8 +125,8 @@ async def _arun(self, args: DnCInterface, ltm: LTM) -> Tuple[DnCInterface, str]:

def _extract_from_result(self, result: str) -> dict:
try:
pattern = r'```json\s*(\{(?:.|\s)*?\})\s*```'
result = result.replace('\n', '')
pattern = r"```json\s*(\{(?:.|\s)*?\})\s*```"
result = result.replace("\n", "")
match = re.search(pattern, result, re.DOTALL)
if match:
return json.loads(match.group(1))
Expand Down
4 changes: 2 additions & 2 deletions engine/tools/face_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ def _run(self):
names.update([item.label for item in anno])

return f"Recognized {len(names)} faces: {', '.join(names)}"

async def _arun(self):
return self._run()
return self._run()
90 changes: 57 additions & 33 deletions engine/tools/ovd_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from typing import List

import requests
from PIL import Image

from omagent_core.core.tool_system.base import ArgSchema, BaseModelTool
from omagent_core.schemas.dev import Target
from omagent_core.utils.general import encode_image
from omagent_core.utils.registry import registry
from PIL import Image
from scenedetect import FrameTimecode

CURRENT_PATH = Path(__file__).parents[0]
Expand All @@ -21,78 +20,103 @@
"labels": {
"type": "string",
"description": "Labels the object detection tool will use to detect objects in the image, split by comma.",
"required": True
"required": True,
},
}


@registry.register_tool()
class ObjectDetection(BaseModelTool):
args_schema: ArgSchema = ArgSchema(**ARGSCHEMA)
description: str = ("Object detection tool, which can detect any objects and add visual prompting(bounding box and label) to the image."
"Tasks like object counting, specific object detection, etc. must use this tool.")
description: str = (
"Object detection tool, which can detect any objects and add visual prompting(bounding box and label) to the image."
"Tasks like object counting, specific object detection, etc. must use this tool."
)
ovd_endpoint: str
model_id: str = 'OmDet-Turbo_tiny_SWIN_T'
model_id: str = "OmDet-Turbo_tiny_SWIN_T"

class Config:
"""Configuration for this pydantic object."""

protected_namespaces = ()

def _run(
self, timestamps: str, labels: str
) -> str:
def _run(self, timestamps: str, labels: str) -> str:
if self.ovd_endpoint is None or self.ovd_endpoint == "":
raise ValueError("ovd_endpoint is required.")
timestamps = timestamps.split(',')
timestamps = timestamps.split(",")
imgs_pil = []
for each_time_stamp in timestamps:
if self.stm.image_cache.get(f'<image_timestamp-{float(each_time_stamp)}>', None) is None:
if (
self.stm.image_cache.get(
f"<image_timestamp-{float(each_time_stamp)}>", None
)
is None
):
frames, time_stamps = self.stm.video.get_video_frames(
(FrameTimecode(timecode=float(each_time_stamp), fps=self.stm.video.stream.frame_rate),
FrameTimecode(timecode=float(each_time_stamp) + 1, fps=self.stm.video.stream.frame_rate)), self.stm.video.stream.frame_rate)
[self.stm.image_cache.update({f'<image_timestamp-{each_img_name}>': each_frame}) for each_frame, each_img_name in zip(frames, time_stamps)]
(
FrameTimecode(
timecode=float(each_time_stamp),
fps=self.stm.video.stream.frame_rate,
),
FrameTimecode(
timecode=float(each_time_stamp) + 1,
fps=self.stm.video.stream.frame_rate,
),
),
self.stm.video.stream.frame_rate,
)
[
self.stm.image_cache.update(
{f"<image_timestamp-{each_img_name}>": each_frame}
)
for each_frame, each_img_name in zip(frames, time_stamps)
]
# timestamps = [f'<image_timestamp-{each_img_name}' for each_img_name in time_stamps]
imgs_pil = [each_frame for each_frame in frames]
else:
imgs_pil.append(self.stm.image_cache[f'<image_timestamp-{float(each_time_stamp)}>'])
imgs_pil.append(
self.stm.image_cache[f"<image_timestamp-{float(each_time_stamp)}>"]
)

infer_targets = self.infer(imgs_pil, {"labels": labels.split(',')})
infer_targets = self.infer(imgs_pil, {"labels": labels.split(",")})
for img_name, img, infer_target in zip(timestamps, imgs_pil, infer_targets):
self.stm.image_cache[f'<image_timestamp-{img_name}>'] = self.visual_prompting(img, infer_target)
self.stm.image_cache[f"<image_timestamp-{img_name}>"] = (
self.visual_prompting(img, infer_target)
)

return f"OVD tool has detected objects in timestamps of {timestamps} and update image."

async def _arun(
self, timestamps: str, labels: str
) -> str:
async def _arun(self, timestamps: str, labels: str) -> str:
return self._run(timestamps, labels)

def infer(self, images: List[Image.Image], kwargs) -> List[List[Target]]:
labels = kwargs.get("labels", [])
ovd_payload = {
"model_id": self.model_id,
"data": [
encode_image(img) for img in images
],
"data": [encode_image(img) for img in images],
"src_type": "base64",
"task": f"Detect {','.join(labels)}.",
"labels": labels,
"threshold": 0.3,
}
res = requests.post(self.ovd_endpoint, json=ovd_payload)
if res.status_code != 200:
raise ValueError(f"OVD tool failed to detect objects in the images. {res.text}")
raise ValueError(
f"OVD tool failed to detect objects in the images. {res.text}"
)
res = res.json()
targets = []
for img in res['objects']:
for img in res["objects"]:
current_img_targets = []
for bbox in img:
x0 = bbox['xmin']
y0 = bbox['ymin']
x1 = bbox['xmax']
y1 = bbox['ymax']
conf = bbox['conf']
label = bbox['label']
current_img_targets.append(Target(bbox=[x0, y0, x1, y1], conf=conf, label=label))
x0 = bbox["xmin"]
y0 = bbox["ymin"]
x1 = bbox["xmax"]
y1 = bbox["ymax"]
conf = bbox["conf"]
label = bbox["label"]
current_img_targets.append(
Target(bbox=[x0, y0, x1, y1], conf=conf, label=label)
)
targets.append(current_img_targets)
return targets
10 changes: 8 additions & 2 deletions engine/tools/rewinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def _run(
end = FrameTimecode(timecode=end_time, fps=video.stream.frame_rate)

if start_time == end_time:
frames, time_stamps = video.get_video_frames((start, end + 1), video.stream.frame_rate)
frames, time_stamps = video.get_video_frames(
(start, end + 1), video.stream.frame_rate
)
else:
interval = int((end.get_frames() - start.get_frames()) / number)
frames, time_stamps = video.get_video_frames((start, end), interval)
Expand All @@ -95,7 +97,11 @@ def _run(
extracted_frames.append(time_stamp)
if self.stm.image_cache.get(f"<{img_index}>", None) is None:
self.stm.image_cache[f"<{img_index}>"] = frame
res = self.simple_infer(image_placeholders=''.join([f'<image_timestamp-{each}>' for each in extracted_frames]))['choices'][0]['message']['content']
res = self.simple_infer(
image_placeholders="".join(
[f"<image_timestamp-{each}>" for each in extracted_frames]
)
)["choices"][0]["message"]["content"]
image_contents = self._extract_from_result(res)
return f"{extracted_frames} described as: {image_contents}."

Expand Down
2 changes: 1 addition & 1 deletion engine/video_process/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ async def ainfer(self, audio: AudioSegment) -> dict:
response_format=self.response_format,
language=NOT_GIVEN if self.lang is None else self.lang,
)
return trans.to_dict()
return trans.to_dict()
6 changes: 3 additions & 3 deletions omagent-core/src/omagent_core/core/llm/azure_gpt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sysconfig
from datetime import datetime
from typing import Dict, List, Any
from typing import Any, Dict, List

import geocoder
from openai import AsyncAzureOpenAI, AzureOpenAI
Expand Down Expand Up @@ -32,13 +32,13 @@ class AzureGPTLLM(BaseLLM):
max_tokens: int = 2048
use_default_sys_prompt: bool = True
response_format: str = "text"

class Config:
"""Configuration for this pydantic object."""

protected_namespaces = ()
extra = "allow"

def __init__(self, /, **data: Any) -> None:
super().__init__(**data)
self.client = AzureOpenAI(
Expand Down
6 changes: 3 additions & 3 deletions omagent-core/src/omagent_core/core/llm/openai_gpt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sysconfig
from datetime import datetime
from typing import Dict, List, Any
from typing import Any, Dict, List

import geocoder
from openai import AsyncOpenAI, OpenAI
Expand Down Expand Up @@ -31,13 +31,13 @@ class OpenaiGPTLLM(BaseLLM):
max_tokens: int = 2048
use_default_sys_prompt: bool = True
response_format: str = "text"

class Config:
"""Configuration for this pydantic object."""

protected_namespaces = ()
extra = "allow"

def __init__(self, /, **data: Any) -> None:
super().__init__(**data)
self.client = OpenAI(api_key=self.api_key, base_url=self.endpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@

from colorama import Fore, Style
from pydantic import Field
from tenacity import (retry, retry_if_exception_message, stop_after_attempt,
stop_after_delay)
from tenacity import (
retry,
retry_if_exception_message,
stop_after_attempt,
stop_after_delay,
)

from .....handlers.data_handler.ltm import LTM
from .....schemas.base import BaseInterface
Expand Down
10 changes: 7 additions & 3 deletions omagent-core/src/omagent_core/core/node/dnc/divider/divider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
from typing import List, Tuple

from pydantic import Field
from tenacity import (retry, retry_if_exception_message, stop_after_attempt,
stop_after_delay)
from tenacity import (
retry,
retry_if_exception_message,
stop_after_attempt,
stop_after_delay,
)

from .....handlers.data_handler.ltm import LTM
from .....utils.env import EnvVar
Expand Down Expand Up @@ -57,7 +61,7 @@ def _run(self, args: DnCInterface, ltm: LTM) -> Tuple[DnCInterface, str]:
parent_task=task.task,
uplevel_tasks=task.parent.sibling_info() if task.parent else [],
former_results=args.last_output,
tools=self.tool_manager.generate_prompt()
tools=self.tool_manager.generate_prompt(),
)
chat_complete_res = self._extract_from_result(chat_complete_res)
if chat_complete_res.get("tasks"):
Expand Down
10 changes: 7 additions & 3 deletions omagent-core/src/omagent_core/core/node/misc/rescue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from colorama import Fore, Style
from omagent_core.core.node.dnc.schemas import TaskStatus
from pydantic import Field
from tenacity import (retry, retry_if_exception_message, stop_after_attempt,
stop_after_delay)
from tenacity import (
retry,
retry_if_exception_message,
stop_after_attempt,
stop_after_delay,
)

from ....core.llm.base import BaseLLMBackend
from ....core.node.base import BaseDecider
Expand Down Expand Up @@ -112,4 +116,4 @@ async def _arun(self, args: DnCInterface, ltm: LTM) -> Tuple[DnCInterface, str]:
args.task.status = TaskStatus.RUNNING
return args, "failure"
else:
return args, "failure"
return args, "failure"
8 changes: 6 additions & 2 deletions omagent-core/src/omagent_core/core/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
from pydantic import model_validator

from ...utils.registry import registry
from .base import (DEFAULT_FORMATTER_MAPPING, BasePromptTemplate,
_get_jinja2_variables_from_template, check_valid_template)
from .base import (
DEFAULT_FORMATTER_MAPPING,
BasePromptTemplate,
_get_jinja2_variables_from_template,
check_valid_template,
)


@registry.register_prompt()
Expand Down
2 changes: 1 addition & 1 deletion omagent-core/src/omagent_core/core/tool_system/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,4 +294,4 @@ def _run(self):
self.memory_handler.execute_sql()

async def _arun(self):
self.memory_handler.execute_sql()
self.memory_handler.execute_sql()
1 change: 0 additions & 1 deletion omagent-core/src/omagent_core/core/tool_system/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,4 +356,3 @@ async def aexecute_task(self, task, related_info=None, function=None):
}
self.callback.send_block(toolcall_failed_structure)
return "failed", str(error)

1 change: 1 addition & 0 deletions omagent-core/src/omagent_core/schemas/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC

# from sqlalchemy import Column, Text
from datetime import datetime
from typing import Any, Dict, Optional
Expand Down
Loading

0 comments on commit 8127b61

Please sign in to comment.