Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SN1-373: Test Time Inference Token Endpoint #534

Open
wants to merge 1 commit into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api_keys.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{}
{"4421e9d3fb1003503e2664ea9bbadd48": {"rate_limit": 10, "usage": 0}, "063884eb7d5aefbc940de68187519983": {"rate_limit": 10, "usage": 0}}
2 changes: 1 addition & 1 deletion validator_api/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def chat_completion(
"""Handle chat completion with multiple miners in parallel."""
# Get multiple UIDs if none specified
if uids is None:
uids = list(get_uids(sampling_mode="top_incentive", k=100))
uids = list(get_uids(sampling_mode="random", k=100))
if uids is None or len(uids) == 0: # if not uids throws error, figure out how to fix
logger.error("No available miners found")
raise HTTPException(status_code=503, detail="No available miners found")
Expand Down
36 changes: 36 additions & 0 deletions validator_api/gpt_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from validator_api.chat_completion import chat_completion
from validator_api.mixture_of_miners import mixture_of_miners
from validator_api.utils import forward_response
from validator_api.test_time_inference import generate_response

router = APIRouter()

Expand Down Expand Up @@ -88,3 +89,38 @@ async def web_retrieval(search_query: str, n_miners: int = 10, uids: list[int] =
)
)
return loaded_results


@router.post("/test_time_inference")
async def test_time_inference(messages: list[dict]):
async def create_response_stream(user_query):
async for steps, total_thinking_time in generate_response(user_query):
if total_thinking_time is not None:
logger.info(f"**Total thinking time: {total_thinking_time:.2f} seconds**")
yield steps, total_thinking_time

# Create a streaming response that yields each step
async def stream_steps():
try:
query = messages[-1]["content"]
logger.info(f"Query: {query}")
async for steps, thinking_time in create_response_stream(query):
step_data = {
"steps": [{"title": step[0], "content": step[1], "thinking_time": step[2]} for step in steps],
"total_thinking_time": thinking_time,
}
yield f"data: {json.dumps(step_data)}\n\n"
except Exception as e:
logger.exception(f"Error during streaming: {e}")
yield f'data: {{"error": "Internal Server Error: {str(e)}"}}\n\n'
finally:
yield "data: [DONE]\n\n"

return StreamingResponse(
stream_steps(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
167 changes: 167 additions & 0 deletions validator_api/test_time_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import re
import json
import time
from loguru import logger
from validator_api.chat_completion import chat_completion
from shared.timer import Timer

MAX_THINKING_STEPS = 10


def parse_multiple_json(api_response):
"""
Parses a string containing multiple JSON objects and returns a list of dictionaries.

Args:
api_response (str): The string returned by the API containing JSON objects.

Returns:
list: A list of dictionaries parsed from the JSON objects.
"""
# Regular expression pattern to match individual JSON objects
json_pattern = re.compile(r"\{.*?\}", re.DOTALL)

# Find all JSON object strings in the response
json_strings = json_pattern.findall(api_response)

parsed_objects = []
for json_str in json_strings:
try:
# Replace escaped single quotes with actual single quotes
json_str_clean = json_str.replace("\\'", "'")

# Parse the JSON string into a dictionary
obj = json.loads(json_str_clean)
parsed_objects.append(obj)
except json.JSONDecodeError as e:
print(f"Failed to parse JSON object: {e}")
continue

return parsed_objects


async def make_api_call(messages, max_tokens, model=None, is_final_answer=False):
logger.info(f"Making API call with messages: {messages}")
for attempt in range(3):
try:
response = await chat_completion(
body={
# "messages": messages,
"messages": [
{"role": "user", "content": "Remember the number 49"},
{"role": "user", "content": "Remember the number 32"},
{"role": "user", "content": "Recite back all numbers previously given to you."},
],
"max_tokens": max_tokens,
"model": model,
"stream": False,
"task": "InferenceTask",
}
)
# return response.choices[0].message.content
response_dict = parse_multiple_json(response)[0]
return response_dict
except Exception as e:
if attempt == 2:
if is_final_answer:
return {
"title": "Error",
"content": f"Failed to generate final answer after 3 attempts. Error: {str(e)}",
}
else:
return {
"title": "Error",
"content": f"Failed to generate step after 3 attempts. Error: {str(e)}",
"next_action": "final_answer",
}
time.sleep(1) # Wait for 1 second before retrying


async def generate_response(prompt):
messages = [
{
"role": "system",
"content": """You are an expert AI assistant with advanced reasoning capabilities. Your task is to provide detailed, step-by-step explanations of your thought process. For each step:

1. Provide a clear, concise title describing the current reasoning phase.
2. Elaborate on your thought process in the content section.
3. Decide whether to continue reasoning or provide a final answer.

Response Format:
Use JSON with keys: 'title', 'content', 'next_action' (values: 'continue' or 'final_answer')

Key Instructions:
- Employ at least 5 distinct reasoning steps.
- Acknowledge your limitations as an AI and explicitly state what you can and cannot do.
- Actively explore and evaluate alternative answers or approaches.
- Critically assess your own reasoning; identify potential flaws or biases.
- When re-examining, employ a fundamentally different approach or perspective.
- Utilize at least 3 diverse methods to derive or verify your answer.
- Incorporate relevant domain knowledge and best practices in your reasoning.
- Quantify certainty levels for each step and the final conclusion when applicable.
- Consider potential edge cases or exceptions to your reasoning.
- Provide clear justifications for eliminating alternative hypotheses.
- Output only one step at a time to ensure a detailed and coherent explanation.


Example of a valid JSON response:
```json
{
"title": "Initial Problem Analysis",
"content": "To approach this problem effectively, I'll first break down the given information into key components. This involves identifying...[detailed explanation]... By structuring the problem this way, we can systematically address each aspect.",
"next_action": "continue"
}```
""",
}
]
messages += [{"role": "user", "content": prompt}]
messages += [
{
"role": "assistant",
"content": "Thank you! I will now think step by step following my instructions, starting at the beginning after decomposing the problem.",
}
]

steps = []
step_count = 1
total_thinking_time = 0

for _ in range(MAX_THINKING_STEPS):
with Timer() as timer:
step_data = await make_api_call(messages, 300)
thinking_time = timer.final_time
total_thinking_time += thinking_time

steps.append((f"Step {step_count}: {step_data['title']}", step_data["content"], thinking_time))

messages.append({"role": "assistant", "content": json.dumps(step_data)})

if step_data["next_action"] == "final_answer" or not step_data.get("next_action"):
break

step_count += 1

# Yield after each step
yield steps, None

# Generate final answer
messages.append(
{
"role": "user",
"content": "Please provide the final answer based on your reasoning above. You must return your answer in a valid json.",
}
)

start_time = time.time()
final_data = make_api_call(messages, 200, is_final_answer=True)
end_time = time.time()
thinking_time = end_time - start_time
total_thinking_time += thinking_time

if final_data["title"] == "Error":
steps.append(("Error", final_data["content"], thinking_time))
raise ValueError("Failed to generate final answer: {final_data['content']}")

steps.append(("Final Answer", final_data["content"], thinking_time))

yield steps, total_thinking_time
Loading