Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Nov 28, 2023
1 parent 98a5cfe commit 0f783f1
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions scripts/inference/endpoint_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from composer.utils import (ObjectStore, maybe_create_object_store_from_uri,
parse_uri)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
log = logging.getLogger(__name__)

ENDPOINT_API_KEY_ENV: str = 'ENDPOINT_API_KEY'
ENDPOINT_URL_ENV: str = 'ENDPOINT_URL'

PROMPT_DELIMITER = "\n"
PROMPT_DELIMITER = '\n'


def parse_args() -> Namespace:
Expand All @@ -38,30 +38,30 @@ def parse_args() -> Namespace:
'prompt contained in a txt file.'
)

now = time.strftime("%Y%m%d-%H%M%S")
default_local_folder = f"/tmp/output/{now}"
now = time.strftime('%Y%m%d-%H%M%S')
default_local_folder = f'/tmp/output/{now}'
parser.add_argument('-l',
'--local-folder',
type=str,
default=default_local_folder,
help="Local folder to save the output")
help='Local folder to save the output')

parser.add_argument('-o',
'--output-folder',
help="Remote folder to save the output")
help='Remote folder to save the output')

#####
# Generation Parameters
parser.add_argument(
'--rate-limit',
type=int,
default=5,
help="Max number of calls to make to the endpoint in a second")
help='Max number of calls to make to the endpoint in a second')
parser.add_argument(
'--batch-size',
type=int,
default=5,
help="Max number of calls to make to the endpoint in a single request")
help='Max number of calls to make to the endpoint in a single request')

#####
# Endpoint Parameters
Expand Down Expand Up @@ -124,12 +124,12 @@ async def main(args: Namespace) -> None:
prompt = load_prompt_string_from_file(prompt)
prompt_strings.append(prompt)

cols = ["batch", "prompt", "output"]
cols = ['batch', 'prompt', 'output']
param_data = {
"max_tokens": args.max_tokens,
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
'max_tokens': args.max_tokens,
'temperature': args.temperature,
'top_k': args.top_k,
'top_p': args.top_p,
}

@sleep_and_retry
Expand All @@ -141,7 +141,7 @@ async def generate(session: aiohttp.ClientSession, batch: int,
api_key = os.environ.get(ENDPOINT_API_KEY_ENV, '')
if not api_key:
log.warning(f'API key not set in {ENDPOINT_API_KEY_ENV}')
headers = {"Authorization": api_key, "Content-Type": "application/json"}
headers = {'Authorization': api_key, 'Content-Type': 'application/json'}

req_start = time.time()
async with session.post(url, headers=headers, json=data) as resp:
Expand All @@ -158,25 +158,25 @@ async def generate(session: aiohttp.ClientSession, batch: int,
)

req_end = time.time()
n_compl = response["usage"]["completion_tokens"]
n_prompt = response["usage"]["prompt_tokens"]
n_compl = response['usage']['completion_tokens']
n_prompt = response['usage']['prompt_tokens']
req_latency = (req_end - req_start)
log.info(f'Completed batch {batch}: {n_compl:,} completion' +
f' tokens using {n_prompt:,} prompt tokens in {req_latency}s')

res = pd.DataFrame(columns=cols)

for r in response["choices"]:
index = r["index"]
res.loc[len(res)] = [batch, prompts[index], r["text"]]
for r in response['choices']:
index = r['index']
res.loc[len(res)] = [batch, prompts[index], r['text']]
return res

res = pd.DataFrame(columns=cols)
batch = 0

total_batches = math.ceil(len(prompt_strings) / args.batch_size)
log.info(
f"Generating {len(prompt_strings)} prompts in {total_batches} batches")
f'Generating {len(prompt_strings)} prompts in {total_batches} batches')

async with aiohttp.ClientSession() as session:
tasks = []
Expand All @@ -192,13 +192,13 @@ async def generate(session: aiohttp.ClientSession, batch: int,
res = pd.concat(results)

res.reset_index(drop=True, inplace=True)
log.info(f"Generated {len(res)} prompts, example data:")
log.info(f'Generated {len(res)} prompts, example data:')
log.info(res.head())

# save res to local output folder
os.makedirs(args.local_folder, exist_ok=True)
local_path = os.path.join(args.local_folder, "output.csv")
res.to_csv(os.path.join(args.local_folder, "output.csv"), index=False)
local_path = os.path.join(args.local_folder, 'output.csv')
res.to_csv(os.path.join(args.local_folder, 'output.csv'), index=False)
log.info(f'Saved results in {local_path}')

if args.output_folder:
Expand Down

0 comments on commit 0f783f1

Please sign in to comment.