Skip to content

Commit

Permalink
fix nits in oai client for vllm (#821)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesfrye authored Jul 24, 2024
1 parent 92828bb commit 82a35b6
Showing 1 changed file with 89 additions and 31 deletions.
120 changes: 89 additions & 31 deletions 06_gpu_and_ml/llm-serving/vllm_oai_compatible/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""This simple script shows how to interact with an OpenAI-compatible server from a client."""
import argparse

import modal
from openai import OpenAI


class Colors:
"""ANSI color codes"""

Expand All @@ -13,6 +15,7 @@ class Colors:
BOLD = "\033[1m"
END = "\033[0m"


def get_completion(client, model_id, messages, args):
completion_args = {
"model": model_id,
Expand All @@ -28,7 +31,9 @@ def get_completion(client, model_id, messages, args):
"top_p": args.top_p,
}

completion_args = {k: v for k, v in completion_args.items() if v is not None}
completion_args = {
k: v for k, v in completion_args.items() if v is not None
}

try:
response = client.chat.completions.create(**completion_args)
Expand All @@ -37,30 +42,63 @@ def get_completion(client, model_id, messages, args):
print(Colors.RED, f"Error during API call: {e}", Colors.END, sep="")
return None


def main():
parser = argparse.ArgumentParser(description="OpenAI Client CLI")

parser.add_argument('--model', type=str, default=None, help='The model to use for completion, defaults to the first available model')
parser.add_argument('--api-key', type=str, default="super-secret-token", help='The API key to use for authentication, set in your api.py')

parser.add_argument(
"--model",
type=str,
default=None,
help="The model to use for completion, defaults to the first available model",
)
parser.add_argument(
"--api-key",
type=str,
default="super-secret-token",
help="The API key to use for authentication, set in your api.py",
)

# Completion parameters
parser.add_argument('--max-tokens', type=int, default=None)
parser.add_argument('--temperature', type=float, default=0.7)
parser.add_argument('--top-p', type=float, default=0.9)
parser.add_argument('--top-k', type=int, default=0)
parser.add_argument('--frequency-penalty', type=float, default=0)
parser.add_argument('--presence-penalty', type=float, default=0)
parser.add_argument('--n', type=int, default=1, help='Number of completions to generate. Streaming and chat mode only support n=1.')
parser.add_argument('--stop', type=str, default=None)
parser.add_argument('--seed', type=int, default=None)
parser.add_argument("--max-tokens", type=int, default=None)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top-p", type=float, default=0.9)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--frequency-penalty", type=float, default=0)
parser.add_argument("--presence-penalty", type=float, default=0)
parser.add_argument(
"--n",
type=int,
default=1,
help="Number of completions to generate. Streaming and chat mode only support n=1.",
)
parser.add_argument("--stop", type=str, default=None)
parser.add_argument("--seed", type=int, default=None)

# Prompting
parser.add_argument('--prompt', type=str, default="Compose a limerick about baboons and racoons.", help='The user prompt for the chat completion')
parser.add_argument('--system-prompt', type=str, default="You are a poetic assistant, skilled in writing satirical doggerel with creative flair.", help='The system prompt for the chat completion')

parser.add_argument(
"--prompt",
type=str,
default="Compose a limerick about baboons and racoons.",
help="The user prompt for the chat completion",
)
parser.add_argument(
"--system-prompt",
type=str,
default="You are a poetic assistant, skilled in writing satirical doggerel with creative flair.",
help="The system prompt for the chat completion",
)

# UI options
parser.add_argument('--no-stream', dest='stream', action='store_false', help='Disable streaming of response chunks')
parser.add_argument('--chat', action='store_true', help='Enable interactive chat mode')
parser.add_argument(
"--no-stream",
dest="stream",
action="store_false",
help="Disable streaming of response chunks",
)
parser.add_argument(
"--chat", action="store_true", help="Enable interactive chat mode"
)

args = parser.parse_args()

Expand All @@ -72,7 +110,6 @@ def main():
f"https://{WORKSPACE}--vllm-openai-compatible-serve.modal.run/v1"
)


if args.model:
model_id = args.model
print(
Expand Down Expand Up @@ -104,23 +141,33 @@ def main():
}
]

print(Colors.BOLD + "🧠: Using system prompt: " + args.system_prompt + Colors.END)
print(
Colors.BOLD
+ "🧠: Using system prompt: "
+ args.system_prompt
+ Colors.END
)

if args.chat:
print(Colors.GREEN + Colors.BOLD + "\nEntering chat mode. Type 'bye' to end the conversation." + Colors.END)
print(
Colors.GREEN
+ Colors.BOLD
+ "\nEntering chat mode. Type 'bye' to end the conversation."
+ Colors.END
)
while True:
user_input = input("\nYou: ")
if user_input.lower() in ['bye']:
if user_input.lower() in ["bye"]:
break

MAX_HISTORY = 10
if len(messages) > MAX_HISTORY:
messages = messages[:1] + messages[-MAX_HISTORY+1:]
messages = messages[:1] + messages[-MAX_HISTORY + 1 :]

messages.append({"role": "user", "content": user_input})

response = get_completion(client, model_id, messages, args)

if response:
if args.stream:
# only stream assuming n=1
Expand All @@ -134,9 +181,14 @@ def main():
print(Colors.END)
else:
assistant_message = response.choices[0].message.content
print(Colors.BLUE + "\n🤖:" + assistant_message + Colors.END, sep="")

messages.append({"role": "assistant", "content": assistant_message})
print(
Colors.BLUE + "\n🤖:" + assistant_message + Colors.END,
sep="",
)

messages.append(
{"role": "assistant", "content": assistant_message}
)
else:
messages.append({"role": "user", "content": args.prompt})
print(Colors.GREEN + f"\nYou: {args.prompt}" + Colors.END)
Expand All @@ -151,7 +203,13 @@ def main():
else:
# only case where multiple completions are returned
for i, response in enumerate(response.choices):
print(Colors.BLUE + f"\n🤖 Choice {i+1}:{response.message.content}" + Colors.END, sep="")
print(
Colors.BLUE
+ f"\n🤖 Choice {i+1}:{response.message.content}"
+ Colors.END,
sep="",
)


if __name__ == "__main__":
main()
main()

0 comments on commit 82a35b6

Please sign in to comment.