Skip to content

Commit

Permalink
add ollama type response
Browse files Browse the repository at this point in the history
add ollama supported formatted response
  • Loading branch information
DoraDong-2023 committed Jun 22, 2024
1 parent 54561a1 commit 43787f0
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 31 deletions.
1 change: 1 addition & 0 deletions src/deploy/cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def cli_demo():
os.mkdir('tmp')
device= 'cuda' if torch.cuda.is_available() else 'cpu'
model = Model(logger,device)
print(model.user_states)
print(Fore.GREEN + "Welcome to BioMANIA CLI Demo!" + Style.RESET_ALL)
print(Fore.BLUE + "[Would you like to see some examples to learn how to interact with the bot?](https://github.com/batmen-lab/BioMANIA/tree/main/examples)" + Style.RESET_ALL)
libs = ["scanpy", "squidpy", "ehrapy", "snapatac2"]
Expand Down
48 changes: 23 additions & 25 deletions src/deploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def color_text(text, color):
def label_sentence(sentence, parameters_dict):
import re

colors = ['red', 'purple', 'blue', 'green', 'orange']
colors = ['red'] # , 'purple', 'blue', 'green', 'orange'
color_map = {}
color_index = 0

Expand Down Expand Up @@ -91,7 +91,6 @@ def __init__(self, logger, device, model_llm_type="gpt-3.5-turbo-0125"): # llama
self.LIB = "scanpy"
with open(f'./data/standard_process/{self.LIB}/centroids.pkl', 'rb') as f:
self.centroids = pickle.load(f)

