diff --git a/containers/bundled_querybook_config.yaml b/containers/bundled_querybook_config.yaml index d5698614e..923c55b6e 100644 --- a/containers/bundled_querybook_config.yaml +++ b/containers/bundled_querybook_config.yaml @@ -15,6 +15,7 @@ ELASTICSEARCH_HOST: http://elasticsearch:9200 # model_args: # model_name: gpt-3.5-turbo # temperature: 0 +# streaming: true # reserved_tokens: 1024 # table_summary: # model_args: diff --git a/querybook/config/querybook_default_config.yaml b/querybook/config/querybook_default_config.yaml index 052bdea82..265e388e2 100644 --- a/querybook/config/querybook_default_config.yaml +++ b/querybook/config/querybook_default_config.yaml @@ -92,6 +92,7 @@ AI_ASSISTANT_CONFIG: model_args: model_name: ~ temperature: ~ + streaming: ~ reserved_tokens: ~ EMBEDDINGS_PROVIDER: ~ diff --git a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py index c629d1f23..73ae55617 100644 --- a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py +++ b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py @@ -53,12 +53,14 @@ def _get_error_msg(self, error) -> str: def _get_llm(self, ai_command: str, prompt_length: int, callback_handler=None): config = self._get_llm_config(ai_command) - if not callback_handler: + + if not callback_handler or config.get("streaming") is False: # non-streaming - return ChatOpenAI(**config) + return ChatOpenAI( + **{**config, "streaming": False}, + ) return ChatOpenAI( - **config, - streaming=True, - callback_manager=CallbackManager([callback_handler]) + **{**config, "streaming": True}, + callback_manager=CallbackManager([callback_handler]), ) diff --git a/querybook/server/lib/ai_assistant/base_ai_assistant.py b/querybook/server/lib/ai_assistant/base_ai_assistant.py index fae8af071..889765788 100644 --- a/querybook/server/lib/ai_assistant/base_ai_assistant.py +++ b/querybook/server/lib/ai_assistant/base_ai_assistant.py @@ -196,6 +196,10 @@ def generate_sql_query( socket=None, session=None, ): + streaming = self._get_llm_config(AICommandType.TEXT_TO_SQL.value).get( + "streaming", True + ) + query_engine = admin_logic.get_query_engine_by_id( query_engine_id, session=session ) @@ -232,6 +236,7 @@ def generate_sql_query( table_schemas=table_schemas, original_query=original_query, ) + llm = self._get_llm( ai_command=AICommandType.TEXT_TO_SQL.value, prompt_length=self._get_token_count( @@ -239,7 +244,14 @@ def generate_sql_query( ), callback_handler=StreamingWebsocketCallbackHandler(socket), ) - return llm.predict(text=prompt) + response = llm.predict(text=prompt) + + if not streaming: + socket.send_delta_data(response) + socket.send_delta_end() + socket.close() + + return response @catch_error @with_ai_socket(command_type=AICommandType.SQL_TITLE) @@ -251,13 +263,24 @@ def generate_title_from_query(self, query, socket=None): stream (bool, optional): Whether to stream the result. Defaults to True. callback_handler (CallbackHandler, optional): Callback handler to handle the straming result. Required if stream is True. """ + streaming = self._get_llm_config(AICommandType.SQL_TITLE.value).get( + "streaming", True + ) + prompt = self._get_sql_title_prompt(query=query) llm = self._get_llm( ai_command=AICommandType.SQL_TITLE.value, prompt_length=self._get_token_count(AICommandType.SQL_TITLE.value, prompt), callback_handler=StreamingWebsocketCallbackHandler(socket), ) - return llm.predict(text=prompt) + response = llm.predict(text=prompt) + + if not streaming: + socket.send_delta_data(response) + socket.send_delta_end() + socket.close() + + return response @catch_error @with_session @@ -273,6 +296,9 @@ def query_auto_fix( Args: query_execution_id (int): The failed query execution id """ + streaming = self._get_llm_config(AICommandType.SQL_FIX.value).get( + "streaming", True + ) query_execution = qe_logic.get_query_execution_by_id( query_execution_id, session=session ) @@ -303,7 +329,14 @@ def query_auto_fix( prompt_length=self._get_token_count(AICommandType.SQL_FIX.value, prompt), callback_handler=StreamingWebsocketCallbackHandler(socket), ) - return llm.predict(text=prompt) + response = llm.predict(text=prompt) + + if not streaming: + socket.send_delta_data(response) + socket.send_delta_end() + socket.close() + + return response @catch_error @with_session