Skip to content

Commit

Permalink
DH-4905 Add flag to create a csv file when the response has many rows (
Browse files Browse the repository at this point in the history
…#230)

* DH-4905 Add flag to create a csv file when the response have many rows

* db_connections endpoints support new file storages

* Fix test

* DH-4931/avoid running nl response generators for csv

* DH-4905 bug fix on csv file generation endpoint (#236)

* Remove URL logic and rename flag to create CSV file

* Fix generate_csv flag

* Send db_connection object as parameter

* Remove temp file when is downloaded

* Only returns sql_query_result as null when generate_csv flag is set and it has more than 50 rows

---------

Co-authored-by: mohammadrezapourreza <[email protected]>
Co-authored-by: Dishen <[email protected]>
  • Loading branch information
3 people authored Nov 10, 2023
1 parent 2c27880 commit 6b019cb
Show file tree
Hide file tree
Showing 15 changed files with 291 additions and 33 deletions.
20 changes: 17 additions & 3 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from fastapi import BackgroundTasks
from fastapi.responses import FileResponse

from dataherald.api.types import Query
from dataherald.config import Component
Expand Down Expand Up @@ -37,16 +38,26 @@ def scan_db(

@abstractmethod
def answer_question(
self, run_evaluator: bool = True, question_request: QuestionRequest = None
self,
run_evaluator: bool = True,
generate_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
pass

@abstractmethod
def answer_question_with_timeout(
self, run_evaluator: bool = True, question_request: QuestionRequest = None
self,
run_evaluator: bool = True,
generate_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
pass

@abstractmethod
def update_response(self, response_id: str) -> Response:
pass

@abstractmethod
def get_questions(self, db_connection_id: str | None = None) -> list[Question]:
pass
Expand Down Expand Up @@ -106,6 +117,7 @@ def create_response(
self,
run_evaluator: bool = True,
sql_response_only: bool = False,
generate_csv: bool = False,
query_request: CreateResponseRequest = None,
) -> Response:
pass
Expand All @@ -119,7 +131,9 @@ def get_response(self, response_id: str) -> Response:
pass

@abstractmethod
def update_response(self, response_id: str) -> Response:
def get_response_file(
self, response_id: str, background_tasks: BackgroundTasks
) -> FileResponse:
pass

@abstractmethod
Expand Down
82 changes: 75 additions & 7 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bson import json_util
from bson.objectid import InvalidId, ObjectId
from fastapi import BackgroundTasks, HTTPException
from fastapi.responses import JSONResponse
from fastapi.responses import FileResponse, JSONResponse
from overrides import override

from dataherald.api import API
Expand Down Expand Up @@ -51,9 +51,12 @@
TableDescriptionRequest,
UpdateInstruction,
)
from dataherald.utils.s3 import S3

logger = logging.getLogger(__name__)

MAX_ROWS_TO_CREATE_CSV_FILE = 50


def async_scanning(scanner, database, scanner_request, storage):
scanner.scan(
Expand All @@ -64,6 +67,10 @@ def async_scanning(scanner, database, scanner_request, storage):
)


def delete_file(file_location: str):
os.remove(file_location)


class FastAPI(API):
def __init__(self, system: System):
super().__init__(system)
Expand Down Expand Up @@ -120,7 +127,10 @@ def scan_db(

@override
def answer_question(
self, run_evaluator: bool = True, question_request: QuestionRequest = None
self,
run_evaluator: bool = True,
generate_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
"""Takes in an English question and answers it based on content from the registered databases"""
logger.info(f"Answer question: {question_request.question}")
Expand Down Expand Up @@ -151,7 +161,10 @@ def answer_question(
context = context_store.retrieve_context_for_question(user_question)
start_generated_answer = time.time()
generated_answer = sql_generation.generate_response(
user_question, database_connection, context[0]
user_question,
database_connection,
context[0],
generate_csv,
)
logger.info("Starts evaluator...")
if run_evaluator:
Expand All @@ -165,13 +178,22 @@ def answer_question(
status_code=400,
content={"question_id": user_question.id, "error_message": str(e)},
)
if (
generate_csv
and len(generated_answer.sql_query_result.rows)
> MAX_ROWS_TO_CREATE_CSV_FILE
):
generated_answer.sql_query_result = None
generated_answer.exec_time = time.time() - start_generated_answer
response_repository = ResponseRepository(self.storage)
return response_repository.insert(generated_answer)

@override
def answer_question_with_timeout(
self, run_evaluator: bool = True, question_request: QuestionRequest = None
self,
run_evaluator: bool = True,
generate_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
result = None
exception = None
Expand All @@ -186,7 +208,9 @@ def answer_question_with_timeout(
def run_and_catch_exceptions():
nonlocal result, exception
if not stop_event.is_set():
result = self.answer_question(run_evaluator, question_request)
result = self.answer_question(
run_evaluator, generate_csv, question_request
)

thread = threading.Thread(target=run_and_catch_exceptions)
thread.start()
Expand Down Expand Up @@ -214,6 +238,7 @@ def create_database_connection(
llm_api_key=database_connection_request.llm_api_key,
use_ssh=database_connection_request.use_ssh,
ssh_settings=database_connection_request.ssh_settings,
file_storage=database_connection_request.file_storage,
)

SQLDatabase.get_sql_engine(db_connection, True)
Expand Down Expand Up @@ -248,6 +273,7 @@ def update_database_connection(
llm_api_key=database_connection_request.llm_api_key,
use_ssh=database_connection_request.use_ssh,
ssh_settings=database_connection_request.ssh_settings,
file_storage=database_connection_request.file_storage,
)

SQLDatabase.get_sql_engine(db_connection, True)
Expand Down Expand Up @@ -352,6 +378,37 @@ def get_response(self, response_id: str) -> Response:

return result

@override
def get_response_file(
self, response_id: str, background_tasks: BackgroundTasks
) -> FileResponse:
response_repository = ResponseRepository(self.storage)
question_repository = QuestionRepository(self.storage)
db_connection_repository = DatabaseConnectionRepository(self.storage)
try:
result = response_repository.find_by_id(response_id)
question = question_repository.find_by_id(result.question_id)
db_connection = db_connection_repository.find_by_id(
question.db_connection_id
)
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e

if not result:
raise HTTPException(
status_code=404, detail="Question, response, or db_connection not found"
)

s3 = S3()

file_location = s3.download(result.csv_file_path, db_connection.file_storage)
background_tasks.add_task(delete_file, file_location)

return FileResponse(
file_location,
media_type="text/csv",
)

@override
def update_response(self, response_id: str) -> Response:
response_repository = ResponseRepository(self.storage)
Expand Down Expand Up @@ -429,6 +486,7 @@ def create_response(
self,
run_evaluator: bool = True,
sql_response_only: bool = False,
generate_csv: bool = False,
query_request: CreateResponseRequest = None, # noqa: ARG002
) -> Response:
question_repository = QuestionRepository(self.storage)
Expand All @@ -448,7 +506,10 @@ def create_response(
context = context_store.retrieve_context_for_question(user_question)
start_generated_answer = time.time()
response = sql_generation.generate_response(
user_question, database_connection, context[0]
user_question,
database_connection,
context[0],
generate_csv,
)
else:
response = Response(
Expand All @@ -458,7 +519,9 @@ def create_response(
start_generated_answer = time.time()

generates_nl_answer = GeneratesNlAnswer(self.system, self.storage)
response = generates_nl_answer.execute(response, sql_response_only)
response = generates_nl_answer.execute(
response, sql_response_only, generate_csv
)
except openai.error.AuthenticationError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except ValueError as e:
Expand All @@ -472,6 +535,11 @@ def create_response(
user_question, response, database_connection
)
response.confidence_score = confidence_score
if (
generate_csv
and len(response.sql_query_result.rows) > MAX_ROWS_TO_CREATE_CSV_FILE
):
response.sql_query_result = None
response.exec_time = time.time() - start_generated_answer
response_repository.insert(response)
return response
Expand Down
4 changes: 2 additions & 2 deletions dataherald/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, storage):
self.storage = storage

def insert(self, response: Response) -> Response:
response_dict = response.dict(exclude={"id"})
response_dict = response.dict(exclude={"id", "sql_query_result"})
response_dict["question_id"] = ObjectId(response.question_id)
response.id = str(self.storage.insert_one(DB_COLLECTION, response_dict))
return response
Expand All @@ -25,7 +25,7 @@ def find_one(self, query: dict) -> Response | None:
return Response(**row)

def update(self, response: Response) -> Response:
response_dict = response.dict(exclude={"id"})
response_dict = response.dict(exclude={"id", "sql_query_result"})
response_dict["question_id"] = ObjectId(response.question_id)

self.storage.update_or_create(
Expand Down
27 changes: 22 additions & 5 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import fastapi
from fastapi import BackgroundTasks, status
from fastapi import FastAPI as _FastAPI
from fastapi.responses import JSONResponse
from fastapi.responses import FileResponse, JSONResponse
from fastapi.routing import APIRoute

import dataherald
Expand Down Expand Up @@ -164,6 +164,13 @@ def __init__(self, settings: Settings):
tags=["Responses"],
)

self.router.add_api_route(
"/api/v1/responses/{response_id}/file",
self.get_response_file,
methods=["GET"],
tags=["Responses"],
)

self.router.add_api_route(
"/api/v1/responses/{response_id}",
self.update_response,
Expand Down Expand Up @@ -224,13 +231,16 @@ def scan_db(
return self._api.scan_db(scanner_request, background_tasks)

def answer_question(
self, run_evaluator: bool = True, question_request: QuestionRequest = None
self,
run_evaluator: bool = True,
generate_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
if os.getenv("DH_ENGINE_TIMEOUT", None):
return self._api.answer_question_with_timeout(
run_evaluator, question_request
run_evaluator, generate_csv, question_request
)
return self._api.answer_question(run_evaluator, question_request)
return self._api.answer_question(run_evaluator, generate_csv, question_request)

def get_questions(self, db_connection_id: str | None = None) -> list[Question]:
return self._api.get_questions(db_connection_id)
Expand Down Expand Up @@ -297,6 +307,12 @@ def update_response(self, response_id: str) -> Response:
"""Update a response"""
return self._api.update_response(response_id)

def get_response_file(
self, response_id: str, background_tasks: BackgroundTasks
) -> FileResponse:
"""Get a response file"""
return self._api.get_response_file(response_id, background_tasks)

def execute_sql_query(self, query: Query) -> tuple[str, dict]:
"""Executes a query on the given db_connection_id"""
return self._api.execute_sql_query(query)
Expand All @@ -305,11 +321,12 @@ def create_response(
self,
run_evaluator: bool = True,
sql_response_only: bool = False,
generate_csv: bool = False,
query_request: CreateResponseRequest = None,
) -> Response:
"""Executes a query on the given db_connection_id"""
return self._api.create_response(
run_evaluator, sql_response_only, query_request
run_evaluator, sql_response_only, generate_csv, query_request
)

def delete_golden_record(self, golden_record_id: str) -> dict:
Expand Down
24 changes: 24 additions & 0 deletions dataherald/sql_database/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,29 @@ def __getitem__(self, key: str) -> Any:
return getattr(self, key)


class FileStorage(BaseModel):
name: str
access_key_id: str
secret_access_key: str
region: str | None
bucket: str

class Config:
extra = Extra.ignore

@validator("access_key_id", "secret_access_key", pre=True, always=True)
def encrypt(cls, value: str):
fernet_encrypt = FernetEncrypt()
try:
fernet_encrypt.decrypt(value)
return value
except Exception:
return fernet_encrypt.encrypt(value)

def __getitem__(self, key: str) -> Any:
return getattr(self, key)


class SSHSettings(BaseSettings):
db_name: str | None
host: str | None
Expand Down Expand Up @@ -60,6 +83,7 @@ class DatabaseConnection(BaseModel):
path_to_credentials_file: str | None
llm_api_key: str | None = None
ssh_settings: SSHSettings | None = None
file_storage: FileStorage | None = None

@validator("uri", pre=True, always=True)
def set_uri_without_ssh(cls, v, values):
Expand Down
Loading

0 comments on commit 6b019cb

Please sign in to comment.