self.user_query_list = []
self.prompt_factory = PromptFactory()
self.model_llm_type = model_llm_type
Expand All @@ -113,6 +112,14 @@ def __init__(self, logger, device, model_llm_type="gpt-3.5-turbo-0125"): # llama
self.predict_api_llm_retry = 3
self.enable_multi_task = True
self.session_id = ""
self.last_user_states = ""
self.user_states = "run_pipeline"
self.retrieve_query_mode = "similar"
self.parameters_info_list = None
self.initial_goal_description = ""
self.new_task_planning = True # decide whether re-plan the task
self.retry_modify_count=0
self.loaded_files = False
#load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY', 'sk-test')
os.environ["GITHUB_TOKEN"] = os.getenv('GITHUB_TOKEN', '')
Expand All @@ -121,19 +128,12 @@ def __init__(self, logger, device, model_llm_type="gpt-3.5-turbo-0125"): # llama
if reset_result=='Fail':
self.logger.error('Reset lib fail! Exit the dialog!')
return
self.last_user_states = ""
self.user_states = "run_pipeline"
self.parameters_info_list = None
self.initial_goal_description = ""
self.image_file_list = []
self.image_file_list = self.update_image_file_list()
self.retrieve_query_mode = "similar"
#self.get_all_api_json_cache(f"./data/standard_process/{self.LIB}/API_init.json", mode='single')
self.all_apis, self.all_apis_json = get_all_api_json(f"./data/standard_process/{self.LIB}/API_init.json", mode='single')
self.new_task_planning = True # decide whether re-plan the task
self.retry_modify_count=0
self.loaded_files = False
self.logger.info("Server ready")
self.save_state_enviro()
async def predict_all_params(self, api_name_tmp, boolean_params, literal_params, int_params, boolean_document, literal_document, int_document):
predicted_params = {}
if boolean_params:
Expand Down Expand Up @@ -389,25 +389,23 @@ def loading_data(self, files, verbose=False):
asyncio.run(self.loading_data_async(files, verbose))
def save_state(self):
a = str(self.session_id)
file_name = f"./tmp/states/{a}_state.pkl"
state = {k: v for k, v in self.__dict__.copy().items() if self.executor.is_picklable(v) and k != 'executor'}
with open(file_name, 'wb') as file:
with open(f"./tmp/states/{a}_state.pkl", 'wb') as file:
pickle.dump(state, file)
self.logger.info("State saved to {}", file_name)
self.logger.info("State saved to {}", f"./tmp/states/{a}_state.pkl")
#@lru_cache(maxsize=10)
def load_state(self, session_id):
a = str(session_id)
file_name = f"./tmp/states/{a}_state.pkl"
with open(file_name, 'rb') as file:
with open(f"./tmp/states/{a}_state.pkl", 'rb') as file:
state = pickle.load(file)
self.__dict__.update(state)
self.logger.info("State loaded from {}", file_name)
self.logger.info("State loaded from {}", f"./tmp/states/{a}_state.pkl")
def run_pipeline_without_files(self, user_input):
self.initialize_tool()
#self.logger.info('==> run_pipeline_without_files')
# if check, back to the last iteration and status
if user_input in ['y', 'n']:
if user_input == 'n':
if user_input in ['y', 'n', 'Y', 'N']:
if user_input in ['n', 'N']:
self.update_user_state("run_pipeline")
self.callback_func('log', "We will start another round. Could you re-enter your inquiry?", "Start another round")
self.save_state_enviro()
Expand Down Expand Up @@ -843,7 +841,7 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input):
self.initialize_tool()
#self.logger.info('==>run_pipeline_after_doublechecking_API_selection')
user_input = str(user_input)
if user_input == 'n':
if user_input in ['n', 'N']:
if self.new_task_planning or self.retry_modify_count>=3: # if there is no task planning
self.update_user_state("run_pipeline")
self.callback_func('log', "We will start another round. Could you re-enter your inquiry?", "Start another round")
Expand All @@ -868,7 +866,7 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input):
self.save_state_enviro()
self.run_pipeline(self.user_query, self.LIB, top_k=3, files=[],conversation_started=False,session_id=self.session_id)
return
elif user_input == 'y':
elif user_input in ['y', 'Y']:
pass
else:
self.callback_func('log', "The input was not y or n, please enter the correct value.", "Index Error")
Expand Down Expand Up @@ -961,7 +959,7 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input):
predicted_parameters = {key: value for key, value in predicted_parameters.items() if value not in [None, "None", "null"] or key in required_param_list}
self.logger.info('after filtering, predicted_parameters: {}', predicted_parameters)
colored_sentence = label_sentence(self.user_query, predicted_parameters)
self.callback_func('log', 'Polished Subtask: ' + colored_sentence, 'Highlight parameters value in subtask description')
self.callback_func('log', colored_sentence, 'Highlight parameters value in polished subtask description')
#self.logger.info('colored_sentence: {}', colored_sentence)
# generate api_calling
self.predicted_api_name, api_calling, self.parameters_info_list = generate_api_calling(self.predicted_api_name, self.API_composite[self.predicted_api_name], predicted_parameters)
Expand Down Expand Up @@ -1248,8 +1246,8 @@ def run_pipeline_after_doublechecking_execution_code(self, user_input):
self.initialize_tool()
#self.logger.info('==> run_pipeline_after_doublechecking_execution_code')
# if check, back to the last iteration and status
if user_input in ['y', 'n', 'r']:
if user_input == 'n':
if user_input in ['y', 'n', 'r', 'Y', 'N', 'R']:
if user_input in ['n', 'N']:
if self.last_user_states=='run_pipeline_asking_GPT':
self.update_user_state("run_pipeline_asking_GPT")
self.callback_func('log', "We will redirect to the LLM model to re-generate the code", "Re-generate the code")
Expand All @@ -1262,7 +1260,7 @@ def run_pipeline_after_doublechecking_execution_code(self, user_input):
self.save_state_enviro()
self.run_pipeline_after_doublechecking_API_selection('y')
return
elif user_input == 'r':
elif user_input in ['r', 'R']:
self.update_user_state("run_pipeline")
self.callback_func('log', "We will start another round. Could you re-enter your inquiry?", "Start another round")
self.save_state_enviro()
Expand Down Expand Up @@ -1438,7 +1436,7 @@ def run_pipeline_after_doublechecking_execution_code(self, user_input):
self.logger.info('relevant_API: {}, execution_prompt: {}', relevant_API, execution_prompt)
#prompt = self.prompt_factory.create_prompt('subtask_code', [], self.user_query, whole_code, True, execution_prompt)
response, _ = LLM_response(execution_prompt, self.model_llm_type, history=[], kwargs={}) # llm
self.logger.info('prompt: {}, response: {}', prompt, response)
self.logger.info('prompt: {}, response: {}', execution_prompt, response)
tmp_retry_count = 0
while tmp_retry_count<5:
tmp_retry_count+=1
Expand Down
131 changes: 131 additions & 0 deletions src/deploy/ollama_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from flask import Flask, request, jsonify, Response
from flask_cors import CORS
from ..deploy.ollama_demo import initialize_model, process_input
from datetime import datetime
import json

