-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bagel mech tool #226
Bagel mech tool #226
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# -*- coding: utf-8 -*- | ||
# ------------------------------------------------------------------------------ | ||
# | ||
# Copyright 2023-2024 Valory AG | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# ------------------------------------------------------------------------------ | ||
|
||
"""This module contains the bet amount per threshold strategy.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# -*- coding: utf-8 -*- | ||
# ------------------------------------------------------------------------------ | ||
# | ||
# Copyright 2023-2024 Valory AG | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# ------------------------------------------------------------------------------ | ||
"""Contains the job definitions""" | ||
|
||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import bagelml | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. needs to be added to project deps |
||
|
||
PREFIX = "bagel-" | ||
ENGINES = { | ||
"search": ["bagel-search"], | ||
"write": ["bagel-write"], | ||
} | ||
ALLOWED_TOOLS = [PREFIX + value for values in ENGINES.values() for value in values] | ||
|
||
DEFAULT_BAGEL_SETTINGS = { | ||
"top_k": 10, | ||
} | ||
|
||
|
||
def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any, Any]: | ||
"""Run the task""" | ||
api_key = kwargs["api_keys"]["bagel"] | ||
tool = kwargs["tool"] | ||
prompt = kwargs["prompt"] | ||
top_k = kwargs.get("top_k", DEFAULT_BAGEL_SETTINGS["top_k"]) | ||
|
||
if tool not in ALLOWED_TOOLS: | ||
return ( | ||
f"Tool {tool} is not in the list of supported tools.", | ||
None, | ||
None, | ||
None, | ||
) | ||
|
||
if not api_key: | ||
return ( | ||
"Missing Bagel API key.", | ||
None, | ||
None, | ||
None, | ||
) | ||
|
||
client = bagelml.Client(api_key=api_key) | ||
|
||
engine = tool.replace(PREFIX, "") | ||
|
||
if engine == "search": | ||
# Get or create a cluster | ||
cluster = client.get_or_create_cluster(name="my-cluster", embedding_model="bagel-text") | ||
|
||
# Search the cluster for documents related to the prompt | ||
response = cluster.find(query_texts=[prompt], n_results=top_k) | ||
documents, distances, metadatas = map(lambda l: list([item for sublist in l for item in sublist]), | ||
(response['documents'], response['distances'], response['metadatas'])) | ||
|
||
return { | ||
"documents": documents, | ||
"distances": distances, | ||
"metadatas": metadatas | ||
}, None, None, None | ||
Comment on lines
+73
to
+77
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First element in the returned value is expected to be string. |
||
elif engine == "write": | ||
# Get or create a cluster | ||
cluster = client.get_or_create_cluster(name="my-cluster", embedding_model="bagel-text") | ||
|
||
# Add documents to the cluster | ||
cluster.add( | ||
documents=[prompt], | ||
metadatas=[{"source": "user_input"}], | ||
ids=[f"doc_{len(cluster)}"] | ||
) | ||
|
||
return "Document added successfully.", None, None, None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remember that this is what will get returned as a response to this request. Is this what was intended? |
||
else: | ||
return ( | ||
f"Unsupported engine: {engine}", | ||
None, | ||
None, | ||
None, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
name: bagel_request | ||
author: bagel | ||
version: 0.1.0 | ||
type: custom | ||
description: A tool that uses bagel LLM finetuning. | ||
license: Apache-2.0 | ||
aea_version: '>=1.0.0, <2.0.0' | ||
fingerprint: | ||
__init__.py: ab1cdb459bcaa25e6f1d12cd36e7bc2d70bb197f065ff7f662756dbbba23e97a | ||
bagel_request.py: 728169fcf4596f5cb0d100593f503c9131858db58f8a5f8aea0c29d03763238b | ||
fingerprint_ignore_patterns: [] | ||
entry_point: bagel_request.py | ||
callable: run | ||
dependencies: | ||
requests: {} | ||
tiktoken: | ||
version: ==0.5.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small issue, but this is technically expected to be 2024 only. Same for the file above.