Skip to content

Commit

Permalink
fix: some small changes and fixes for AI assistant (pinterest#1333)
Browse files Browse the repository at this point in the history
  • Loading branch information
jczhong84 authored and aidenprice committed Jan 3, 2024
1 parent f9fb350 commit 8f08570
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 48 deletions.
14 changes: 5 additions & 9 deletions containers/bundled_querybook_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,21 @@ ELASTICSEARCH_HOST: http://elasticsearch:9200
# AI_ASSISTANT_CONFIG:
# default:
# model_args:
# model_name: gpt-3.5-turbo-16k
# model_name: gpt-3.5-turbo
# temperature: 0
# context_length: 16384
# reserved_tokens: 2048
# reserved_tokens: 1024
# table_summary:
# model_args:
# model_name: gpt-3.5-turbo-16k
# model_name: gpt-3.5-turbo
# temperature: 0
# context_length: 16384
# sql_summary:
# model_args:
# model_name: gpt-3.5-turbo-16k
# model_name: gpt-3.5-turbo
# temperature: 0
# context_length: 16384
# table_select:
# model_args:
# model_name: gpt-3.5-turbo-16k
# model_name: gpt-3.5-turbo
# temperature: 0
# context_length: 16384

# Uncomment below to enable vector store to support embedding based table search.
# Please check langchain doc for the configs of each provider.
Expand Down
10 changes: 6 additions & 4 deletions docs_website/docs/integrations/add_ai_assistant.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ The AI Assistant plugin will allow users to do title generation, text to sql and

Please follow below steps to enable AI assistant plugin:

1. [Optional] Create your own AI assistant provider if needed. Please refer to `querybook/server/lib/ai_assistant/openai_assistant.py` as an example.
1. Add `langchain` package dependency by adding `-r ai/langchain.txt` to `requirements/local.txt`.

2. Add your provider in `plugins/ai_assistant_plugin/__init__.py`
2. [Optional] Create your own AI assistant provider if needed. Please refer to `querybook/server/lib/ai_assistant/openai_assistant.py` as an example.

3. Add configs in the `querybook_config.yaml`. Please refer to `containers/bundled_querybook_config.yaml` as an example. Please also check the model's official doc for all avaialbe model args.
3. Add your provider in `plugins/ai_assistant_plugin/__init__.py`

4. Add configs in the `querybook_config.yaml`. Please refer to `containers/bundled_querybook_config.yaml` as an example. Please also check the model's official doc for all avaialbe model args.

- Dont forget to set proper environment variables for your provider. e.g. for openai, you'll need `OPENAI_API_KEY`.

4. Enable it in `querybook/config/querybook_public_config.yaml`
5. Enable it in `querybook/config/querybook_public_config.yaml`

## Vector Store Plugin

Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "querybook",
"version": "3.28.0",
"version": "3.28.1",
"description": "A Big Data Webapp",
"private": true,
"scripts": {
Expand Down
4 changes: 0 additions & 4 deletions querybook/config/querybook_default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ AI_ASSISTANT_CONFIG:
model_args:
model_name: ~
temperature: ~
context_length: ~
reserved_tokens: ~
table_select:
fetch_k: ~
top_n: ~

EMBEDDINGS_PROVIDER: ~
EMBEDDINGS_CONFIG: ~
Expand Down
4 changes: 3 additions & 1 deletion querybook/server/datasources/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from lib.elasticsearch.suggest_table import construct_suggest_table_query
from lib.elasticsearch.suggest_user import construct_suggest_user_query
from lib.elasticsearch.search_utils import ES_CONFIG
from logic import vector_store as vs_logic

LOG = get_logger(__file__)

Expand Down Expand Up @@ -123,6 +122,9 @@ def vector_search_tables(
keywords,
filters=None,
):
# delayed import only if vector search is enabled
from logic import vector_store as vs_logic

verify_metastore_permission(metastore_id)
return vs_logic.search_tables(metastore_id, keywords, filters)

Expand Down
19 changes: 17 additions & 2 deletions querybook/server/lib/ai_assistant/assistants/openai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
LOG = get_logger(__file__)


OPENAI_MODEL_CONTEXT_WINDOW_SIZE = {
"gpt-3.5-turbo": 4097,
"gpt-3.5-turbo-16k": 16385,
"gpt-4": 8192,
"gpt-4-32k": 32768,
}
DEFAULT_MODEL_NAME = "gpt-3.5-turbo"


class OpenAIAssistant(BaseAIAssistant):
"""To use it, please set the following environment variable:
OPENAI_API_KEY: OpenAI API key
Expand All @@ -18,10 +27,16 @@ class OpenAIAssistant(BaseAIAssistant):
def name(self) -> str:
return "openai"

def _get_context_length_by_model(self, model_name: str) -> int:
return (
OPENAI_MODEL_CONTEXT_WINDOW_SIZE.get(model_name)
or OPENAI_MODEL_CONTEXT_WINDOW_SIZE[DEFAULT_MODEL_NAME]
)

def _get_default_llm_config(self):
default_config = super()._get_default_llm_config()
if not default_config.get("model_name"):
default_config["model_name"] = "gpt-3.5-turbo"
default_config["model_name"] = DEFAULT_MODEL_NAME

return default_config

Expand All @@ -36,7 +51,7 @@ def _get_error_msg(self, error) -> str:

return super()._get_error_msg(error)

def _get_llm(self, ai_command: str, callback_handler=None):
def _get_llm(self, ai_command: str, prompt_length: int, callback_handler=None):
config = self._get_llm_config(ai_command)
if not callback_handler:
# non-streaming
Expand Down
82 changes: 64 additions & 18 deletions querybook/server/lib/ai_assistant/base_ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

from app.db import with_session
from const.ai_assistant import (
AICommandType,
DEFAUTL_TABLE_SELECT_LIMIT,
MAX_SAMPLE_QUERY_COUNT_FOR_TABLE_SUMMARY,
AICommandType,
)
from langchain.chains import LLMChain
from lib.logger import get_logger
from lib.query_analysis.lineage import process_query
from lib.vector_store import get_vector_store
Expand All @@ -28,7 +27,11 @@
from .prompts.table_summary_prompt import TABLE_SUMMARY_PROMPT
from .prompts.text_to_sql_prompt import TEXT_TO_SQL_PROMPT
from .streaming_web_socket_callback_handler import StreamingWebsocketCallbackHandler
from .tools.table_schema import get_table_schema_by_name, get_table_schemas_by_names
from .tools.table_schema import (
get_slimmed_table_schemas,
get_table_schema_by_name,
get_table_schemas_by_names,
)

LOG = get_logger(__file__)

Expand Down Expand Up @@ -71,13 +74,17 @@ def _get_llm_config(self, ai_command: str):
**self._config.get(ai_command, {}).get("model_args", {}),
}

@abstractmethod
def _get_context_length_by_model(self, model_name: str) -> int:
"""Get the context window size of the model."""
raise NotImplementedError()

def _get_usable_token_count(self, ai_command: str) -> int:
ai_command_config = self._config.get(ai_command, {})
default_config = self._config.get("default", {})

max_context_length = ai_command_config.get(
"context_length"
) or default_config.get("context_length", 0)
model_name = self._get_llm_config(ai_command)["model_name"]
max_context_length = self._get_context_length_by_model(model_name)
reserved_tokens = ai_command_config.get(
"reserved_tokens"
) or default_config.get("reserved_tokens", 0)
Expand All @@ -86,21 +93,44 @@ def _get_usable_token_count(self, ai_command: str) -> int:

@abstractmethod
def _get_llm(
self, ai_command, callback_handler: StreamingWebsocketCallbackHandler = None
self,
ai_command: str,
prompt_length: int,
callback_handler: StreamingWebsocketCallbackHandler = None,
):
"""return the large language model to use"""
"""return the large language model to use.
Args:
ai_command (str): AI command type
prompt_length (str): The number of tokens in the prompt. Can be used to decide which model to use.
callback_handler (StreamingWebsocketCallbackHandler, optional): Callback handler to handle the straming result.
"""
raise NotImplementedError()

def _get_sql_title_prompt(self, query):
return SQL_TITLE_PROMPT.format(query=query)

def _get_text_to_sql_prompt(self, dialect, question, table_schemas, original_query):
return TEXT_TO_SQL_PROMPT.format(
context_limit = self._get_usable_token_count(AICommandType.TEXT_TO_SQL.value)
prompt = TEXT_TO_SQL_PROMPT.format(
dialect=dialect,
question=question,
table_schemas=table_schemas,
original_query=original_query,
)
token_count = self._get_token_count(AICommandType.TEXT_TO_SQL.value, prompt)

if token_count > context_limit:
# if the prompt is too long, use slimmed table schemas
prompt = TEXT_TO_SQL_PROMPT.format(
dialect=dialect,
question=question,
table_schemas=get_slimmed_table_schemas(table_schemas),
original_query=original_query,
)

# TODO: need a better way to handle it if the prompt is still too long
return prompt

def _get_sql_fix_prompt(self, dialect, query, error, table_schemas):
return SQL_FIX_PROMPT.format(
Expand Down Expand Up @@ -133,11 +163,6 @@ def _get_table_select_prompt(self, top_n, question, table_schemas):
table_schemas=table_schemas,
)

def _get_llm_chain(self, prompt, socket):
callback_handler = StreamingWebsocketCallbackHandler(socket)
llm = self._get_llm(callback_handler=callback_handler)
return LLMChain(llm=llm, prompt=prompt)

def _get_error_msg(self, error) -> str:
"""Override this method to return specific error messages for your own assistant."""
if isinstance(error, ValidationError):
Expand Down Expand Up @@ -207,6 +232,9 @@ def generate_sql_query(
)
llm = self._get_llm(
ai_command=AICommandType.TEXT_TO_SQL.value,
prompt_length=self._get_token_count(
AICommandType.TEXT_TO_SQL.value, prompt
),
callback_handler=StreamingWebsocketCallbackHandler(socket),
)
return llm.predict(text=prompt)
Expand All @@ -224,6 +252,7 @@ def generate_title_from_query(self, query, socket=None):
prompt = self._get_sql_title_prompt(query=query)
llm = self._get_llm(
ai_command=AICommandType.SQL_TITLE.value,
prompt_length=self._get_token_count(AICommandType.SQL_TITLE.value, prompt),
callback_handler=StreamingWebsocketCallbackHandler(socket),
)
return llm.predict(text=prompt)
Expand Down Expand Up @@ -269,6 +298,7 @@ def query_auto_fix(
)
llm = self._get_llm(
ai_command=AICommandType.SQL_FIX.value,
prompt_length=self._get_token_count(AICommandType.SQL_FIX.value, prompt),
callback_handler=StreamingWebsocketCallbackHandler(socket),
)
return llm.predict(text=prompt)
Expand Down Expand Up @@ -301,7 +331,11 @@ def summarize_table(
table_schema=table_schema, sample_queries=sample_queries
)
llm = self._get_llm(
ai_command=AICommandType.TABLE_SUMMARY.value, callback_handler=None
ai_command=AICommandType.TABLE_SUMMARY.value,
prompt_length=self._get_token_count(
AICommandType.TABLE_SUMMARY.value, prompt
),
callback_handler=None,
)
return llm.predict(text=prompt)

Expand All @@ -325,7 +359,11 @@ def summarize_query(

prompt = self._get_sql_summary_prompt(table_schemas=table_schemas, query=query)
llm = self._get_llm(
ai_command=AICommandType.SQL_SUMMARY.value, callback_handler=None
ai_command=AICommandType.SQL_SUMMARY.value,
prompt_length=self._get_token_count(
AICommandType.SQL_SUMMARY.value, prompt
),
callback_handler=None,
)
return llm.predict(text=prompt)

Expand All @@ -349,7 +387,11 @@ def find_tables(self, metastore_id, question, session=None):
AICommandType.TABLE_SELECT.value
)
for full_table_name in table_names:
table_schema, table_name = full_table_name.split(".")
full_table_name_parts = full_table_name.split(".")
if len(full_table_name_parts) != 2:
continue

table_schema, table_name = full_table_name_parts
table = get_table_by_name(
schema_name=table_schema,
name=table_name,
Expand All @@ -374,7 +416,11 @@ def find_tables(self, metastore_id, question, session=None):
question=question,
)
llm = self._get_llm(
ai_command=AICommandType.TABLE_SELECT.value, callback_handler=None
ai_command=AICommandType.TABLE_SELECT.value,
prompt_length=self._get_token_count(
AICommandType.TABLE_SELECT.value, prompt
),
callback_handler=None,
)
return json.loads(llm.predict(text=prompt))
except Exception as e:
Expand Down
Loading

0 comments on commit 8f08570

Please sign in to comment.