diff --git a/README.md b/README.md index c6a1cd25..343ce2be 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,6 @@ [![PyPI - Version](https://img.shields.io/pypi/v/motleycrew)](https://pypi.org/project/motleycrew/) [![CI](https://github.com/ShoggothAI/motleycrew/actions/workflows/build.yml/badge.svg)](https://github.com/ShoggothAI/motleycrew/actions/workflows/build.yml) -[![GitHub commit activity (branch)](https://img.shields.io/github/commit-activity/w/ShoggothAI/motleycrew)](https://github.com/ShoggothAI/motleycrew/commits/main/) [Website](https://motleycrew.ai) •︎ [Documentation](https://motleycrew.readthedocs.io) diff --git a/motleycrew/agents/__init__.py b/motleycrew/agents/__init__.py index e240d90a..b5d8f69f 100644 --- a/motleycrew/agents/__init__.py +++ b/motleycrew/agents/__init__.py @@ -3,9 +3,11 @@ from .abstract_parent import MotleyAgentAbstractParent from .output_handler import MotleyOutputHandler from .parent import MotleyAgentParent +from .langchain import LangchainMotleyAgent __all__ = [ "MotleyAgentAbstractParent", "MotleyAgentParent", "MotleyOutputHandler", + "LangchainMotleyAgent", ] diff --git a/motleycrew/agents/langchain/langchain.py b/motleycrew/agents/langchain/langchain.py index cfd394e1..08d37ccd 100644 --- a/motleycrew/agents/langchain/langchain.py +++ b/motleycrew/agents/langchain/langchain.py @@ -6,6 +6,7 @@ from langchain_core.chat_history import InMemoryChatMessageHistory from langchain_core.runnables import RunnableConfig from langchain_core.runnables.history import RunnableWithMessageHistory, GetSessionHistoryCallable +from langchain_core.prompts.chat import ChatPromptTemplate from motleycrew.agents.mixins import LangchainOutputHandlingAgentMixin from motleycrew.agents.parent import MotleyAgentParent @@ -21,7 +22,7 @@ def __init__( self, description: str | None = None, name: str | None = None, - prompt_prefix: str | None = None, + prompt_prefix: str | ChatPromptTemplate | None = None, agent_factory: MotleyAgentFactory[AgentExecutor] | None = None, tools: Sequence[MotleySupportedTool] | None = None, output_handler: MotleySupportedTool | None = None, diff --git a/motleycrew/agents/langchain/tool_calling_react.py b/motleycrew/agents/langchain/tool_calling_react.py index 579cc415..02e5655f 100644 --- a/motleycrew/agents/langchain/tool_calling_react.py +++ b/motleycrew/agents/langchain/tool_calling_react.py @@ -110,7 +110,7 @@ def __init__( tools: Sequence[MotleySupportedTool], description: str | None = None, name: str | None = None, - prompt_prefix: str | None = None, + prompt_prefix: str | ChatPromptTemplate | None = None, prompt: ChatPromptTemplate | None = None, chat_history: bool | GetSessionHistoryCallable = True, output_handler: MotleySupportedTool | None = None, diff --git a/motleycrew/agents/parent.py b/motleycrew/agents/parent.py index 5b7d4ced..4532866c 100644 --- a/motleycrew/agents/parent.py +++ b/motleycrew/agents/parent.py @@ -56,7 +56,7 @@ class MotleyAgentParent(MotleyAgentAbstractParent, ABC): def __init__( self, - prompt_prefix: str | None = None, + prompt_prefix: str | ChatPromptTemplate | None = None, description: str | None = None, name: str | None = None, agent_factory: MotleyAgentFactory | None = None, diff --git a/motleycrew/tools/image/dall_e.py b/motleycrew/tools/image/dall_e.py index 7ef09b6f..7cd4f41d 100644 --- a/motleycrew/tools/image/dall_e.py +++ b/motleycrew/tools/image/dall_e.py @@ -1,8 +1,5 @@ -import mimetypes -import os from typing import Optional -import requests from langchain.agents import Tool from langchain.prompts import PromptTemplate from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper @@ -12,34 +9,9 @@ from motleycrew.common import LLMFramework from motleycrew.common import logger from motleycrew.common.llms import init_llm +from motleycrew.tools.image.download_image import download_url_to_directory from motleycrew.tools.tool import MotleyTool - -def download_image(url: str, file_path: str) -> Optional[str]: - response = requests.get(url, stream=True) - if response.status_code == requests.codes.ok: - try: - content_type = response.headers.get("content-type") - extension = mimetypes.guess_extension(content_type) - except Exception as e: - logger.error("Failed to guess content type: %s", e) - extension = None - - if not extension: - extension = ".png" - - file_path_with_extension = file_path + extension - logger.info("Downloading image %s to %s", url, file_path_with_extension) - - with open(file_path_with_extension, "wb") as f: - for chunk in response: - f.write(chunk) - - return file_path_with_extension - else: - logger.error("Failed to download image. Status code: %s", response.status_code) - - DEFAULT_REFINE_PROMPT = """Generate a detailed DALL-E prompt to generate an image based on the following description: ```{text}``` @@ -133,17 +105,7 @@ def run_dalle_and_save_images( return if images_directory: - os.makedirs(images_directory, exist_ok=True) - file_paths = [] - for url in urls: - file_name = motley_utils.generate_hex_hash(url, length=file_name_length) - file_path = os.path.join(images_directory, file_name) - - file_path_with_extension = download_image(url=url, file_path=file_path).replace( - os.sep, "/" - ) - file_paths.append(file_path_with_extension) - return file_paths + return [download_url_to_directory(url, images_directory, file_name_length) for url in urls] else: logger.info("Images directory is not provided, returning URLs") return urls diff --git a/motleycrew/tools/image/download_image.py b/motleycrew/tools/image/download_image.py new file mode 100644 index 00000000..4c0f6629 --- /dev/null +++ b/motleycrew/tools/image/download_image.py @@ -0,0 +1,41 @@ +import mimetypes +import os +from typing import Optional + +import requests + +from motleycrew.common import logger, utils as motley_utils + + +def download_image(url: str, file_path: str) -> Optional[str]: + response = requests.get(url, stream=True) + if response.status_code == requests.codes.ok: + try: + content_type = response.headers.get("content-type") + extension = mimetypes.guess_extension(content_type) + except Exception as e: + logger.error("Failed to guess content type: %s", e) + extension = None + + if not extension: + extension = ".png" + + file_path_with_extension = file_path + extension + logger.info("Downloading image %s to %s", url, file_path_with_extension) + + with open(file_path_with_extension, "wb") as f: + for chunk in response: + f.write(chunk) + + return file_path_with_extension + else: + logger.error("Failed to download image. Status code: %s", response.status_code) + + +def download_url_to_directory(url: str, images_directory: str, file_name_length: int = 8) -> str: + os.makedirs(images_directory, exist_ok=True) + file_name = motley_utils.generate_hex_hash(url, length=file_name_length) + file_path = os.path.join(images_directory, file_name) + + file_path_with_extension = download_image(url=url, file_path=file_path).replace(os.sep, "/") + return file_path_with_extension diff --git a/motleycrew/tools/image/replicate_tool.py b/motleycrew/tools/image/replicate_tool.py new file mode 100644 index 00000000..81f9fc46 --- /dev/null +++ b/motleycrew/tools/image/replicate_tool.py @@ -0,0 +1,97 @@ +from typing import Optional, List + +import replicate + +from langchain.agents import Tool +from langchain_core.pydantic_v1 import BaseModel, Field + +import motleycrew.common.utils as motley_utils +from motleycrew.tools.image.download_image import download_url_to_directory +from motleycrew.tools.tool import MotleyTool +from motleycrew.common import logger + +model_map = { + "sdxl": "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", + "flux-pro": "black-forest-labs/flux-pro", + "flux-dev": "black-forest-labs/flux-dev", + "flux-schnell": "black-forest-labs/flux-schnell", +} + +# Each model has a different set of extra parameters, documented at pages like +# https://replicate.com/black-forest-labs/flux-dev/api/schema + + +def run_model_in_replicate(model_name: str, prompt: str, **kwargs) -> str | List[str]: + if model_name in model_map: + model_name = model_map[model_name] + output = replicate.run(model_name, input={"prompt": prompt, **kwargs}) + return output + + +def run_model_in_replicate_and_save_images( + model_name: str, prompt: str, directory_name: Optional[str] = None, **kwargs +) -> List[str]: + download_urls = run_model_in_replicate(model_name, prompt, **kwargs) + if isinstance(download_urls, str): + download_urls = [download_urls] + if directory_name is None: + logger.info("Images directory is not provided, returning URLs") + return download_urls + out_files = [] + for url in download_urls: + if motley_utils.is_http_url(url): + out_files.append(download_url_to_directory(url, directory_name)) + return out_files + + +class ImageToolInput(BaseModel): + """Input for the Dall-E tool.""" + + description: str = Field(description="image description") + + +class ReplicateImageGeneratorTool(MotleyTool): + def __init__(self, model_name: str, images_directory: Optional[str] = None, **kwargs): + """ + A tool for generating images from text descriptions using the Replicate API. + :param model_name: one of "sdxl", "flux-pro", "flux-dev", "flux-schnell", or a full model name supported by replicate + :param images_directory: the directory to save the images to + :param kwargs: model-specific parameters, from pages such as https://replicate.com/black-forest-labs/flux-dev/api/schema + """ + self.model_name = model_name + self.kwargs = kwargs + langchain_tool = create_replicate_image_generator_langchain_tool( + model_name=model_name, images_directory=images_directory, **kwargs + ) + + super().__init__(langchain_tool) + + +def create_replicate_image_generator_langchain_tool( + model_name: str, images_directory: Optional[str] = None, **kwargs +): + def run_replicate_image_generator(description: str): + return run_model_in_replicate_and_save_images( + model_name=model_name, + prompt=description, + directory_name=images_directory, + **kwargs, + ) + + return Tool( + name=f"{model_name}_image_generator", + func=run_replicate_image_generator, + description=f"A wrapper around the {model_name} image generation model. Useful for when you need to generate images from a text description. " + "Input should be an image description.", + args_schema=ImageToolInput, + ) + + +if __name__ == "__main__": + image_dir = os.path.join(os.path.expanduser("~"), "images") + tool = ReplicateImageGeneratorTool("flux-pro", image_dir, aspect_ratio="3:2") + output = tool.invoke( + "A beautiful sunset over the mountains, with a dragon flying into the sunset, photorealistic style." + ) + print(output) + print("yay!") diff --git a/poetry.lock b/poetry.lock index f88697d5..700ca882 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1260,7 +1260,7 @@ files = [ name = "distro" version = "1.9.0" description = "Distro - an OS platform information API" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, @@ -1652,12 +1652,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" @@ -2616,7 +2616,7 @@ i18n = ["Babel (>=2.7)"] name = "jiter" version = "0.4.2" description = "Fast iterable JSON parser." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "jiter-0.4.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:c2b003ff58d14f5e182b875acd5177b2367245c19a03be9a2230535d296f7550"}, @@ -2938,8 +2938,8 @@ langchain-core = ">=0.2.32,<0.3.0" langchain-text-splitters = ">=0.2.0,<0.3.0" langsmith = ">=0.1.17,<0.2.0" numpy = [ - {version = ">=1,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, + {version = ">=1,<2", markers = "python_version < \"3.12\""}, ] pydantic = ">=1,<3" PyYAML = ">=5.3" @@ -2986,8 +2986,8 @@ langchain = ">=0.2.13,<0.3.0" langchain-core = ">=0.2.30,<0.3.0" langsmith = ">=0.1.0,<0.2.0" numpy = [ - {version = ">=1,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, + {version = ">=1,<2", markers = "python_version < \"3.12\""}, ] PyYAML = ">=5.3" requests = ">=2,<3" @@ -3010,8 +3010,8 @@ jsonpatch = ">=1.33,<2.0" langsmith = ">=0.1.75,<0.2.0" packaging = ">=23.2,<25" pydantic = [ - {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, + {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, ] PyYAML = ">=5.3" tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" @@ -3036,7 +3036,7 @@ langchain-core = ">=0.2.10,<0.3.0" name = "langchain-openai" version = "0.1.22" description = "An integration package connecting OpenAI and LangChain" -optional = true +optional = false python-versions = "<4.0,>=3.8.1" files = [ {file = "langchain_openai-0.1.22-py3-none-any.whl", hash = "sha256:e184ab867a30f803dc210a388537186b1b670a33d910a7e0fa4e0329d3b6c654"}, @@ -3093,8 +3093,8 @@ files = [ httpx = ">=0.23.0,<1" orjson = ">=3.9.14,<4.0.0" pydantic = [ - {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, + {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, ] requests = ">=2,<3" @@ -4177,7 +4177,7 @@ sympy = "*" name = "openai" version = "1.42.0" description = "The official Python library for the openai API" -optional = true +optional = false python-versions = ">=3.7.1" files = [ {file = "openai-1.42.0-py3-none-any.whl", hash = "sha256:dc91e0307033a4f94931e5d03cc3b29b9717014ad5e73f9f2051b6cb5eda4d80"}, @@ -4508,9 +4508,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -5056,8 +5056,8 @@ files = [ annotated-types = ">=0.4.0" pydantic-core = "2.20.1" typing-extensions = [ - {version = ">=4.6.1", markers = "python_version < \"3.13\""}, {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, + {version = ">=4.6.1", markers = "python_version < \"3.13\""}, ] [package.extras] @@ -5580,8 +5580,8 @@ grpcio = ">=1.41.0" grpcio-tools = ">=1.41.0" httpx = {version = ">=0.20.0", extras = ["http2"]} numpy = [ - {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, {version = ">=1.26", markers = "python_version >= \"3.12\""}, + {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, ] portalocker = ">=2.7.0,<3.0.0" pydantic = ">=1.10.8" @@ -5610,7 +5610,7 @@ rpds-py = ">=0.7.0" name = "regex" version = "2023.12.25" description = "Alternative regular expression module, to replace re." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "regex-2023.12.25-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0694219a1d54336fd0445ea382d49d36882415c0134ee1e8332afd1529f0baa5"}, @@ -6405,7 +6405,7 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] name = "tiktoken" version = "0.7.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "tiktoken-0.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485f3cc6aba7c6b6ce388ba634fbba656d9ee27f766216f45146beb4ac18b25f"}, @@ -6623,7 +6623,7 @@ files = [ name = "tqdm" version = "4.66.5" description = "Fast, Extensible Progress Meter" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"}, @@ -7247,4 +7247,4 @@ pglast = ["pglast"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<=3.13" -content-hash = "af4acdd5127059f3337c8bef03cc9dbd1b1913ec68e506c8be027cb7e0466cc9" +content-hash = "67dac7ce7e5c4e8d74269db4fbdccf30433ad0f824d1d9b5de31f45be7256dca" diff --git a/pyproject.toml b/pyproject.toml index 5e642446..3d69d3ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ curl-cffi = "^0.6.4" httpx = "^0.27.0" motleycache = "^0.0.4" pglast = {version = "^6.2", optional = true} +langchain-openai = "^0.1.22" [tool.poetry.group.dev.dependencies] black = "^24.2.0"