diff --git a/src/leapfrogai_api/backend/grpc_client.py b/src/leapfrogai_api/backend/grpc_client.py index 9dbe782ded..7967e5b4bc 100644 --- a/src/leapfrogai_api/backend/grpc_client.py +++ b/src/leapfrogai_api/backend/grpc_client.py @@ -48,7 +48,7 @@ async def completion(model: Model, request: lfai.CompletionRequest): CompletionChoice( index=0, text=response.choices[0].text, - finish_reason=str(response.choices[0].finish_reason), + finish_reason=response.choices[0].finish_reason, logprobs=None, ) ], diff --git a/src/leapfrogai_api/backend/types.py b/src/leapfrogai_api/backend/types.py index 59011003c9..7424632482 100644 --- a/src/leapfrogai_api/backend/types.py +++ b/src/leapfrogai_api/backend/types.py @@ -131,8 +131,8 @@ class CompletionChoice(BaseModel): None, description="Log probabilities for the generated tokens. Only returned if requested.", ) - finish_reason: str = Field( - "", description="The reason why the completion finished.", example="length" + finish_reason: str | None = Field( + None, description="The reason why the completion finished.", example="length" ) @@ -252,7 +252,7 @@ class ChatChoice(BaseModel): default=ChatMessage(), description="The message content for this choice." ) finish_reason: str | None = Field( - default="", + default=None, description="The reason why the model stopped generating tokens.", examples=["stop", "length"], ) diff --git a/src/leapfrogai_sdk/llm.py b/src/leapfrogai_sdk/llm.py index cd2240923a..2209432214 100644 --- a/src/leapfrogai_sdk/llm.py +++ b/src/leapfrogai_sdk/llm.py @@ -21,7 +21,7 @@ class FinishReason(Enum): - NONE = "" + NONE = None STOP = "stop" LENGTH = "length" diff --git a/tests/pytest/leapfrogai_api/test_api.py b/tests/pytest/leapfrogai_api/test_api.py index 2941a83d82..3a5270beec 100644 --- a/tests/pytest/leapfrogai_api/test_api.py +++ b/tests/pytest/leapfrogai_api/test_api.py @@ -260,7 +260,7 @@ def test_chat_completion(dummy_auth_middleware): # parse finish reason assert "finish_reason" in response_choices[0] - assert "STOP" == response_choices[0].get("finish_reason") + assert "stop" == response_choices[0].get("finish_reason") # parse usage data response_usage = response_obj.get("usage") @@ -324,9 +324,9 @@ def test_stream_chat_completion(dummy_auth_middleware): assert "finish_reason" in choices[0] # in streaming responses, the stop reason is not STOP until the last iteration (token) is sent back if iter_length == input_length: - assert "STOP" == choices[0].get("finish_reason") + assert "stop" == choices[0].get("finish_reason") else: - assert "NONE" == choices[0].get("finish_reason") + assert None is choices[0].get("finish_reason") # parse usage data response_usage = stream_response.get("usage") prompt_tokens = response_usage.get("prompt_tokens")