Skip to content

Commit

Permalink
gracefully handle enum reason, pt.2
Browse files Browse the repository at this point in the history
  • Loading branch information
justinthelaw committed Sep 17, 2024
1 parent 9897a67 commit 0eb1ead
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 15 deletions.
15 changes: 2 additions & 13 deletions src/leapfrogai_api/backend/grpc_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""gRPC client for OpenAI models."""

from typing import Iterator, AsyncGenerator, Any, List
from fastapi import HTTPException, status
import grpc
from fastapi.responses import StreamingResponse
import leapfrogai_sdk as lfai
Expand All @@ -25,16 +24,6 @@
from leapfrogai_api.utils.config import Model


def to_string_finish_reason(finish_reason: int) -> str:
try:
return FinishReason(finish_reason).to_string()
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)


async def stream_completion(model: Model, request: lfai.CompletionRequest):
"""Stream completion using the specified model."""
async with grpc.aio.insecure_channel(model.backend) as channel:
Expand All @@ -60,7 +49,7 @@ async def completion(model: Model, request: lfai.CompletionRequest):
CompletionChoice(
index=0,
text=response.choices[0].text,
finish_reason=to_string_finish_reason(
finish_reason=FinishReason.to_string(
response.choices[0].finish_reason
),
logprobs=None,
Expand Down Expand Up @@ -119,7 +108,7 @@ async def chat_completion(model: Model, request: lfai.ChatCompletionRequest):
).lower(),
content=response.choices[0].chat_item.content,
),
finish_reason=to_string_finish_reason(
finish_reason=FinishReason.to_string(
response.choices[0].finish_reason
),
)
Expand Down
9 changes: 7 additions & 2 deletions src/leapfrogai_api/backend/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ChatStreamChoice,
CompletionChoice,
CompletionResponse,
FinishReason,
Usage,
)

Expand All @@ -32,7 +33,9 @@ async def recv_completion(
index=0,
text=c.choices[0].text,
logprobs=None,
finish_reason=c.choices[0].finish_reason,
finish_reason=FinishReason.to_string(
c.choices[0].finish_reason
),
)
],
usage=Usage(
Expand Down Expand Up @@ -68,7 +71,9 @@ async def recv_chat(
delta=ChatDelta(
role="assistant", content=c.choices[0].chat_item.content
),
finish_reason=c.choices[0].finish_reason,
finish_reason=FinishReason.to_string(
c.choices[0].finish_reason
),
)
],
usage=Usage(
Expand Down

0 comments on commit 0eb1ead

Please sign in to comment.