From b9b2196721199c4d9616f30cc919cd9f59f980ad Mon Sep 17 00:00:00 2001 From: Bidhan Roy Date: Thu, 2 May 2024 22:18:11 +0200 Subject: [PATCH 1/2] bagel mech tool --- .../valory/customs/bagel_request/__init__.py | 20 ++++ .../customs/bagel_request/bagel_request.py | 96 +++++++++++++++++++ .../customs/bagel_request/component.yaml | 17 ++++ 3 files changed, 133 insertions(+) create mode 100644 packages/valory/customs/bagel_request/__init__.py create mode 100644 packages/valory/customs/bagel_request/bagel_request.py create mode 100644 packages/valory/customs/bagel_request/component.yaml diff --git a/packages/valory/customs/bagel_request/__init__.py b/packages/valory/customs/bagel_request/__init__.py new file mode 100644 index 00000000..93e63121 --- /dev/null +++ b/packages/valory/customs/bagel_request/__init__.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------------ +# +# Copyright 2023-2024 Valory AG +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ + +"""This module contains the bet amount per threshold strategy.""" diff --git a/packages/valory/customs/bagel_request/bagel_request.py b/packages/valory/customs/bagel_request/bagel_request.py new file mode 100644 index 00000000..6509f56c --- /dev/null +++ b/packages/valory/customs/bagel_request/bagel_request.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------------ +# +# Copyright 2023-2024 Valory AG +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ +"""Contains the job definitions""" + +from typing import Any, Dict, Optional, Tuple + +import bagelml + +PREFIX = "bagel-" +ENGINES = { + "search": ["bagel-search"], + "write": ["bagel-write"], +} +ALLOWED_TOOLS = [PREFIX + value for values in ENGINES.values() for value in values] + +DEFAULT_BAGEL_SETTINGS = { + "top_k": 10, +} + + +def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any, Any]: + """Run the task""" + api_key = kwargs["api_keys"]["bagel"] + tool = kwargs["tool"] + prompt = kwargs["prompt"] + top_k = kwargs.get("top_k", DEFAULT_BAGEL_SETTINGS["top_k"]) + + if tool not in ALLOWED_TOOLS: + return ( + f"Tool {tool} is not in the list of supported tools.", + None, + None, + None, + ) + + if not api_key: + return ( + "Missing Bagel API key.", + None, + None, + None, + ) + + client = bagelml.Client(api_key=api_key) + + engine = tool.replace(PREFIX, "") + + if engine == "search": + # Get or create a cluster + cluster = client.get_or_create_cluster(name="my-cluster", embedding_model="bagel-text") + + # Search the cluster for documents related to the prompt + response = cluster.find(query_texts=[prompt], n_results=top_k) + documents, distances, metadatas = map(lambda l: list([item for sublist in l for item in sublist]), + (response['documents'], response['distances'], response['metadatas'])) + + return { + "documents": documents, + "distances": distances, + "metadatas": metadatas + }, None, None, None + elif engine == "write": + # Get or create a cluster + cluster = client.get_or_create_cluster(name="my-cluster", embedding_model="bagel-text") + + # Add documents to the cluster + cluster.add( + documents=[prompt], + metadatas=[{"source": "user_input"}], + ids=[f"doc_{len(cluster)}"] + ) + + return "Document added successfully.", None, None, None + else: + return ( + f"Unsupported engine: {engine}", + None, + None, + None, + ) diff --git a/packages/valory/customs/bagel_request/component.yaml b/packages/valory/customs/bagel_request/component.yaml new file mode 100644 index 00000000..a149d381 --- /dev/null +++ b/packages/valory/customs/bagel_request/component.yaml @@ -0,0 +1,17 @@ +name: stability_ai_request +author: bagel +version: 0.1.0 +type: custom +description: A tool that runs a prompt against StabilityAI. +license: Apache-2.0 +aea_version: '>=1.0.0, <2.0.0' +fingerprint: + __init__.py: bafybeibbn67pnrrm4qm3n3kbelvbs3v7fjlrjniywmw2vbizarippidtvi + stabilityai_request.py: bafybeifi7i5syencul3nvplbnvporb4x2brr7ugosnvn6uyiaejsqetq7u +fingerprint_ignore_patterns: [] +entry_point: stabilityai_request.py +callable: run +dependencies: + requests: {} + tiktoken: + version: ==0.5.1 From 52c336b1fb2f203ea9dab4a0b4471a0d8241d83b Mon Sep 17 00:00:00 2001 From: Bidhan Roy Date: Wed, 8 May 2024 12:54:56 +0200 Subject: [PATCH 2/2] bagel component.yaml --- packages/valory/customs/bagel_request/component.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/valory/customs/bagel_request/component.yaml b/packages/valory/customs/bagel_request/component.yaml index a149d381..9d211966 100644 --- a/packages/valory/customs/bagel_request/component.yaml +++ b/packages/valory/customs/bagel_request/component.yaml @@ -1,15 +1,15 @@ -name: stability_ai_request +name: bagel_request author: bagel version: 0.1.0 type: custom -description: A tool that runs a prompt against StabilityAI. +description: A tool that uses bagel LLM finetuning. license: Apache-2.0 aea_version: '>=1.0.0, <2.0.0' fingerprint: - __init__.py: bafybeibbn67pnrrm4qm3n3kbelvbs3v7fjlrjniywmw2vbizarippidtvi - stabilityai_request.py: bafybeifi7i5syencul3nvplbnvporb4x2brr7ugosnvn6uyiaejsqetq7u + __init__.py: ab1cdb459bcaa25e6f1d12cd36e7bc2d70bb197f065ff7f662756dbbba23e97a + bagel_request.py: 728169fcf4596f5cb0d100593f503c9131858db58f8a5f8aea0c29d03763238b fingerprint_ignore_patterns: [] -entry_point: stabilityai_request.py +entry_point: bagel_request.py callable: run dependencies: requests: {}