app = Flask(__name__)
CORS(app)

# Initialize the model once when the server starts
model = initialize_model()
library = "scanpy" # Set default library or fetch dynamically based on your requirements

def generate_stream(user_input, session_id, conversation_started):
responses = process_input(model, user_input, library, session_id, conversation_started)
for response in responses:
chunk = {
"model": "biomania",
"created_at": datetime.utcnow().isoformat() + "Z",
"message": {
"role": "assistant",
"content": response
},
"done": False
}
yield f"{json.dumps(chunk)}\n"
# Ensure the last response indicates completion
final_chunk = {
"model": "biomania",
"created_at": datetime.utcnow().isoformat() + "Z",
"message": {
"role": "assistant",
"content": ""
},
"done": True,
"done_reason": "stop",
"context": []
}
yield f"{json.dumps(final_chunk)}\n"

@app.route('/api/generate', methods=['POST'])
def generate():
data = request.json
user_input = data.get('input')
session_id = data.get('session_id', datetime.now().strftime("%Y%m%d%H%M%S"))
conversation_started = data.get('conversation_started', True)
if not user_input:
return jsonify({"error": "No input provided"}), 400
return Response(generate_stream(user_input, session_id, conversation_started), content_type='application/json')

@app.route('/api/tags', methods=['GET'])
def get_tags(): # placeholder to be compatible with ollama format
tags = {"models": [{"name":"biomania",
"model":"biomania",
"modified_at":"2024-06-18T18:37:34.916232101-04:00",
"size":1,
"digest":"None",
"details":
{
"parent_model":"",
"format":"python-stream",
"family":"biomania",
"families": None,
"parameter_size":"None",
"quantization_level":"None"
}
}
]
} # Replace with actual data fetching logic
return jsonify(tags)

@app.route('/api/chat', methods=['POST'])
def chat():
if request.is_json:
data = request.json
else:
data = request.get_data(as_text=True)
try:
data = json.loads(data)
except json.JSONDecodeError:
return jsonify({"error": "Invalid JSON"}), 400

messages = data.get('messages')
print(data)
print(messages)
if not messages or not isinstance(messages, list) or len(messages) == 0:
return jsonify({"error": "No messages provided"}), 400

user_input = messages[0].get('content')
if not user_input:
return jsonify({"error": "No content provided in the messages"}), 400

session_id = data.get('session_id', datetime.now().strftime("%Y%m%d%H%M%S"))
conversation_started = data.get('conversation_started', True)

responses = process_input(model, user_input, library, session_id, conversation_started)
output = []
for response in responses:
chunk = {
"model": "biomania",
"created_at": datetime.utcnow().isoformat() + "Z",
"message": {
"role": "assistant",
"content": response
},
"done": False
}
output.append(chunk)

# Ensure the last response indicates completion
final_chunk = {
"model": "biomania",
"created_at": datetime.utcnow().isoformat() + "Z",
"message": {
"role": "assistant",
"content": ""
},
"done": True,
"done_reason": "stop",
"context": []
}
output.append(final_chunk)
return Response((f"{json.dumps(chunk)}\n" for chunk in output), content_type='application/json')

@app.route('/api/chat/biomania', methods=['POST'])
def chat_biomania():
return chat()

if __name__ == '__main__':
app.run(port=5000)

