Skip to content

Commit

Permalink
Serve trtllm upgrade model
Browse files Browse the repository at this point in the history
  • Loading branch information
darraghdog committed Jan 6, 2025
1 parent bbb8220 commit 15b03cc
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion nemo_skills/inference/server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def generate(
tokens_to_generate: int | list[int] = 2048,
temperature: float | list[float] = 0.0,
top_p: float | list[float] = 0.95,
top_k: int | list[int] = 0,
top_k: int | list[int] = 0.0,
min_p: float | list[float] = 0.0,
repetition_penalty: float | list[float] = 1.0,
random_seed: int | list[int] = 0,
stop_phrases: list[str] | list[list[str]] | None = None,
Expand All @@ -127,6 +128,7 @@ def generate(
'temperature': temperature,
'top_p': top_p,
'top_k': top_k,
'min_p': min_p,
'repetition_penalty': repetition_penalty,
'random_seed': random_seed,
'stop_phrases': stop_phrases,
Expand Down Expand Up @@ -173,20 +175,27 @@ def _generate_single(
temperature: float = 0.0,
top_p: float = 0.95,
top_k: int = 0,
min_p: float = 0.0 ,
repetition_penalty: float = 1.0,
random_seed: int = 0,
stop_phrases: list[str] | None = None,
) -> list[dict]:
if isinstance(prompt, dict):
raise NotImplementedError("trtllm server does not support OpenAI \"messages\" as prompt.")

if stop_phrases is None:
stop_phrases = []
top_p_min = None
if min_p > 0:
top_p_min = min_p

request = {
"prompt": prompt,
"tokens_to_generate": tokens_to_generate,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"top_p_min": top_p_min,
"random_seed": random_seed,
"repetition_penalty": repetition_penalty,
"stop_words_list": stop_phrases,
Expand Down

0 comments on commit 15b03cc

Please sign in to comment.