From 143c27a2b2de79d7e13fbada14f8026392722950 Mon Sep 17 00:00:00 2001 From: Marc Scholten Date: Wed, 13 Sep 2023 16:03:13 -0700 Subject: [PATCH] Added non-streaming access to the OpenAI completion API via fetchCompletion --- ihp-openai/IHP/OpenAI.hs | 51 +++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/ihp-openai/IHP/OpenAI.hs b/ihp-openai/IHP/OpenAI.hs index e66c07964..3991d43e5 100644 --- a/ihp-openai/IHP/OpenAI.hs +++ b/ihp-openai/IHP/OpenAI.hs @@ -15,6 +15,7 @@ import qualified OpenSSL.Session as SSL import qualified Data.Text as Text import qualified Control.Retry as Retry import qualified Control.Exception as Exception +import Control.Applicative ((<|>)) data CompletionRequest = CompletionRequest { messages :: ![Message] @@ -24,6 +25,7 @@ data CompletionRequest = CompletionRequest , presencePenalty :: !Double , frequencePenalty :: !Double , model :: !Text + , stream :: !Bool } data Message = Message @@ -34,12 +36,12 @@ data Message = Message data Role = UserRole | SystemRole | AssistantRole instance ToJSON CompletionRequest where - toJSON CompletionRequest { model, prompt, messages, maxTokens, temperature, presencePenalty, frequencePenalty } = + toJSON CompletionRequest { model, prompt, messages, maxTokens, temperature, presencePenalty, frequencePenalty, stream } = object [ "model" .= model , "messages" .= (messages <> [userMessage prompt]) , "max_tokens" .= maxTokens - , "stream" .= True + , "stream" .= stream , "temperature" .= temperature , "presence_penalty" .= presencePenalty , "frequency_penalty" .= frequencePenalty @@ -69,6 +71,7 @@ newCompletionRequest = CompletionRequest , presencePenalty = 2 , frequencePenalty = 0.2 , model = "gpt-3.5-turbo" + , stream = False } data CompletionResult = CompletionResult @@ -86,13 +89,14 @@ data Choice = Choice instance FromJSON Choice where parseJSON = withObject "Choice" $ \v -> do - delta <- v .: "delta" - content <- delta .: "content" + deltaOrMessage <- (v .: "message") <|> (v .: "delta") + content <- deltaOrMessage .: "content" pure Choice { text = content } streamCompletion :: ByteString -> CompletionRequest -> IO () -> (Text -> IO ()) -> IO Text -streamCompletion secretKey completionRequest onStart callback = do +streamCompletion secretKey completionRequest' onStart callback = do + let completionRequest = enableStream completionRequest' completionRequestRef <- newIORef completionRequest result <- Retry.retrying retryPolicyDefault shouldRetry (action completionRequestRef) case result of @@ -115,7 +119,8 @@ streamCompletion secretKey completionRequest onStart callback = do retryPolicyDefault = Retry.constantDelay 50000 <> Retry.limitRetries 10 streamCompletionWithoutRetry :: ByteString -> CompletionRequest -> IO () -> (Text -> IO ()) -> IO (Either Text Text) -streamCompletionWithoutRetry secretKey completionRequest onStart callback = do +streamCompletionWithoutRetry secretKey completionRequest' onStart callback = do + let completionRequest = enableStream completionRequest' modifyContextSSL (\context -> do SSL.contextSetVerificationMode context SSL.VerifyNone pure context @@ -154,3 +159,37 @@ streamCompletionWithoutRetry secretKey completionRequest onStart callback = do otherwise -> do pure (curBuffer <> json, chunk) Nothing -> pure (curBuffer <> input, chunk) + + +fetchCompletion :: ByteString -> CompletionRequest -> IO Text +fetchCompletion secretKey completionRequest = do + result <- Retry.retrying retryPolicyDefault shouldRetry action + case result of + Left (e :: SomeException) -> Exception.throwIO e + Right result -> pure result + where + shouldRetry retryStatus (Left _) = pure True + shouldRetry retryStatus (Right _) = pure False + action retryStatus = Exception.try (fetchCompletionWithoutRetry secretKey completionRequest) + + retryPolicyDefault = Retry.constantDelay 50000 <> Retry.limitRetries 10 + +fetchCompletionWithoutRetry :: ByteString -> CompletionRequest -> IO Text +fetchCompletionWithoutRetry secretKey completionRequest = do + modifyContextSSL (\context -> do + SSL.contextSetVerificationMode context SSL.VerifyNone + pure context + ) + withOpenSSL do + withConnection (establishConnection "https://api.openai.com/v1/chat/completions") \connection -> do + let q = buildRequest1 do + http POST "/v1/chat/completions" + setContentType "application/json" + Network.Http.Client.setHeader "Authorization" ("Bearer " <> secretKey) + + sendRequest connection q (jsonBody completionRequest) + completionResult :: CompletionResult <- receiveResponse connection jsonHandler + pure (mconcat $ map (.text) completionResult.choices) + +enableStream :: CompletionRequest -> CompletionRequest +enableStream completionRequest = completionRequest { stream = True } \ No newline at end of file