36 changes: 36 additions & 0 deletions src/deploy/ollama_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from ..deploy.model import Model
import os, torch
from datetime import datetime
from colorama import Fore, Style
from ..deploy.cli import encode_file_to_base64, parse_backend_response
from ..deploy.cli_demo import parse_backend_queue

def initialize_model():
from loguru import logger
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
os.makedirs(f"./logs", exist_ok=True)
logger.remove()
logger.add(f"./logs/BioMANIA_log_{timestamp}.log", rotation="500 MB", retention="7 days", level="INFO")
logger.info("Loguru initialized successfully.")
print("Logging setup complete.")
if not os.path.exists('tmp'):
os.mkdir('tmp')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Model(logger, device)
return model

def process_input(model, user_input, library, session_id, conversation_started):
if "<file>" in user_input:
path_start = user_input.find("<file>") + 6
path_end = user_input.find("</file>")
filepath = user_input[path_start:path_end]
user_input = user_input[:path_start-6] + user_input[path_end+7:]
file_content = encode_file_to_base64(filepath)
print(Fore.YELLOW + "File encoded to base64 for processing: " + file_content[:30] + "..." + Style.RESET_ALL)
model.run_pipeline(user_input, library, top_k=1, files=[], conversation_started=conversation_started, session_id=session_id)
messages = parse_backend_queue(model.queue)
responses = []
for msg in messages:
output = parse_backend_response([msg], yield_load=False)
responses.extend(output)
return responses
10 changes: 5 additions & 5 deletions src/inference/execution_UI.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def find_matching_instance(api_string, executor_variables):
return instance_name, True
return None, False
except (ImportError, AttributeError) as e:
print(f"Error: {e}")
self.logger.info(f"Error: {e}")
return None, False

class FakeLogger:
def info(self, *messages):
combined_message = " ".join(str(message) for message in messages)
print("Logged info:", combined_message)
self.logger.info("Logged info:", combined_message)

class CodeExecutor:
def __init__(self, logger=None):
Expand Down Expand Up @@ -118,9 +118,9 @@ def load_environment(self, file_name):
if k.endswith("_AnnData"): # Assuming you have a way to recognize AnnData objects
self.variables[k] = read_h5ad(f"{file_name}_{k}.h5ad") # Load AnnData object from file
self.execute_code = loaded_data["execute_code"]
print('before loading environment:', self.counter)
#print('before loading environment:', self.counter)
self.counter = loaded_data["counter"]
print('after loading environment:', self.counter)
#print('after loading environment:', self.counter)
tmp_variables = {k: self.variables[k]['value'] for k in self.variables if not k.endswith("_AnnData")}
globals().update(tmp_variables)
locals().update(tmp_variables)
Expand Down Expand Up @@ -331,7 +331,7 @@ def generate_execution_code_for_one_api(self, api_name, selected_params, return_
import_code, type_api = self.get_import_code(api_name)
except:
error = traceback.format_exc()
print('generate_execution_code_for_one_api: {}', error)
self.logger.info('generate_execution_code_for_one_api: {}', error)
self.logger.info(f'==>import_code, type_api, {import_code, type_api}')
if import_code in [i['code'] for i in self.execute_code if i['success']=='True']:
self.logger.info('==>api already imported!')
Expand Down
2 changes: 1 addition & 1 deletion src/prompt/promptgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def build_prompt(self, current_subtask, execution_history, namespace_variables,
Never include data description in other subtasks except for the data loading subtask. Ensure Goal-Oriented Task Structuring, place the goal description at the beginning of each subtask.
Ensure to check docstring requirements for API dependencies, required optional parameters, parameter conflicts, and deprecations.
If there are obvious parameter values in the current subtask, retain them in the polished subtask description and condense the parameter assignments in 1-2 sentences.
Just response with the modified subtask description directly. DO NOT add additional explanations or introducement.
Only respond the modified subtask description with assigned parameter values. DO NOT add additional explanations or introducement. DO NOT return any previous subtask.
'''
return query_prompt

Expand Down
Binary file added tmp/sessions/20240622135509_environment.pkl
Binary file not shown.

0 comments on commit 43787f0

Please sign in to comment.