-
Notifications
You must be signed in to change notification settings - Fork 10
/
main.py
49 lines (35 loc) · 1.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from pb import create_population, init_run, run_for_n
from pb.mutation_prompts import mutation_prompts
from pb.thinking_styles import thinking_styles
import os
import logging
import argparse
import asyncio
from dotenv import load_dotenv
from rich import print
import cohere
load_dotenv() # load environment variables
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description='Run the PromptBreeder Algorithm. Number of units is mp * ts.')
parser.add_argument('-mp', '--num_mutation_prompts', default=2)
parser.add_argument('-ts', '--num_thinking_styles', default=4)
parser.add_argument('-e', '--num_evals', default=10)
parser.add_argument('-n', '--simulations', default=10)
parser.add_argument('-p', '--problem', default="Solve the math word problem, giving your answer as an arabic numeral.")
args = vars(parser.parse_args())
total_evaluations = args['num_mutation_prompts']*args['num_thinking_styles']*args['num_evals']
# set num_workers to total_evaluations so we always have a thread
co = cohere.Client(api_key=os.environ['COHERE_API_KEY'], num_workers=total_evaluations, max_retries=5, timeout=30) #override the 2 min timeout with 30s.
tp_set = mutation_prompts[:int(args['num_mutation_prompts'])]
mutator_set= thinking_styles[:int(args['num_thinking_styles'])]
logger.info(f'You are prompt-optimizing for the problem: {args["problem"]}')
logger.info(f'Creating the population...')
p = create_population(tp_set=tp_set, mutator_set=mutator_set, problem_description=args['problem'])
logger.info(f'Generating the initial prompts...')
init_run(p, co, int(args['num_evals']))
logger.info(f'Starting the genetic algorithm...')
run_for_n(n=int(args['simulations']), population=p, model=co, num_evals=int(args['num_evals']))
print("%"*80)
print("done processing! final gen:")
print(p.units)