Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cepo #1

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions configs/cepo_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
bestofn_n: 3
bestofn_temperature: 0.1
bestofn_max_tokens: 4096
bestofn_rating_type: "absolute" # or "pairwise"
planning_n: 3
planning_m: 6
planning_temperature_step1: 0.55
planning_temperature_step2: 0.25
planning_temperature_step3: 0.1
planning_temperature_step4: 0
planning_max_tokens_step1: 4096
planning_max_tokens_step2: 4096
planning_max_tokens_step3: 4096
planning_max_tokens_step4: 4096
30 changes: 28 additions & 2 deletions optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import secrets
from flask import Flask, request, jsonify
from cerebras.cloud.sdk import Cerebras
from openai import AzureOpenAI, OpenAI
from flask import Response
import json
Expand All @@ -27,6 +28,7 @@
from optillm.plansearch import plansearch
from optillm.leap import leap
from optillm.reread import re2_approach
from optillm.cepo import cepo, CepoConfig, init_cepo_config

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
Expand All @@ -50,7 +52,14 @@ def get_config():
from optillm.inference import create_inference_client
API_KEY = os.environ.get("OPTILLM_API_KEY")
default_client = create_inference_client()
# OpenAI, Azure, or LiteLLM API configuration
# Cerebras, OpenAI, Azure, or LiteLLM API configuration
elif os.environ.get("CEREBRAS_API_KEY"):
API_KEY = os.environ.get("CEREBRAS_API_KEY")
base_url = server_config['base_url']
if base_url != "":
default_client = Cerebras(api_key=API_KEY, base_url=base_url)
else:
default_client = Cerebras(api_key=API_KEY)
elif os.environ.get("OPENAI_API_KEY"):
API_KEY = os.environ.get("OPENAI_API_KEY")
base_url = server_config['base_url']
Expand Down Expand Up @@ -104,7 +113,7 @@ def get_config():

# List of known approaches
known_approaches = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency",
"pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]
"pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2", "cepo"]

plugin_approaches = {}

Expand Down Expand Up @@ -283,6 +292,10 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode
return leap(system_prompt, initial_query, client, model)
elif approach == 're2':
return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'])
elif approach == 'cepo':
# build the cepo config based on the cmd line arguments and the
logger.debug(f"Calling with {cepo_config}")
return cepo(system_prompt, initial_query, client, model, cepo_config)
elif approach in plugin_approaches:
return plugin_approaches[approach](system_prompt, initial_query, client, model)
else:
Expand Down Expand Up @@ -609,6 +622,13 @@ def parse_args():
parser.add_argument("--base-url", "--base_url", dest="base_url", type=str, default=base_url_default,
help="Base url for OpenAI compatible endpoint")

# Special handling of all the Cepo Configurations
for key, value in CepoConfig.__dict__.items():
if not key.startswith('__'):
parser.add_argument(f"--cepo_{key}", dest=f"cepo_{key}", type=type(value), default=None, help=f"CePO configuration for {key}")

parser.add_argument(f"--cepo_config_file", dest=f"cepo_config_file", type=str, default=None, help="Path to CePO configuration file")

args = parser.parse_args()

# Convert argument names to match server_config keys
Expand All @@ -622,6 +642,7 @@ def parse_args():

def main():
global server_config
global cepo_config
# Call this function at the start of main()
load_plugins()
args = parse_args()
Expand All @@ -636,6 +657,11 @@ def main():
if logging_level in logging_levels.keys():
logger.setLevel(logging_levels[logging_level])

# set and log the cepo configs
cepo_config = init_cepo_config(server_config)
if args.approach == 'cepo':
logger.info(f"CePO Config: {cepo_config}")

logger.info(f"Starting server with approach: {server_config['approach']}")
server_config_clean = server_config.copy()
if server_config_clean['optillm_api_key']:
Expand Down
Loading