Skip to content

Commit

Permalink
[Client] Add guided_grammar and other missing fields (#532)
Browse files Browse the repository at this point in the history
Add guided_grammar to the client, + add some missing fields to some codepaths
  • Loading branch information
seanshi-scale authored Jun 4, 2024
1 parent bd192cb commit ad24f65
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 3 deletions.
2 changes: 1 addition & 1 deletion clients/python/llmengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.0.0b33"
__version__ = "0.0.0b34"

import os
from typing import Sequence
Expand Down
20 changes: 20 additions & 0 deletions clients/python/llmengine/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def acreate(
guided_json: Optional[Dict[str, Any]] = None,
guided_regex: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
guided_grammar: Optional[str] = None,
timeout: int = COMPLETION_TIMEOUT,
stream: bool = False,
) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]:
Expand Down Expand Up @@ -118,6 +119,9 @@ async def acreate(
guided_choice (Optional[List[str]]):
If specified, the output will be exactly one of the choices.
guided_grammar (Optional[str]):
If specified, the output will follow the context-free grammar provided.
timeout (int):
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.
Expand Down Expand Up @@ -218,6 +222,7 @@ async def _acreate_stream(
guided_json=guided_json,
guided_regex=guided_regex,
guided_choice=guided_choice,
guided_grammar=guided_grammar,
timeout=timeout,
)

Expand All @@ -242,6 +247,11 @@ async def _acreate_sync(**kwargs) -> CompletionSyncResponse:
frequency_penalty=frequency_penalty,
top_k=top_k,
top_p=top_p,
include_stop_str_in_output=include_stop_str_in_output,
guided_json=guided_json,
guided_regex=guided_regex,
guided_choice=guided_choice,
guided_grammar=guided_grammar,
)

@classmethod
Expand All @@ -261,6 +271,7 @@ def create(
guided_json: Optional[Dict[str, Any]] = None,
guided_regex: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
guided_grammar: Optional[str] = None,
timeout: int = COMPLETION_TIMEOUT,
stream: bool = False,
) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]:
Expand Down Expand Up @@ -333,6 +344,9 @@ def create(
guided_choice (Optional[List[str]]):
If specified, the output will be exactly one of the choices.
guided_grammar (Optional[str]):
If specified, the output will follow the context-free grammar provided.
timeout (int):
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.
Expand Down Expand Up @@ -419,6 +433,11 @@ def _create_stream(**kwargs):
frequency_penalty=frequency_penalty,
top_k=top_k,
top_p=top_p,
include_stop_str_in_output=include_stop_str_in_output,
guided_json=guided_json,
guided_regex=guided_regex,
guided_choice=guided_choice,
guided_grammar=guided_grammar,
)

else:
Expand All @@ -436,6 +455,7 @@ def _create_stream(**kwargs):
guided_json=guided_json,
guided_regex=guided_regex,
guided_choice=guided_choice,
guided_grammar=guided_grammar,
).dict()
response = cls.post_sync(
resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}",
Expand Down
2 changes: 2 additions & 0 deletions clients/python/llmengine/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ class CompletionSyncV1Request(BaseModel):
guided_json: Optional[Dict[str, Any]] = Field(default=None)
guided_regex: Optional[str] = Field(default=None)
guided_choice: Optional[List[str]] = Field(default=None)
guided_grammar: Optional[str] = Field(default=None)


class TokenOutput(BaseModel):
Expand Down Expand Up @@ -405,6 +406,7 @@ class CompletionStreamV1Request(BaseModel):
guided_json: Optional[Dict[str, Any]] = Field(default=None)
guided_regex: Optional[str] = Field(default=None)
guided_choice: Optional[List[str]] = Field(default=None)
guided_grammar: Optional[str] = Field(default=None)


class CompletionStreamOutput(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "scale-llm-engine"
version = "0.0.0.beta33"
version = "0.0.0.beta34"
description = "Scale LLM Engine Python client"
license = "Apache-2.0"
authors = ["Phil Chen <[email protected]>"]
Expand Down
2 changes: 1 addition & 1 deletion clients/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="scale-llm-engine",
python_requires=">=3.7",
version="0.0.0.beta33",
version="0.0.0.beta34",
packages=find_packages(),
package_data={"llmengine": ["py.typed"]},
)

0 comments on commit ad24f65

Please sign in to comment.