diff --git a/src/leapfrogai_api/backend/grpc_client.py b/src/leapfrogai_api/backend/grpc_client.py index e2f088989..6a3fa2e4b 100644 --- a/src/leapfrogai_api/backend/grpc_client.py +++ b/src/leapfrogai_api/backend/grpc_client.py @@ -1,7 +1,6 @@ """gRPC client for OpenAI models.""" from typing import Iterator, AsyncGenerator, Any, List -from fastapi import HTTPException, status import grpc from fastapi.responses import StreamingResponse import leapfrogai_sdk as lfai @@ -25,16 +24,6 @@ from leapfrogai_api.utils.config import Model -def to_string_finish_reason(finish_reason: int) -> str: - try: - return FinishReason(finish_reason).to_string() - except ValueError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), - ) - - async def stream_completion(model: Model, request: lfai.CompletionRequest): """Stream completion using the specified model.""" async with grpc.aio.insecure_channel(model.backend) as channel: @@ -60,7 +49,7 @@ async def completion(model: Model, request: lfai.CompletionRequest): CompletionChoice( index=0, text=response.choices[0].text, - finish_reason=to_string_finish_reason( + finish_reason=FinishReason.to_string( response.choices[0].finish_reason ), logprobs=None, @@ -119,7 +108,7 @@ async def chat_completion(model: Model, request: lfai.ChatCompletionRequest): ).lower(), content=response.choices[0].chat_item.content, ), - finish_reason=to_string_finish_reason( + finish_reason=FinishReason.to_string( response.choices[0].finish_reason ), ) diff --git a/src/leapfrogai_api/backend/helpers.py b/src/leapfrogai_api/backend/helpers.py index 2fec1f6eb..d9baddec8 100644 --- a/src/leapfrogai_api/backend/helpers.py +++ b/src/leapfrogai_api/backend/helpers.py @@ -11,6 +11,7 @@ ChatStreamChoice, CompletionChoice, CompletionResponse, + FinishReason, Usage, ) @@ -32,7 +33,9 @@ async def recv_completion( index=0, text=c.choices[0].text, logprobs=None, - finish_reason=c.choices[0].finish_reason, + finish_reason=FinishReason.to_string( + c.choices[0].finish_reason + ), ) ], usage=Usage( @@ -68,7 +71,9 @@ async def recv_chat( delta=ChatDelta( role="assistant", content=c.choices[0].chat_item.content ), - finish_reason=c.choices[0].finish_reason, + finish_reason=FinishReason.to_string( + c.choices[0].finish_reason + ), ) ], usage=Usage(