Skip to content

Commit

Permalink
Replicate image api (#69)
Browse files Browse the repository at this point in the history
* minor tweaks

* Add a tool to call image models like flux and stable diffusion, through replicate.com

* Fix directory name in example code snippet
  • Loading branch information
ZmeiGorynych authored Aug 23, 2024
1 parent 087b94a commit 48c983c
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 43 deletions.
2 changes: 2 additions & 0 deletions motleycrew/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from .abstract_parent import MotleyAgentAbstractParent
from .output_handler import MotleyOutputHandler
from .parent import MotleyAgentParent
from .langchain import LangchainMotleyAgent

__all__ = [
"MotleyAgentAbstractParent",
"MotleyAgentParent",
"MotleyOutputHandler",
"LangchainMotleyAgent",
]
3 changes: 2 additions & 1 deletion motleycrew/agents/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory, GetSessionHistoryCallable
from langchain_core.prompts.chat import ChatPromptTemplate

from motleycrew.agents.mixins import LangchainOutputHandlingAgentMixin
from motleycrew.agents.parent import MotleyAgentParent
Expand All @@ -21,7 +22,7 @@ def __init__(
self,
description: str | None = None,
name: str | None = None,
prompt_prefix: str | None = None,
prompt_prefix: str | ChatPromptTemplate | None = None,
agent_factory: MotleyAgentFactory[AgentExecutor] | None = None,
tools: Sequence[MotleySupportedTool] | None = None,
output_handler: MotleySupportedTool | None = None,
Expand Down
2 changes: 1 addition & 1 deletion motleycrew/agents/langchain/tool_calling_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
tools: Sequence[MotleySupportedTool],
description: str | None = None,
name: str | None = None,
prompt_prefix: str | None = None,
prompt_prefix: str | ChatPromptTemplate | None = None,
prompt: ChatPromptTemplate | None = None,
chat_history: bool | GetSessionHistoryCallable = True,
output_handler: MotleySupportedTool | None = None,
Expand Down
2 changes: 1 addition & 1 deletion motleycrew/agents/parent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class MotleyAgentParent(MotleyAgentAbstractParent, ABC):

def __init__(
self,
prompt_prefix: str | None = None,
prompt_prefix: str | ChatPromptTemplate | None = None,
description: str | None = None,
name: str | None = None,
agent_factory: MotleyAgentFactory | None = None,
Expand Down
42 changes: 2 additions & 40 deletions motleycrew/tools/image/dall_e.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import mimetypes
import os
from typing import Optional

import requests
from langchain.agents import Tool
from langchain.prompts import PromptTemplate
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
Expand All @@ -12,34 +9,9 @@
from motleycrew.common import LLMFramework
from motleycrew.common import logger
from motleycrew.common.llms import init_llm
from motleycrew.tools.image.download_image import download_url_to_directory
from motleycrew.tools.tool import MotleyTool


def download_image(url: str, file_path: str) -> Optional[str]:
response = requests.get(url, stream=True)
if response.status_code == requests.codes.ok:
try:
content_type = response.headers.get("content-type")
extension = mimetypes.guess_extension(content_type)
except Exception as e:
logger.error("Failed to guess content type: %s", e)
extension = None

if not extension:
extension = ".png"

file_path_with_extension = file_path + extension
logger.info("Downloading image %s to %s", url, file_path_with_extension)

with open(file_path_with_extension, "wb") as f:
for chunk in response:
f.write(chunk)

return file_path_with_extension
else:
logger.error("Failed to download image. Status code: %s", response.status_code)


DEFAULT_REFINE_PROMPT = """Generate a detailed DALL-E prompt to generate an image
based on the following description:
```{text}```
Expand Down Expand Up @@ -133,17 +105,7 @@ def run_dalle_and_save_images(
return

if images_directory:
os.makedirs(images_directory, exist_ok=True)
file_paths = []
for url in urls:
file_name = motley_utils.generate_hex_hash(url, length=file_name_length)
file_path = os.path.join(images_directory, file_name)

file_path_with_extension = download_image(url=url, file_path=file_path).replace(
os.sep, "/"
)
file_paths.append(file_path_with_extension)
return file_paths
return [download_url_to_directory(url, images_directory, file_name_length) for url in urls]
else:
logger.info("Images directory is not provided, returning URLs")
return urls
Expand Down
41 changes: 41 additions & 0 deletions motleycrew/tools/image/download_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import mimetypes
import os
from typing import Optional

import requests

from motleycrew.common import logger, utils as motley_utils


def download_image(url: str, file_path: str) -> Optional[str]:
response = requests.get(url, stream=True)
if response.status_code == requests.codes.ok:
try:
content_type = response.headers.get("content-type")
extension = mimetypes.guess_extension(content_type)
except Exception as e:
logger.error("Failed to guess content type: %s", e)
extension = None

if not extension:
extension = ".png"

file_path_with_extension = file_path + extension
logger.info("Downloading image %s to %s", url, file_path_with_extension)

with open(file_path_with_extension, "wb") as f:
for chunk in response:
f.write(chunk)

return file_path_with_extension
else:
logger.error("Failed to download image. Status code: %s", response.status_code)


def download_url_to_directory(url: str, images_directory: str, file_name_length: int = 8) -> str:
os.makedirs(images_directory, exist_ok=True)
file_name = motley_utils.generate_hex_hash(url, length=file_name_length)
file_path = os.path.join(images_directory, file_name)

file_path_with_extension = download_image(url=url, file_path=file_path).replace(os.sep, "/")
return file_path_with_extension
97 changes: 97 additions & 0 deletions motleycrew/tools/image/replicate_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Optional, List

import replicate

from langchain.agents import Tool
from langchain_core.pydantic_v1 import BaseModel, Field

import motleycrew.common.utils as motley_utils
from motleycrew.tools.image.download_image import download_url_to_directory
from motleycrew.tools.tool import MotleyTool
from motleycrew.common import logger

model_map = {
"sdxl": "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
"flux-pro": "black-forest-labs/flux-pro",
"flux-dev": "black-forest-labs/flux-dev",
"flux-schnell": "black-forest-labs/flux-schnell",
}

# Each model has a different set of extra parameters, documented at pages like
# https://replicate.com/black-forest-labs/flux-dev/api/schema


def run_model_in_replicate(model_name: str, prompt: str, **kwargs) -> str | List[str]:
if model_name in model_map:
model_name = model_map[model_name]
output = replicate.run(model_name, input={"prompt": prompt, **kwargs})
return output


def run_model_in_replicate_and_save_images(
model_name: str, prompt: str, directory_name: Optional[str] = None, **kwargs
) -> List[str]:
download_urls = run_model_in_replicate(model_name, prompt, **kwargs)
if isinstance(download_urls, str):
download_urls = [download_urls]
if directory_name is None:
logger.info("Images directory is not provided, returning URLs")
return download_urls
out_files = []
for url in download_urls:
if motley_utils.is_http_url(url):
out_files.append(download_url_to_directory(url, directory_name))
return out_files


class ImageToolInput(BaseModel):
"""Input for the Dall-E tool."""

description: str = Field(description="image description")


class ReplicateImageGeneratorTool(MotleyTool):
def __init__(self, model_name: str, images_directory: Optional[str] = None, **kwargs):
"""
A tool for generating images from text descriptions using the Replicate API.
:param model_name: one of "sdxl", "flux-pro", "flux-dev", "flux-schnell", or a full model name supported by replicate
:param images_directory: the directory to save the images to
:param kwargs: model-specific parameters, from pages such as https://replicate.com/black-forest-labs/flux-dev/api/schema
"""
self.model_name = model_name
self.kwargs = kwargs
langchain_tool = create_replicate_image_generator_langchain_tool(
model_name=model_name, images_directory=images_directory, **kwargs
)

super().__init__(langchain_tool)


def create_replicate_image_generator_langchain_tool(
model_name: str, images_directory: Optional[str] = None, **kwargs
):
def run_replicate_image_generator(description: str):
return run_model_in_replicate_and_save_images(
model_name=model_name,
prompt=description,
directory_name=images_directory,
**kwargs,
)

return Tool(
name=f"{model_name}_image_generator",
func=run_replicate_image_generator,
description=f"A wrapper around the {model_name} image generation model. Useful for when you need to generate images from a text description. "
"Input should be an image description.",
args_schema=ImageToolInput,
)


if __name__ == "__main__":
image_dir = os.path.join(os.path.expanduser("~"), "images")
tool = ReplicateImageGeneratorTool("flux-pro", image_dir, aspect_ratio="3:2")
output = tool.invoke(
"A beautiful sunset over the mountains, with a dragon flying into the sunset, photorealistic style."
)
print(output)
print("yay!")

0 comments on commit 48c983c

Please sign in to comment.