-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'prompt_prefix_as_messages' of github.com:ShoggothAI/mot…
…leycrew into prompt_prefix_as_messages
- Loading branch information
Showing
10 changed files
with
164 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import mimetypes | ||
import os | ||
from typing import Optional | ||
|
||
import requests | ||
|
||
from motleycrew.common import logger, utils as motley_utils | ||
|
||
|
||
def download_image(url: str, file_path: str) -> Optional[str]: | ||
response = requests.get(url, stream=True) | ||
if response.status_code == requests.codes.ok: | ||
try: | ||
content_type = response.headers.get("content-type") | ||
extension = mimetypes.guess_extension(content_type) | ||
except Exception as e: | ||
logger.error("Failed to guess content type: %s", e) | ||
extension = None | ||
|
||
if not extension: | ||
extension = ".png" | ||
|
||
file_path_with_extension = file_path + extension | ||
logger.info("Downloading image %s to %s", url, file_path_with_extension) | ||
|
||
with open(file_path_with_extension, "wb") as f: | ||
for chunk in response: | ||
f.write(chunk) | ||
|
||
return file_path_with_extension | ||
else: | ||
logger.error("Failed to download image. Status code: %s", response.status_code) | ||
|
||
|
||
def download_url_to_directory(url: str, images_directory: str, file_name_length: int = 8) -> str: | ||
os.makedirs(images_directory, exist_ok=True) | ||
file_name = motley_utils.generate_hex_hash(url, length=file_name_length) | ||
file_path = os.path.join(images_directory, file_name) | ||
|
||
file_path_with_extension = download_image(url=url, file_path=file_path).replace(os.sep, "/") | ||
return file_path_with_extension |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from typing import Optional, List | ||
|
||
import replicate | ||
|
||
from langchain.agents import Tool | ||
from langchain_core.pydantic_v1 import BaseModel, Field | ||
|
||
import motleycrew.common.utils as motley_utils | ||
from motleycrew.tools.image.download_image import download_url_to_directory | ||
from motleycrew.tools.tool import MotleyTool | ||
from motleycrew.common import logger | ||
|
||
model_map = { | ||
"sdxl": "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", | ||
"flux-pro": "black-forest-labs/flux-pro", | ||
"flux-dev": "black-forest-labs/flux-dev", | ||
"flux-schnell": "black-forest-labs/flux-schnell", | ||
} | ||
|
||
# Each model has a different set of extra parameters, documented at pages like | ||
# https://replicate.com/black-forest-labs/flux-dev/api/schema | ||
|
||
|
||
def run_model_in_replicate(model_name: str, prompt: str, **kwargs) -> str | List[str]: | ||
if model_name in model_map: | ||
model_name = model_map[model_name] | ||
output = replicate.run(model_name, input={"prompt": prompt, **kwargs}) | ||
return output | ||
|
||
|
||
def run_model_in_replicate_and_save_images( | ||
model_name: str, prompt: str, directory_name: Optional[str] = None, **kwargs | ||
) -> List[str]: | ||
download_urls = run_model_in_replicate(model_name, prompt, **kwargs) | ||
if isinstance(download_urls, str): | ||
download_urls = [download_urls] | ||
if directory_name is None: | ||
logger.info("Images directory is not provided, returning URLs") | ||
return download_urls | ||
out_files = [] | ||
for url in download_urls: | ||
if motley_utils.is_http_url(url): | ||
out_files.append(download_url_to_directory(url, directory_name)) | ||
return out_files | ||
|
||
|
||
class ImageToolInput(BaseModel): | ||
"""Input for the Dall-E tool.""" | ||
|
||
description: str = Field(description="image description") | ||
|
||
|
||
class ReplicateImageGeneratorTool(MotleyTool): | ||
def __init__(self, model_name: str, images_directory: Optional[str] = None, **kwargs): | ||
""" | ||
A tool for generating images from text descriptions using the Replicate API. | ||
:param model_name: one of "sdxl", "flux-pro", "flux-dev", "flux-schnell", or a full model name supported by replicate | ||
:param images_directory: the directory to save the images to | ||
:param kwargs: model-specific parameters, from pages such as https://replicate.com/black-forest-labs/flux-dev/api/schema | ||
""" | ||
self.model_name = model_name | ||
self.kwargs = kwargs | ||
langchain_tool = create_replicate_image_generator_langchain_tool( | ||
model_name=model_name, images_directory=images_directory, **kwargs | ||
) | ||
|
||
super().__init__(langchain_tool) | ||
|
||
|
||
def create_replicate_image_generator_langchain_tool( | ||
model_name: str, images_directory: Optional[str] = None, **kwargs | ||
): | ||
def run_replicate_image_generator(description: str): | ||
return run_model_in_replicate_and_save_images( | ||
model_name=model_name, | ||
prompt=description, | ||
directory_name=images_directory, | ||
**kwargs, | ||
) | ||
|
||
return Tool( | ||
name=f"{model_name}_image_generator", | ||
func=run_replicate_image_generator, | ||
description=f"A wrapper around the {model_name} image generation model. Useful for when you need to generate images from a text description. " | ||
"Input should be an image description.", | ||
args_schema=ImageToolInput, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
image_dir = os.path.join(os.path.expanduser("~"), "images") | ||
tool = ReplicateImageGeneratorTool("flux-pro", image_dir, aspect_ratio="3:2") | ||
output = tool.invoke( | ||
"A beautiful sunset over the mountains, with a dragon flying into the sunset, photorealistic style." | ||
) | ||
print(output) | ||
print("yay!") |
Oops, something went wrong.