diff --git a/containers/bundled_querybook_config.yaml b/containers/bundled_querybook_config.yaml index 487738c3d..d5698614e 100644 --- a/containers/bundled_querybook_config.yaml +++ b/containers/bundled_querybook_config.yaml @@ -13,25 +13,21 @@ ELASTICSEARCH_HOST: http://elasticsearch:9200 # AI_ASSISTANT_CONFIG: # default: # model_args: -# model_name: gpt-3.5-turbo-16k +# model_name: gpt-3.5-turbo # temperature: 0 -# context_length: 16384 -# reserved_tokens: 2048 +# reserved_tokens: 1024 # table_summary: # model_args: -# model_name: gpt-3.5-turbo-16k +# model_name: gpt-3.5-turbo # temperature: 0 -# context_length: 16384 # sql_summary: # model_args: -# model_name: gpt-3.5-turbo-16k +# model_name: gpt-3.5-turbo # temperature: 0 -# context_length: 16384 # table_select: # model_args: -# model_name: gpt-3.5-turbo-16k +# model_name: gpt-3.5-turbo # temperature: 0 -# context_length: 16384 # Uncomment below to enable vector store to support embedding based table search. # Please check langchain doc for the configs of each provider. diff --git a/docs_website/docs/integrations/add_ai_assistant.md b/docs_website/docs/integrations/add_ai_assistant.md index bb83d9580..e3ed46d22 100644 --- a/docs_website/docs/integrations/add_ai_assistant.md +++ b/docs_website/docs/integrations/add_ai_assistant.md @@ -16,15 +16,17 @@ The AI Assistant plugin will allow users to do title generation, text to sql and Please follow below steps to enable AI assistant plugin: -1. [Optional] Create your own AI assistant provider if needed. Please refer to `querybook/server/lib/ai_assistant/openai_assistant.py` as an example. +1. Add `langchain` package dependency by adding `-r ai/langchain.txt` to `requirements/local.txt`. -2. Add your provider in `plugins/ai_assistant_plugin/__init__.py` +2. [Optional] Create your own AI assistant provider if needed. Please refer to `querybook/server/lib/ai_assistant/openai_assistant.py` as an example. -3. Add configs in the `querybook_config.yaml`. Please refer to `containers/bundled_querybook_config.yaml` as an example. Please also check the model's official doc for all avaialbe model args. +3. Add your provider in `plugins/ai_assistant_plugin/__init__.py` + +4. Add configs in the `querybook_config.yaml`. Please refer to `containers/bundled_querybook_config.yaml` as an example. Please also check the model's official doc for all avaialbe model args. - Dont forget to set proper environment variables for your provider. e.g. for openai, you'll need `OPENAI_API_KEY`. -4. Enable it in `querybook/config/querybook_public_config.yaml` +5. Enable it in `querybook/config/querybook_public_config.yaml` ## Vector Store Plugin diff --git a/package.json b/package.json index 553866e73..de66d712e 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "querybook", - "version": "3.28.0", + "version": "3.28.1", "description": "A Big Data Webapp", "private": true, "scripts": { diff --git a/querybook/config/querybook_default_config.yaml b/querybook/config/querybook_default_config.yaml index fed3948f5..052bdea82 100644 --- a/querybook/config/querybook_default_config.yaml +++ b/querybook/config/querybook_default_config.yaml @@ -92,11 +92,7 @@ AI_ASSISTANT_CONFIG: model_args: model_name: ~ temperature: ~ - context_length: ~ reserved_tokens: ~ - table_select: - fetch_k: ~ - top_n: ~ EMBEDDINGS_PROVIDER: ~ EMBEDDINGS_CONFIG: ~ diff --git a/querybook/server/datasources/search.py b/querybook/server/datasources/search.py index 6a1bcd20b..08251c86d 100644 --- a/querybook/server/datasources/search.py +++ b/querybook/server/datasources/search.py @@ -17,7 +17,6 @@ from lib.elasticsearch.suggest_table import construct_suggest_table_query from lib.elasticsearch.suggest_user import construct_suggest_user_query from lib.elasticsearch.search_utils import ES_CONFIG -from logic import vector_store as vs_logic LOG = get_logger(__file__) @@ -123,6 +122,9 @@ def vector_search_tables( keywords, filters=None, ): + # delayed import only if vector search is enabled + from logic import vector_store as vs_logic + verify_metastore_permission(metastore_id) return vs_logic.search_tables(metastore_id, keywords, filters) diff --git a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py index 25fb25a00..c629d1f23 100644 --- a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py +++ b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py @@ -9,6 +9,15 @@ LOG = get_logger(__file__) +OPENAI_MODEL_CONTEXT_WINDOW_SIZE = { + "gpt-3.5-turbo": 4097, + "gpt-3.5-turbo-16k": 16385, + "gpt-4": 8192, + "gpt-4-32k": 32768, +} +DEFAULT_MODEL_NAME = "gpt-3.5-turbo" + + class OpenAIAssistant(BaseAIAssistant): """To use it, please set the following environment variable: OPENAI_API_KEY: OpenAI API key @@ -18,10 +27,16 @@ class OpenAIAssistant(BaseAIAssistant): def name(self) -> str: return "openai" + def _get_context_length_by_model(self, model_name: str) -> int: + return ( + OPENAI_MODEL_CONTEXT_WINDOW_SIZE.get(model_name) + or OPENAI_MODEL_CONTEXT_WINDOW_SIZE[DEFAULT_MODEL_NAME] + ) + def _get_default_llm_config(self): default_config = super()._get_default_llm_config() if not default_config.get("model_name"): - default_config["model_name"] = "gpt-3.5-turbo" + default_config["model_name"] = DEFAULT_MODEL_NAME return default_config @@ -36,7 +51,7 @@ def _get_error_msg(self, error) -> str: return super()._get_error_msg(error) - def _get_llm(self, ai_command: str, callback_handler=None): + def _get_llm(self, ai_command: str, prompt_length: int, callback_handler=None): config = self._get_llm_config(ai_command) if not callback_handler: # non-streaming diff --git a/querybook/server/lib/ai_assistant/base_ai_assistant.py b/querybook/server/lib/ai_assistant/base_ai_assistant.py index 18a812b2c..22078efd3 100644 --- a/querybook/server/lib/ai_assistant/base_ai_assistant.py +++ b/querybook/server/lib/ai_assistant/base_ai_assistant.py @@ -4,11 +4,10 @@ from app.db import with_session from const.ai_assistant import ( - AICommandType, DEFAUTL_TABLE_SELECT_LIMIT, MAX_SAMPLE_QUERY_COUNT_FOR_TABLE_SUMMARY, + AICommandType, ) -from langchain.chains import LLMChain from lib.logger import get_logger from lib.query_analysis.lineage import process_query from lib.vector_store import get_vector_store @@ -28,7 +27,11 @@ from .prompts.table_summary_prompt import TABLE_SUMMARY_PROMPT from .prompts.text_to_sql_prompt import TEXT_TO_SQL_PROMPT from .streaming_web_socket_callback_handler import StreamingWebsocketCallbackHandler -from .tools.table_schema import get_table_schema_by_name, get_table_schemas_by_names +from .tools.table_schema import ( + get_slimmed_table_schemas, + get_table_schema_by_name, + get_table_schemas_by_names, +) LOG = get_logger(__file__) @@ -71,13 +74,17 @@ def _get_llm_config(self, ai_command: str): **self._config.get(ai_command, {}).get("model_args", {}), } + @abstractmethod + def _get_context_length_by_model(self, model_name: str) -> int: + """Get the context window size of the model.""" + raise NotImplementedError() + def _get_usable_token_count(self, ai_command: str) -> int: ai_command_config = self._config.get(ai_command, {}) default_config = self._config.get("default", {}) - max_context_length = ai_command_config.get( - "context_length" - ) or default_config.get("context_length", 0) + model_name = self._get_llm_config(ai_command)["model_name"] + max_context_length = self._get_context_length_by_model(model_name) reserved_tokens = ai_command_config.get( "reserved_tokens" ) or default_config.get("reserved_tokens", 0) @@ -86,21 +93,44 @@ def _get_usable_token_count(self, ai_command: str) -> int: @abstractmethod def _get_llm( - self, ai_command, callback_handler: StreamingWebsocketCallbackHandler = None + self, + ai_command: str, + prompt_length: int, + callback_handler: StreamingWebsocketCallbackHandler = None, ): - """return the large language model to use""" + """return the large language model to use. + + Args: + ai_command (str): AI command type + prompt_length (str): The number of tokens in the prompt. Can be used to decide which model to use. + callback_handler (StreamingWebsocketCallbackHandler, optional): Callback handler to handle the straming result. + """ raise NotImplementedError() def _get_sql_title_prompt(self, query): return SQL_TITLE_PROMPT.format(query=query) def _get_text_to_sql_prompt(self, dialect, question, table_schemas, original_query): - return TEXT_TO_SQL_PROMPT.format( + context_limit = self._get_usable_token_count(AICommandType.TEXT_TO_SQL.value) + prompt = TEXT_TO_SQL_PROMPT.format( dialect=dialect, question=question, table_schemas=table_schemas, original_query=original_query, ) + token_count = self._get_token_count(AICommandType.TEXT_TO_SQL.value, prompt) + + if token_count > context_limit: + # if the prompt is too long, use slimmed table schemas + prompt = TEXT_TO_SQL_PROMPT.format( + dialect=dialect, + question=question, + table_schemas=get_slimmed_table_schemas(table_schemas), + original_query=original_query, + ) + + # TODO: need a better way to handle it if the prompt is still too long + return prompt def _get_sql_fix_prompt(self, dialect, query, error, table_schemas): return SQL_FIX_PROMPT.format( @@ -133,11 +163,6 @@ def _get_table_select_prompt(self, top_n, question, table_schemas): table_schemas=table_schemas, ) - def _get_llm_chain(self, prompt, socket): - callback_handler = StreamingWebsocketCallbackHandler(socket) - llm = self._get_llm(callback_handler=callback_handler) - return LLMChain(llm=llm, prompt=prompt) - def _get_error_msg(self, error) -> str: """Override this method to return specific error messages for your own assistant.""" if isinstance(error, ValidationError): @@ -207,6 +232,9 @@ def generate_sql_query( ) llm = self._get_llm( ai_command=AICommandType.TEXT_TO_SQL.value, + prompt_length=self._get_token_count( + AICommandType.TEXT_TO_SQL.value, prompt + ), callback_handler=StreamingWebsocketCallbackHandler(socket), ) return llm.predict(text=prompt) @@ -224,6 +252,7 @@ def generate_title_from_query(self, query, socket=None): prompt = self._get_sql_title_prompt(query=query) llm = self._get_llm( ai_command=AICommandType.SQL_TITLE.value, + prompt_length=self._get_token_count(AICommandType.SQL_TITLE.value, prompt), callback_handler=StreamingWebsocketCallbackHandler(socket), ) return llm.predict(text=prompt) @@ -269,6 +298,7 @@ def query_auto_fix( ) llm = self._get_llm( ai_command=AICommandType.SQL_FIX.value, + prompt_length=self._get_token_count(AICommandType.SQL_FIX.value, prompt), callback_handler=StreamingWebsocketCallbackHandler(socket), ) return llm.predict(text=prompt) @@ -301,7 +331,11 @@ def summarize_table( table_schema=table_schema, sample_queries=sample_queries ) llm = self._get_llm( - ai_command=AICommandType.TABLE_SUMMARY.value, callback_handler=None + ai_command=AICommandType.TABLE_SUMMARY.value, + prompt_length=self._get_token_count( + AICommandType.TABLE_SUMMARY.value, prompt + ), + callback_handler=None, ) return llm.predict(text=prompt) @@ -325,7 +359,11 @@ def summarize_query( prompt = self._get_sql_summary_prompt(table_schemas=table_schemas, query=query) llm = self._get_llm( - ai_command=AICommandType.SQL_SUMMARY.value, callback_handler=None + ai_command=AICommandType.SQL_SUMMARY.value, + prompt_length=self._get_token_count( + AICommandType.SQL_SUMMARY.value, prompt + ), + callback_handler=None, ) return llm.predict(text=prompt) @@ -349,7 +387,11 @@ def find_tables(self, metastore_id, question, session=None): AICommandType.TABLE_SELECT.value ) for full_table_name in table_names: - table_schema, table_name = full_table_name.split(".") + full_table_name_parts = full_table_name.split(".") + if len(full_table_name_parts) != 2: + continue + + table_schema, table_name = full_table_name_parts table = get_table_by_name( schema_name=table_schema, name=table_name, @@ -374,7 +416,11 @@ def find_tables(self, metastore_id, question, session=None): question=question, ) llm = self._get_llm( - ai_command=AICommandType.TABLE_SELECT.value, callback_handler=None + ai_command=AICommandType.TABLE_SELECT.value, + prompt_length=self._get_token_count( + AICommandType.TABLE_SELECT.value, prompt + ), + callback_handler=None, ) return json.loads(llm.predict(text=prompt)) except Exception as e: diff --git a/querybook/server/lib/ai_assistant/tools/table_schema.py b/querybook/server/lib/ai_assistant/tools/table_schema.py index 2c2e55657..f3d2cf218 100644 --- a/querybook/server/lib/ai_assistant/tools/table_schema.py +++ b/querybook/server/lib/ai_assistant/tools/table_schema.py @@ -15,12 +15,12 @@ def get_table_documentation(table: DataTable) -> str: vs = get_vector_store() if vs: - vs.get_table_summary(table.id) + return vs.get_table_summary(table.id) return "" -def _get_column(column: DataTableColumn) -> str: +def _get_column(column: DataTableColumn) -> dict[str, str]: column_json = {} column_json["name"] = column.name @@ -39,7 +39,7 @@ def _get_column(column: DataTableColumn) -> str: def _get_table_schema( table: DataTable, should_skip_column: Callable[[DataTableColumn], bool] = None, -) -> str: +) -> dict: """Generate table schema prompt. The format will be like: Table Name: [Name_of_table_1] @@ -77,7 +77,7 @@ def get_table_schema_by_id( table_id: int, should_skip_column: Callable[[DataTableColumn], bool] = None, session=None, -) -> str: +) -> dict: """Generate table schema prompt by table id""" table = m_logic.get_table_by_id(table_id=table_id, session=session) return _get_table_schema(table, should_skip_column) @@ -88,7 +88,7 @@ def get_table_schemas_by_ids( table_ids: list[int], should_skip_column: Callable[[DataTableColumn], bool] = None, session=None, -) -> str: +) -> list[dict]: """Generate table schemas prompt by table ids""" return [ get_table_schema_by_id( @@ -106,9 +106,13 @@ def get_table_schema_by_name( full_table_name: str, should_skip_column: Callable[[DataTableColumn], bool] = None, session=None, -) -> str: +) -> dict: """Generate table schema prompt by full table name""" - table_schema, table_name = full_table_name.split(".") + full_table_name_parts = full_table_name.split(".") + if len(full_table_name_parts) != 2: + return None + + table_schema, table_name = full_table_name_parts table = m_logic.get_table_by_name( schema_name=table_schema, name=table_name, @@ -124,7 +128,7 @@ def get_table_schemas_by_names( full_table_names: list[str], should_skip_column: Callable[[DataTableColumn], bool] = None, session=None, -) -> str: +) -> list[dict]: """Generate table schemas prompt by table names""" return [ get_table_schema_by_name( @@ -135,3 +139,24 @@ def get_table_schemas_by_names( ) for table_name in full_table_names ] + + +def get_slimmed_table_schemas(table_schemas: list[dict]) -> list[dict]: + """Get a slimmed version of the table schemas, which will only keep below fields: + - table_name + - columns: + name + type + """ + column_keys_to_keep = ["name", "type"] + + return [ + { + "table_name": schema["table_name"], + "columns": [ + {k: c[k] for k in column_keys_to_keep if k in c} + for c in schema["columns"] + ], + } + for schema in table_schemas + ] diff --git a/querybook/webapp/components/AIAssistant/AutoFixButton.tsx b/querybook/webapp/components/AIAssistant/AutoFixButton.tsx index dd0f363e0..aeb4f7726 100644 --- a/querybook/webapp/components/AIAssistant/AutoFixButton.tsx +++ b/querybook/webapp/components/AIAssistant/AutoFixButton.tsx @@ -27,6 +27,7 @@ const useSQLFix = () => { }); const { + data: unformattedData, explanation, fix_suggestion: suggestion, fixed_query: rawFixedQuery, @@ -37,7 +38,7 @@ const useSQLFix = () => { return { socket, fixed: Object.keys(data).length > 0, // If has data, then it has been fixed - explanation, + explanation: explanation || unformattedData, suggestion, fixedQuery, };