-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
7 changed files
with
197 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from flask import Flask, request, jsonify, Response | ||
from flask_cors import CORS | ||
from ..deploy.ollama_demo import initialize_model, process_input | ||
from datetime import datetime | ||
import json | ||
|
||
app = Flask(__name__) | ||
CORS(app) | ||
|
||
# Initialize the model once when the server starts | ||
model = initialize_model() | ||
library = "scanpy" # Set default library or fetch dynamically based on your requirements | ||
|
||
def generate_stream(user_input, session_id, conversation_started): | ||
responses = process_input(model, user_input, library, session_id, conversation_started) | ||
for response in responses: | ||
chunk = { | ||
"model": "biomania", | ||
"created_at": datetime.utcnow().isoformat() + "Z", | ||
"message": { | ||
"role": "assistant", | ||
"content": response | ||
}, | ||
"done": False | ||
} | ||
yield f"{json.dumps(chunk)}\n" | ||
# Ensure the last response indicates completion | ||
final_chunk = { | ||
"model": "biomania", | ||
"created_at": datetime.utcnow().isoformat() + "Z", | ||
"message": { | ||
"role": "assistant", | ||
"content": "" | ||
}, | ||
"done": True, | ||
"done_reason": "stop", | ||
"context": [] | ||
} | ||
yield f"{json.dumps(final_chunk)}\n" | ||
|
||
@app.route('/api/generate', methods=['POST']) | ||
def generate(): | ||
data = request.json | ||
user_input = data.get('input') | ||
session_id = data.get('session_id', datetime.now().strftime("%Y%m%d%H%M%S")) | ||
conversation_started = data.get('conversation_started', True) | ||
if not user_input: | ||
return jsonify({"error": "No input provided"}), 400 | ||
return Response(generate_stream(user_input, session_id, conversation_started), content_type='application/json') | ||
|
||
@app.route('/api/tags', methods=['GET']) | ||
def get_tags(): # placeholder to be compatible with ollama format | ||
tags = {"models": [{"name":"biomania", | ||
"model":"biomania", | ||
"modified_at":"2024-06-18T18:37:34.916232101-04:00", | ||
"size":1, | ||
"digest":"None", | ||
"details": | ||
{ | ||
"parent_model":"", | ||
"format":"python-stream", | ||
"family":"biomania", | ||
"families": None, | ||
"parameter_size":"None", | ||
"quantization_level":"None" | ||
} | ||
} | ||
] | ||
} # Replace with actual data fetching logic | ||
return jsonify(tags) | ||
|
||
@app.route('/api/chat', methods=['POST']) | ||
def chat(): | ||
if request.is_json: | ||
data = request.json | ||
else: | ||
data = request.get_data(as_text=True) | ||
try: | ||
data = json.loads(data) | ||
except json.JSONDecodeError: | ||
return jsonify({"error": "Invalid JSON"}), 400 | ||
|
||
messages = data.get('messages') | ||
print(data) | ||
print(messages) | ||
if not messages or not isinstance(messages, list) or len(messages) == 0: | ||
return jsonify({"error": "No messages provided"}), 400 | ||
|
||
user_input = messages[0].get('content') | ||
if not user_input: | ||
return jsonify({"error": "No content provided in the messages"}), 400 | ||
|
||
session_id = data.get('session_id', datetime.now().strftime("%Y%m%d%H%M%S")) | ||
conversation_started = data.get('conversation_started', True) | ||
|
||
responses = process_input(model, user_input, library, session_id, conversation_started) | ||
output = [] | ||
for response in responses: | ||
chunk = { | ||
"model": "biomania", | ||
"created_at": datetime.utcnow().isoformat() + "Z", | ||
"message": { | ||
"role": "assistant", | ||
"content": response | ||
}, | ||
"done": False | ||
} | ||
output.append(chunk) | ||
|
||
# Ensure the last response indicates completion | ||
final_chunk = { | ||
"model": "biomania", | ||
"created_at": datetime.utcnow().isoformat() + "Z", | ||
"message": { | ||
"role": "assistant", | ||
"content": "" | ||
}, | ||
"done": True, | ||
"done_reason": "stop", | ||
"context": [] | ||
} | ||
output.append(final_chunk) | ||
return Response((f"{json.dumps(chunk)}\n" for chunk in output), content_type='application/json') | ||
|
||
@app.route('/api/chat/biomania', methods=['POST']) | ||
def chat_biomania(): | ||
return chat() | ||
|
||
if __name__ == '__main__': | ||
app.run(port=5000) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from ..deploy.model import Model | ||
import os, torch | ||
from datetime import datetime | ||
from colorama import Fore, Style | ||
from ..deploy.cli import encode_file_to_base64, parse_backend_response | ||
from ..deploy.cli_demo import parse_backend_queue | ||
|
||
def initialize_model(): | ||
from loguru import logger | ||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | ||
os.makedirs(f"./logs", exist_ok=True) | ||
logger.remove() | ||
logger.add(f"./logs/BioMANIA_log_{timestamp}.log", rotation="500 MB", retention="7 days", level="INFO") | ||
logger.info("Loguru initialized successfully.") | ||
print("Logging setup complete.") | ||
if not os.path.exists('tmp'): | ||
os.mkdir('tmp') | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
model = Model(logger, device) | ||
return model | ||
|
||
def process_input(model, user_input, library, session_id, conversation_started): | ||
if "<file>" in user_input: | ||
path_start = user_input.find("<file>") + 6 | ||
path_end = user_input.find("</file>") | ||
filepath = user_input[path_start:path_end] | ||
user_input = user_input[:path_start-6] + user_input[path_end+7:] | ||
file_content = encode_file_to_base64(filepath) | ||
print(Fore.YELLOW + "File encoded to base64 for processing: " + file_content[:30] + "..." + Style.RESET_ALL) | ||
model.run_pipeline(user_input, library, top_k=1, files=[], conversation_started=conversation_started, session_id=session_id) | ||
messages = parse_backend_queue(model.queue) | ||
responses = [] | ||
for msg in messages: | ||
output = parse_backend_response([msg], yield_load=False) | ||
responses.extend(output) | ||
return responses |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.