diff --git a/Install.sh b/Install.sh index dd1164a5..d3cde67a 100755 --- a/Install.sh +++ b/Install.sh @@ -94,6 +94,7 @@ sec_key=$(openssl rand -hex 16) env_content=$(echo "$env_content" | sed "s/SEC_KEY/$sec_key/g") # save .env file echo "$env_content" > .env +echo "DATA_SOURCE_FILE_DIR=/app/user_upload_files" >> .env cp Dockerfile.template Dockerfile # save env file over echo "You setting as fellows:" diff --git a/Install_CN.sh b/Install_CN.sh index 94d73b26..d72824c4 100755 --- a/Install_CN.sh +++ b/Install_CN.sh @@ -86,6 +86,7 @@ sec_key=$(openssl rand -hex 16) env_content=$(echo "$env_content" | sed "s/SEC_KEY/$sec_key/g") # save .env file,保存文件 echo "$env_content" > .env +echo "DATA_SOURCE_FILE_DIR=/app/user_upload_files" >> .env # 修改配置 pip 为国内清华源 sed 's/#CN#//g' Dockerfile.template > Dockerfile # 输出说明: diff --git a/README_CN.md b/README_CN.md index 0bb77e9b..d7c96ea6 100644 --- a/README_CN.md +++ b/README_CN.md @@ -118,7 +118,7 @@ git clone http://github.com/DeepInsight-AI/DeepBI.git OpenAI gpt-4o 支持 - 不稳定 + 支持 支持 支持 价格更便宜 @@ -128,7 +128,7 @@ git clone http://github.com/DeepInsight-AI/DeepBI.git DeepInsight gpt-4o 支持 - 不稳定 + 支持 支持 支持 目前只支持gpt-4o @@ -136,7 +136,7 @@ git clone http://github.com/DeepInsight-AI/DeepBI.git Microsoft Azure - gpt-4 + gpt-4 (自定义的名称) 支持 支持 支持 @@ -146,9 +146,9 @@ git clone http://github.com/DeepInsight-AI/DeepBI.git Microsoft Azure - gpt-4o + gpt-4o(自定义的名称) + 支持 支持 - 不稳定 支持 支持 价格更便宜 diff --git a/ai/agents/agent_instance_util.py b/ai/agents/agent_instance_util.py index 60c471d6..345506b0 100644 --- a/ai/agents/agent_instance_util.py +++ b/ai/agents/agent_instance_util.py @@ -4,8 +4,8 @@ from ai.backend.util.write_log import logger import traceback from ai.backend.util.token_util import num_tokens_from_messages -from ai.agents.prompt import CSV_ECHART_TIPS_MESS, \ - MYSQL_ECHART_TIPS_MESS, MYSQL_MATPLOTLIB_TIPS_MESS, POSTGRESQL_ECHART_TIPS_MESS, MONGODB_ECHART_TIPS_MESS +from ai.agents.prompt import EXCEL_ECHART_TIPS_MESS, \ + MYSQL_ECHART_TIPS_MESS, MYSQL_MATPLOTLIB_TIPS_MESS, POSTGRESQL_ECHART_TIPS_MESS, MONGODB_ECHART_TIPS_MESS, CSV_ECHART_TIPS_MESS from ai.agents.agentchat import (UserProxyAgent, GroupChat, AssistantAgent, GroupChatManager, PythonProxyAgent, BIProxyAgent, TaskPlannerAgent, TaskSelectorAgent, CheckAgent, ChartPresenterAgent) @@ -202,8 +202,10 @@ def get_agent_mysql_engineer(self): Hand over your code to the Executor for execution. Don’t query too much data, Try to merge query data as simply as possible. Be careful to avoid using mysql special keywords in mysql code. + If function call is needed, the function name mast be 'run_mysql_code', be sure contains no other characters. Reply "TERMINATE" in the end when everything is done. ''', + function_map={"bi_run_chart_code": BIProxyAgent.run_chart_code}, websocket=self.websocket, is_log_out=self.is_log_out, user_name=self.user_name, @@ -895,7 +897,7 @@ def get_agent_chart_planner(self): ) return chart_planner - def get_agent_python_executor(self, report_file_name=None): + def get_agent_python_executor(self, report_file_name=None, is_auto_pilot=False): python_executor = PythonProxyAgent( name="python_executor", system_message="python executor. Execute the python code and report the result.", @@ -908,6 +910,7 @@ def get_agent_python_executor(self, report_file_name=None): # incoming=self.incoming, db_id=self.db_id, report_file_name=report_file_name, + is_auto_pilot=is_auto_pilot ) return python_executor @@ -927,6 +930,10 @@ def get_agent_csv_echart_assistant(self, use_cache=True): When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible. Reply "TERMINATE" in the end when everything is done. When you find an answer, You are a report analysis, you have the knowledge and skills to turn raw data into information and insight, which can be used to make business decisions.include your analysis in your reply. + It involves data queries that truncate the data if it exceeds 1000 rows, or reduce the number of rows by summing and other means. + It involves data queries that truncate the data if it exceeds 1000 rows, or reduce the number of rows by summing and other means. + It involves data queries that truncate the data if it exceeds 1000 rows, or reduce the number of rows by summing and other means. + It involves data queries that truncate the data if it exceeds 1000 rows, or reduce the number of rows by summing and other means. """ + '\n' + self.base_csv_info + '\n' + python_base_dependency + '\n' + CSV_ECHART_TIPS_MESS, human_input_mode="NEVER", user_name=self.user_name, diff --git a/ai/agents/agentchat/bi_proxy_agent.py b/ai/agents/agentchat/bi_proxy_agent.py index 3fc5afd0..5523d36a 100644 --- a/ai/agents/agentchat/bi_proxy_agent.py +++ b/ai/agents/agentchat/bi_proxy_agent.py @@ -1197,7 +1197,7 @@ async def run_chart_code(self, chart_code_str: str): for config in str_obj: if 'columnMapping' in config and isinstance(config['columnMapping'], dict) and config[ - 'columnMapping']: + 'columnMapping']: for variable, axis in config['columnMapping'].items(): print('axis :', axis) if axis in ["x"]: @@ -1404,7 +1404,6 @@ async def tell_logger(self, log_str): "from user:[{}".format(self.user_name) + "] , " + self.name + " send a message:{}".format( send_json_str)) - except Exception as e: traceback.print_exc() logger.error("from user:[{}".format(self.user_name) + "] , " + str(e)) @@ -1512,7 +1511,8 @@ async def run_echart_code(self, chart_code_str: str, name: str): websocket = self.websocket send_json_str = json.dumps(result_message) - await websocket.send(send_json_str) + if websocket and send_json_str: + await websocket.send(send_json_str) print(str(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) + ' ---- ' + " send a message:{}".format( send_json_str)) logger.info( diff --git a/ai/agents/agentchat/groupchat.py b/ai/agents/agentchat/groupchat.py index feb98679..25c27748 100644 --- a/ai/agents/agentchat/groupchat.py +++ b/ai/agents/agentchat/groupchat.py @@ -109,7 +109,7 @@ async def run_chat( sender: Optional[Agent] = None, config: Optional[GroupChat] = None, ) -> Union[str, Dict, None]: - """Run a group chat.""" + """Run a group chat. mysql_engineer, bi_proxy, chart_presenter""" if messages is None: messages = self._oai_messages[sender] message = messages[-1] diff --git a/ai/agents/agentchat/python_proxy_agent.py b/ai/agents/agentchat/python_proxy_agent.py index 13cf8b9c..b20bd65d 100644 --- a/ai/agents/agentchat/python_proxy_agent.py +++ b/ai/agents/agentchat/python_proxy_agent.py @@ -38,6 +38,8 @@ def format_decimal(value): elif isinstance(value, int): return value return value + + def calculate_dispersion(data): x_values = np.array([point[0] for point in data]) y_values = np.array([point[1] for point in data]) @@ -51,6 +53,7 @@ def calculate_dispersion(data): ave_y = format_decimal(sum(y_values) / len(y_values)) return dispersion, correlation, (x_min, x_max), (y_min, y_max), (ave_x, ave_y) + def calculate_trendline(data): x_values = np.array([point[0] for point in data]).reshape(-1, 1) y_values = np.array([point[1] for point in data]) @@ -61,6 +64,7 @@ def calculate_trendline(data): return slope, intercept + def calculate_distance(point1, point2): return np.sqrt((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2) @@ -77,13 +81,13 @@ def count_outliers(data): median_point = np.median(data, axis=0) # 计算距离中值点的距离,并检查是否大于平均距离的200% outliers_count = sum([1 for point in data if calculate_distance(point, median_point) > avg_distance * 2]) - if outliers_count<5: + if outliers_count < 5: outliers = [format_decimal(point) for point in data if any(calculate_distance(point, median_point) > avg_distance * 2 for p in data)] else: outliers = [format_decimal(point) for point in data if any(calculate_distance(point, median_point) > avg_distance * 2 for p in data)] outliers.sort(key=lambda point: abs(calculate_distance(point, median_point)), reverse=True) outliers = outliers[:5] - return outliers_count,outliers + return outliers_count, outliers class PythonProxyAgent(Agent): @@ -125,6 +129,7 @@ def __init__( db_id: Optional = None, is_log_out: Optional[bool] = True, report_file_name: Optional[str] = None, + is_auto_pilot: Optional[bool] = False ): """ Args: @@ -206,6 +211,7 @@ def __init__( self.db_id = db_id self.is_log_out = is_log_out self.report_file_name = report_file_name + self.is_auto_pilot = is_auto_pilot delay_messages = self.delay_messages def register_reply( @@ -758,8 +764,12 @@ async def generate_code_execution_reply( if len(code_blocks) == 1 and code_blocks[0][0] != 'python': # continue - return True, f"exitcode:exitcode failed\nCode output: Please give me executable python code.\n" - if self.db_id is not None: + if self.is_auto_pilot: + # if is auto pilot, no code TERMINATE + return True, "TERMINATE" + else: + return True, f"exitcode:exitcode failed\nCode output: Please give me executable python code.\n" + if self.db_id is not None and int(self.db_id) > 0: obj = database_util.Main(self.db_id) if_suss, db_info = obj.run_decode() if if_suss: @@ -789,7 +799,7 @@ async def generate_code_execution_reply( logs = json.loads(str(logs)) except Exception as e: return True, f"exitcode:exitcode failed\nCode output: There is an error in the JSON code causing parsing errors,Please modify the JSON code for me:{traceback.format_exc()}" - for entry in logs: + for index, entry in enumerate(logs): if 'echart_name' in entry and 'echart_code' in entry: if isinstance(entry['echart_code'], str): entry['echart_code'] = json.loads(entry['entry']['echart_code']) @@ -800,14 +810,12 @@ async def generate_code_execution_reply( if series_data['type'] in ["bar", "line"]: formatted_series_data = [format_decimal(value) for value in series_data['data']] elif series_data['type'] in ["pie", "gauge", "funnel"]: - formatted_series_data = [{"name": d["name"], "value": format_decimal(d["value"])} for - d in series_data['data']] + formatted_series_data = [{"name": d["name"], "value": format_decimal(d["value"])} for d in series_data['data']] elif series_data['type'] in ['graph']: formatted_series_data = [ {'name': data_point['name'], 'symbolSize': format_decimal(data_point['symbolSize'])} for data_point in series_data['data']] - elif series_data['type'] in ["Kline", "radar", "heatmap", "scatter", "themeRiver", - 'parallel', 'effectScatter']: + elif series_data['type'] in ["Kline", "radar", "heatmap", "scatter", "themeRiver", 'parallel', 'effectScatter']: formatted_series_data = [[format_decimal(value) for value in sublist] for sublist in series_data['data']] else: @@ -816,27 +824,33 @@ async def generate_code_execution_reply( formatted_series_list.append(series_data) entry['echart_code']['series'] = formatted_series_list base_content.append(entry) - for echart in base_content: - echart = json.dumps(echart, indent=2) - agent_instance_util = AgentInstanceUtil(user_name=str(self.user_name), - delay_messages=self.delay_messages, - outgoing=self.outgoing, - incoming=self.incoming, - websocket=self.websocket - ) - bi_proxy = agent_instance_util.get_agent_bi_proxy() - # Call the interface to generate pictures - for img_str in base_content: - echart_name = img_str.get('echart_name') - echart_code = img_str.get('echart_code') - - if len(echart_code) > 0 and str(echart_code).__contains__('x'): - is_chart = True - print("echart_name : ", echart_name) - re_str = await bi_proxy.run_echart_code(str(echart_code), echart_name) + # this is autopilot + if self.is_auto_pilot: + # if have multy echarts + if index != len(logs) - 1: + continue + return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: 图像已生成,任务执行成功!图表数据:{base_content} \nTERMINATE" + + # not autopilot + if not self.is_auto_pilot: + agent_instance_util = AgentInstanceUtil(user_name=str(self.user_name), + delay_messages=self.delay_messages, + outgoing=self.outgoing, + incoming=self.incoming, + websocket=self.websocket + ) + bi_proxy = agent_instance_util.get_agent_bi_proxy() + # Call the interface to generate pictures + for img_str in base_content: + echart_name = img_str.get('echart_name') + echart_code = img_str.get('echart_code') + + if len(echart_code) > 0 and str(echart_code).__contains__('x'): + print("echart_name : ", echart_name) + await bi_proxy.run_echart_code(str(echart_code), echart_name) # 初始化一个空列表来保存每个echart的信息 - echarts_data,series_data = [],[] - xAxis_data_tag=0 + echarts_data, series_data = [], [] + xAxis_data_tag = 0 # 遍历echarts_code列表,提取数据并构造字典 for echart in base_content: echart_name = echart['echart_name'] @@ -858,7 +872,8 @@ async def generate_code_execution_reply( xAxis_data_tag = 1 if "%Y-%m" in xAxis_data: return True, f"exitcode:exitcode failed\nCode output:The SQL code query is incorrect. The query date should be %, not %%. Just for example: SELECT DATE_FORMAT(event_time, '%Y-%m-%d') is correct, but SELECT DATE_FORMAT(event_time, '%%Y -%%m-%%d') is wrong!" - if(xAxis_data_tag): + if xAxis_data_tag and not self.is_auto_pilot: + echart_dict = { 'echart_name': echart_name, 'series': series_data, @@ -868,8 +883,8 @@ async def generate_code_execution_reply( if (len(echart_dict['series']) == 1): data = echart_dict['series'][0]['data'] if (len(data) < 10): - return True,f"exitcode: {exitcode} ({exitcode2str})\nCode output: 图像已生成,任务执行成功!请直接分析图表数据:{echart_dict}" - data_set,count_tag = set(),0 + return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: 图像已生成,任务执行成功!请直接分析图表数据:{echart_dict} \nTERMINATE" + data_set, count_tag = set(), 0 for i in data: data_set.add(i[1]) if (len(data_set) == 4): @@ -879,21 +894,21 @@ async def generate_code_execution_reply( data_dict = {item: 0 for item in data_set} for i in data: data_dict[i[1]] += 1 - return True,f"exitcode: {exitcode} ({exitcode2str})\ncode output:图像已生成,生成的是散点图,该图的标题为:{echart_dict['echart_name']},描述了其相应的关系." \ - f"它的数据一共有{len(data)}个,但是它的取值集合个数较少,每一种取值对应的数量关系为{data_dict}" + return True, f"exitcode: {exitcode} ({exitcode2str})\ncode output:图像已生成,生成的是散点图,该图的标题为:{echart_dict['echart_name']},描述了其相应的关系." \ + f"它的数据一共有{len(data)}个,但是它的取值集合个数较少,每一种取值对应的数量关系为{data_dict} \nTERMINATE" correlation, dispersion, (x_min, x_max), (y_min, y_max), ( - ave_x, ave_y) = calculate_dispersion(data) + ave_x, ave_y) = calculate_dispersion(data) outliers_count, outliers = count_outliers(data) threshold = 0.6 # 相关性阈值 if correlation >= threshold: slope, intercept = calculate_trendline(data) - return True,f"exitcode: {exitcode} ({exitcode2str})\nCode output:图像已生成,生成的是散点图,该图的标题为:{echart_dict['echart_name']},描述了其相应的关系.如下是它的评价指标:\n离散程度为(标准差):{correlation},相关性评价指标为{dispersion}" \ + return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output:图像已生成,生成的是散点图,该图的标题为:{echart_dict['echart_name']},描述了其相应的关系.如下是它的评价指标:\n离散程度为(标准差):{correlation},相关性评价指标为{dispersion}" \ f"散点图的x区间范围从{(x_min, x_max)},y区间范围从{(y_min, y_max)},中值点为{(ave_x, ave_y)}。定义一个数据点到中值点的距离大于其所有点至中值点平均距离的2倍作为离群点,则离群点的个数为{outliers_count},其中,最大的是几个离群点为{outliers}" \ - "其相关性达到预设的阈值,计算得到其趋势线方程为: y = {:.2f}x+{:.2f}。(注意:这里所有的计算数据都约束到了两位有效小数)".format(slope, intercept) + "其相关性达到预设的阈值,计算得到其趋势线方程为: y = {:.2f}x+{:.2f}。(注意:这里所有的计算数据都约束到了两位有效小数)\nTERMINATE".format(slope, intercept) else: - return True,f"Code output:图像已生成,生成的是散点图,该图的标题为:{echart_dict['echart_name']},描述了其相应的关系.如下是它的评价指标:\n离散程度为(标准差):{correlation},相关性评价指标为{dispersion}" \ + return True, f"Code output:图像已生成,生成的是散点图,该图的标题为:{echart_dict['echart_name']},描述了其相应的关系.如下是它的评价指标:\n离散程度为(标准差):{correlation},相关性评价指标为{dispersion}" \ f"散点图的x区间范围从{(x_min, x_max)},y区间范围从{(y_min, y_max)},中值点为{(ave_x, ave_y)}。定义一个数据点到中值点的距离大于其所有点至中值点平均距离的2倍作为离群点,则离群点的个数为{outliers_count},其中,最大的是几个离群点为{outliers}" \ - f"但由于数据的相关性未达到阈值,即该散点图数据并没有明显的线性关系,无法计算趋势线。(注意:这里所有的计算数据都约束到了两位有效小数)" + f"但由于数据的相关性未达到阈值,即该散点图数据并没有明显的线性关系,无法计算趋势线。(注意:这里所有的计算数据都约束到了两位有效小数)\nTERMINATE" else: message = f"exitcode: {exitcode} ({exitcode2str})\nCode output:图像已生成,生成的是散点图,该图的标题为:{echart_dict['echart_name']},描述了其相应的关系,一共有{len(echart_dict['series'])}类散点数据:\n" count_class = 0 @@ -928,7 +943,7 @@ async def generate_code_execution_reply( message_data = f"名为{echart_data['name']}的数据,数据的离散程度为(标准差):{correlation},相关性评价指标为{dispersion}" \ f"散点图的x区间范围从{(x_min, x_max)},y区间范围从{(y_min, y_max)},中值点为{(ave_x, ave_y)}。定义一个数据点到中值点的距离大于其所有点至中值点平均距离的2倍作为离群点,则离群点的个数为{outliers_count},其中,最大的是几个离群点为{outliers}" \ "其相关性达到预设的阈值,计算得到其趋势线方程为: y = {:.2f}x+{:.2f}。\n".format( - slope, intercept) + slope, intercept) message += message_data continue else: @@ -937,22 +952,23 @@ async def generate_code_execution_reply( f"但由于数据的相关性未达到阈值,即该散点图数据并没有明显的线性关系,无法计算趋势线。\n" message += message_data continue - message=message+"请对每一类的数据性质都进行详细的分析" - return True,f"exitcode: {exitcode} ({exitcode2str})\n{message}" + message = message + "请对每一类的数据性质都进行详细的分析" + return True, f"exitcode: {exitcode} ({exitcode2str})\n{message}\nTERMINATE" else: echart_dict = { - 'echart_name': echart_name, - 'series': series_data, - } - count_max=1000 - echart_dict['series'][0]['data']=echart_dict['series'][0]['data'][:count_max] - if(xAxis_data_tag): - echart_dict['xAxis_data']=echart_dict['xAxis_data'][:count_max] + 'echart_name': echart_name, + 'series': series_data, + } + count_max = 1000 + echart_dict['series'][0]['data'] = echart_dict['series'][0]['data'][:count_max] + if (xAxis_data_tag): + echart_dict['xAxis_data'] = echart_dict['xAxis_data'][:count_max] + echarts_data.append(echart_dict) if (len(echart_dict['series'][0]['data']) > 999): - return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: 图像已生成,任务执行成功!但由于数据量过大,仅截取了1000条,请直接分析图表这些数据:{echarts_data}" + return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: 图像已生成,任务执行成功!但由于数据量过大,仅截取了1000条,请直接分析图表这些数据:{echarts_data}\nTERMINATE" else: - return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: 图像已生成,任务执行成功!请直接分析图表数据:{echarts_data}" + return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: 图像已生成,任务执行成功!请直接分析图表数据:{echarts_data}\nTERMINATE" code_execution_config["last_n_messages"] = last_n_messages return False, None diff --git a/ai/agents/prompt/__init__.py b/ai/agents/prompt/__init__.py index d3944d18..3b6031b7 100644 --- a/ai/agents/prompt/__init__.py +++ b/ai/agents/prompt/__init__.py @@ -1,10 +1,11 @@ from .prompt_matplotlib import MYSQL_MATPLOTLIB_TIPS_MESS -from .prompt_echarts import CSV_ECHART_TIPS_MESS, MYSQL_ECHART_TIPS_MESS, POSTGRESQL_ECHART_TIPS_MESS, MONGODB_ECHART_TIPS_MESS +from .prompt_echarts import EXCEL_ECHART_TIPS_MESS, MYSQL_ECHART_TIPS_MESS, POSTGRESQL_ECHART_TIPS_MESS, MONGODB_ECHART_TIPS_MESS, CSV_ECHART_TIPS_MESS __all__ = [ - "CSV_ECHART_TIPS_MESS", + "EXCEL_ECHART_TIPS_MESS", "MYSQL_ECHART_TIPS_MESS", "MYSQL_MATPLOTLIB_TIPS_MESS", "POSTGRESQL_ECHART_TIPS_MESS", - "MONGODB_ECHART_TIPS_MESS" + "MONGODB_ECHART_TIPS_MESS", + "CSV_ECHART_TIPS_MESS" ] diff --git a/ai/agents/prompt/prompt_echarts.py b/ai/agents/prompt/prompt_echarts.py index 1d20cee2..ca1d4a99 100644 --- a/ai/agents/prompt/prompt_echarts.py +++ b/ai/agents/prompt/prompt_echarts.py @@ -1,4 +1,4 @@ -CSV_ECHART_TIPS_MESS = """Here are some examples of generating mysql and pyecharts Code based on the given question. +EXCEL_ECHART_TIPS_MESS = """Here are some examples of generating mysql and pyecharts Code based on the given question. Please generate new one based on the data and question human asks you, import the neccessary libraries and make sure the code is correct. IMPORTANT: You need to follow the coding style, and the type of the x, y axis.Title and label are not displayed under any circumstances. In either case, the datazoom and scroll legend must be displayed. The datazoom of the x-axis must be left=1, horizontal located below the x-axis, and the datazoom of the y-axis must be right=1, vertical located on the far right side of the container. The toolbox is only shown in line charts and bar charts. The five function buttons must be located on the left side of the line chart and bar chart according to pop_left=1, pop_top=15%, and vertical. Scroll legends for line and bar charts must be placed above the chart with pop_top=1 and horizontal. The scrolling legends of other charts must be placed vertically on the right side of the chart according to pop_right=1, pop_top=15%, and avoidLabelOverlap should be turned on as much as possible. If the x-axis can be sorted according to certain rules (such as date and time size or value size), please sort by the x-axis, otherwise sort by size.But also need to focus on the column name of the uploaded tables(if exists). Generally, PyEcharts does not accept numpy.int or numpy.float, etc. It only supports built-in data type like int, float, and str. @@ -36,28 +36,14 @@ title_opts=opts.TitleOpts(title="Sales and Profit over Time",is_show=false), datazoom_opts=[ opts.DataZoomOpts( - # 设置 x 轴 dataZoom - id_="dataZoomX", - type_="slider", - xAxisIndex=[0], # 控制 x 轴 - ), - opts.DataZoomOpts( - # 设置 y 轴 dataZoom - id_="dataZoomY", - type_="slider", - yAxisIndex=[0], # 控制 y 轴 - ), - opts.DataZoomOpts( - # 设置 x 轴 dataZoom - id_="dataZoomX", - type_="inside", - xAxisIndex=[0], # 控制 x 轴 + is_show=True, type_="slider", + xaxis_index=[0], range_start=0, range_end=100, orient="horizontal", + pos_bottom="0px", pos_left="1%", pos_right="1%" ), opts.DataZoomOpts( - # 设置 y 轴 dataZoom - id_="dataZoomY", - type_="inside", - yAxisIndex=[0], # 控制 y 轴 + is_show=True, type_="slider", + yaxis_index=[0], range_start=0, range_end=100, orient="vertical", + pos_top="0px", pos_right="1%", pos_bottom="3%" ), ], legend_opts=opts.LegendOpts( @@ -117,28 +103,163 @@ title_opts=opts.TitleOpts(title="Sales over Years",is_show=false), datazoom_opts=[ opts.DataZoomOpts( - # 设置 x 轴 dataZoom - id_="dataZoomX", - type_="slider", - xAxisIndex=[0], # 控制 x 轴 + is_show=True, type_="slider", + xaxis_index=[0], range_start=0, range_end=100, orient="horizontal", + pos_bottom="0px", pos_left="1%", pos_right="1%" ), opts.DataZoomOpts( - # 设置 y 轴 dataZoom - id_="dataZoomY", - type_="slider", - yAxisIndex=[0], # 控制 y 轴 + is_show=True, type_="slider", + yaxis_index=[0], range_start=0, range_end=100, orient="vertical", + pos_top="0px", pos_right="1%", pos_bottom="3%" ), + ], + legend_opts=opts.LegendOpts( + type_="scroll", # 设置图例类型为滚动 + ), + toolbox_opts=opts.ToolboxOpts( + is_show=true, + feature={ + "dataZoom": opts.ToolBoxFeatureDataZoomOpts(), + "dataView": opts.ToolBoxFeatureDataViewOpts(), + "magicType": opts.ToolBoxFeatureMagicTypeOpts(type_=['line', 'bar','stack']), + "restore": opts.ToolBoxFeatureRestoreOpts(), + "saveAsImage": opts.ToolBoxFeatureSaveAsImageOpts(), + }, + ), + ) + # Render the chart + ret_json = bar.dump_options() + echart_code = json.loads(ret_json) + + out_put = [{"echart_name": "Sales over Years", "echart_code": echart_code}] + print(out_put) + + When using pie charts, there must be no parameter x + X axis dataZoom is set to orient: horizontal + Y-axis dataZoom is set to orient: vertical" + The output should be formatted as a JSON instance that conforms to the JSON schema below, the JSON is a list of dict, + [ + {"echart_name": "Sales over Years", "echart_code": ret_json} + {}, + {}, + ]. + """ + +CSV_ECHART_TIPS_MESS = """Here are some examples of generating pyecharts Code based on the given question. + Please generate new one based on the data and question human asks you, import the neccessary libraries and make sure the code is correct. + +IMPORTANT: You need to follow the coding style, and the type of the x, y axis.Title and label are not displayed under any circumstances. In either case, the datazoom and scroll legend must be displayed. The datazoom of the x-axis must be left=1, horizontal located below the x-axis, and the datazoom of the y-axis must be right=1, vertical located on the far right side of the container. The toolbox is only shown in line charts and bar charts. The five function buttons must be located on the left side of the line chart and bar chart according to pop_left=1, pop_top=15%, and vertical. Scroll legends for line and bar charts must be placed above the chart with pop_top=1 and horizontal. The scrolling legends of other charts must be placed vertically on the right side of the chart according to pop_right=1, pop_top=15%, and avoidLabelOverlap should be turned on as much as possible. If the x-axis can be sorted according to certain rules (such as date and time size or value size), please sort by the x-axis, otherwise sort by size.But also need to focus on the column name of the uploaded tables(if exists). Generally, PyEcharts does not accept numpy.int or numpy.float, etc. It only supports built-in data type like int, float, and str. +Pay attention to check whether the query statement in the execution code block can correctly query the data. + + + Given the same `company_sales.csv`. + Q: A line chart comparing sales and profit over time would be useful. Could you help plot it? + + import pandas as pd + from pyecharts.charts import Line + from pyecharts import options as opts + import json + + df = pd.read_csv('company_sales.csv') + year = [str(_) for _ in df["year"].to_list()] + sales = [float(_) for _ in df["sales"].to_list()] + profit = [float(_) for _ in df["profit"].to_list()] + line = Line() + # Add x-axis and y-axis data + line.add_xaxis(year) + line.add_yaxis("Sales", sales) + line.add_yaxis("Profit", profit) + line.set_global_opts( + xaxis_opts=opts.AxisOpts( + type_="category", # better use category rather than value + name="year", + min_=min(year), + max_=max(year), + ), + yaxis_opts=opts.AxisOpts( + type_="value", + name="price", + ), + title_opts=opts.TitleOpts(title="Sales and Profit over Time",is_show=false), + datazoom_opts=[ + opts.DataZoomOpts( + is_show=True, type_="slider", + xaxis_index=[0], range_start=0, range_end=100, orient="horizontal", + pos_bottom="0px", pos_left="1%", pos_right="1%" + ), + opts.DataZoomOpts( + is_show=True, type_="slider", + yaxis_index=[0], range_start=0, range_end=100, orient="vertical", + pos_top="0px", pos_right="1%", pos_bottom="3%" + ), + ], + legend_opts=opts.LegendOpts( + type_="scroll", # 设置图例类型为滚动 + ), + toolbox_opts=opts.ToolboxOpts( + is_show=true, + feature={ + "dataZoom": opts.ToolBoxFeatureDataZoomOpts(), + "dataView": opts.ToolBoxFeatureDataViewOpts(), + "magicType": opts.ToolBoxFeatureMagicTypeOpts(type_=['line', 'bar','stack']), + "restore": opts.ToolBoxFeatureRestoreOpts(), + "saveAsImage": opts.ToolBoxFeatureSaveAsImageOpts(), + }, + ), + ) + line.set_series_opts( + areastyle_opts=opts.AreaStyleOpts(opacity=0.5), + ) + ret_json = line.dump_options() + echart_code = json.loads(ret_json) + + out_put = [{"echart_name": "Sales and Profit over Time", "echart_code": echart_code}] + print(out_put) + + When using pie charts, there must be no parameter x + X axis dataZoom is set to orient: horizontal + Y-axis dataZoom is set to orient: vertical" + + Given the following data: + company_sales.csv + year sales profit expenses employees + 0 2010 100 60 40 10 + 1 2011 120 80 50 12 + 2 2012 150 90 60 14 + 3 2013 170 120 70 16 + [too long to show] + + Q: Could you help plot a bar chart with the year on the x-axis and the sales on the y-axis? + + import pandas as pd + from pyecharts.charts import Bar + from pyecharts import options as opts + df = pd.read_csv('company_sales.csv') + years = [str(_) for _ in df['year'].tolist()] + sales = [float(_) for _ in df['sales'].tolist()] + bar = Bar() + bar.add_xaxis(years) + bar.add_yaxis("Sales", sales) + bar.set_global_opts( + xaxis_opts=opts.AxisOpts( + type_="category", + name="Year", + ), + yaxis_opts=opts.AxisOpts( + type_="value", + name="Sales", + ), + title_opts=opts.TitleOpts(title="Sales over Years",is_show=false), + datazoom_opts=[ opts.DataZoomOpts( - # 设置 x 轴 dataZoom - id_="dataZoomX", - type_="inside", - xAxisIndex=[0], # 控制 x 轴 + is_show=True, type_="slider", + xaxis_index=[0], range_start=0, range_end=100, orient="horizontal", + pos_bottom="0px", pos_left="1%", pos_right="1%" ), opts.DataZoomOpts( - # 设置 y 轴 dataZoom - id_="dataZoomY", - type_="inside", - yAxisIndex=[0], # 控制 y 轴 + is_show=True, type_="slider", + yaxis_index=[0], range_start=0, range_end=100, orient="vertical", + pos_top="0px", pos_right="1%", pos_bottom="3%" ), ], legend_opts=opts.LegendOpts( @@ -162,7 +283,11 @@ out_put = [{"echart_name": "Sales over Years", "echart_code": echart_code}] print(out_put) - + When using pie charts, there must be no parameter x + X axis dataZoom is set to orient: horizontal + Y-axis dataZoom is set to orient: vertical" + Set one or more dataZoom rooms based on site requirements + Do not have any output or debug messages in the middle of the code, only output content at the end of the code The output should be formatted as a JSON instance that conforms to the JSON schema below, the JSON is a list of dict, [ {"echart_name": "Sales over Years", "echart_code": ret_json} @@ -219,12 +344,12 @@ legend_opts=opts.LegendOpts(is_show=True, type_="scroll"), # 显示滚动图例 datazoom_opts=[ opts.DataZoomOpts( - is_show=True, id_="dataZoomX", type_="slider", + is_show=True, type_="slider", xaxis_index=[0], range_start=0, range_end=100, orient="horizontal", pos_bottom="0px", pos_left="1%", pos_right="1%" ), opts.DataZoomOpts( - is_show=True, id_="dataZoomY", type_="slider", + is_show=True, type_="slider", yaxis_index=[0], range_start=0, range_end=100, orient="vertical", pos_top="0px", pos_right="1%", pos_bottom="3%" ), @@ -251,7 +376,10 @@ out_put = [{"echart_name": "Sales and Profit over Time", "echart_code": echart_code}] print(out_put) - + When using pie charts, there must be no parameter x + X axis dataZoom is set to orient: horizontal + Y-axis dataZoom is set to orient: vertical" + Q: Could you help plot a bar chart with the year on the x-axis and the sales on the y-axis? import pymysql @@ -294,10 +422,14 @@ legend_opts=opts.LegendOpts(is_show=True, type_="scroll", pos_top="1%"), datazoom_opts=[ opts.DataZoomOpts( - is_show=True, type_="slider", xaxis_index=[0], pos_left="1%", pos_bottom="0px" + is_show=True, type_="slider", + xaxis_index=[0], range_start=0, range_end=100, orient="horizontal", + pos_bottom="0px", pos_left="1%", pos_right="1%" ), opts.DataZoomOpts( - is_show=True, type_="slider", yaxis_index=[0], pos_right="1%", pos_top="0px" + is_show=True, type_="slider", + yaxis_index=[0], range_start=0, range_end=100, orient="vertical", + pos_top="0px", pos_right="1%", pos_bottom="3%" ), ], toolbox_opts=opts.ToolboxOpts( @@ -319,7 +451,10 @@ output = [{"echart_name": "Sales over Years", "echart_code": echart_code}] print(output) - + When using pie charts, there must be no parameter x + X axis dataZoom is set to orient: horizontal + Y-axis dataZoom is set to orient: vertical" + Q:Create a machine learning model to predict future sales and plot historical and forecasted sales figures note:For such prediction problems based on machine learning, the front-end page can only be displayed based on json code. If visualization is required, be sure to package the data into json code and return it! @@ -382,7 +517,10 @@ output = [{"echart_name": "Sales forecast chart","plot_data": plot_json}] print(output) - + When using pie charts, there must be no parameter x + X axis dataZoom is set to orient: horizontal + Y-axis dataZoom is set to orient: vertical" + The output should be formatted as a JSON instance that conforms to the JSON schema below, the JSON is a list of dict, [ {"echart_name": "Sales over Years", "echart_code": ret_json} @@ -446,28 +584,14 @@ title_opts=opts.TitleOpts(title="Sales and Profit over Time",is_show=false), datazoom_opts=[ opts.DataZoomOpts( - # 设置 x 轴 dataZoom - id_="dataZoomX", - type_="slider", - xAxisIndex=[0], # 控制 x 轴 - ), - opts.DataZoomOpts( - # 设置 y 轴 dataZoom - id_="dataZoomY", - type_="slider", - yAxisIndex=[0], # 控制 y 轴 - ), - opts.DataZoomOpts( - # 设置 x 轴 dataZoom - id_="dataZoomX", - type_="inside", - xAxisIndex=[0], # 控制 x 轴 + is_show=True, type_="slider", + xaxis_index=[0], range_start=0, range_end=100, orient="horizontal", + pos_bottom="0px", pos_left="1%", pos_right="1%" ), opts.DataZoomOpts( - # 设置 y 轴 dataZoom - id_="dataZoomY", - type_="inside", - yAxisIndex=[0], # 控制 y 轴 + is_show=True, type_="slider", + yaxis_index=[0], range_start=0, range_end=100, orient="vertical", + pos_top="0px", pos_right="1%", pos_bottom="3%" ), ], legend_opts=opts.LegendOpts( @@ -536,28 +660,14 @@ title_opts=opts.TitleOpts(title="Sales over Years",is_show=false), datazoom_opts=[ opts.DataZoomOpts( - # 设置 x 轴 dataZoom - id_="dataZoomX", - type_="slider", - xAxisIndex=[0], # 控制 x 轴 - ), - opts.DataZoomOpts( - # 设置 y 轴 dataZoom - id_="dataZoomY", - type_="slider", - yAxisIndex=[0], # 控制 y 轴 - ), - opts.DataZoomOpts( - # 设置 x 轴 dataZoom - id_="dataZoomX", - type_="inside", - xAxisIndex=[0], # 控制 x 轴 + is_show=True, type_="slider", + xaxis_index=[0], range_start=0, range_end=100, orient="horizontal", + pos_bottom="0px", pos_left="1%", pos_right="1%" ), opts.DataZoomOpts( - # 设置 y 轴 dataZoom - id_="dataZoomY", - type_="inside", - yAxisIndex=[0], # 控制 y 轴 + is_show=True, type_="slider", + yaxis_index=[0], range_start=0, range_end=100, orient="vertical", + pos_top="0px", pos_right="1%", pos_bottom="3%" ), ], legend_opts=opts.LegendOpts( @@ -638,28 +748,14 @@ title_opts=opts.TitleOpts(title="Sales over Years",is_show=false), datazoom_opts=[ opts.DataZoomOpts( - # 设置 x 轴 dataZoom - id_="dataZoomX", - type_="slider", - xAxisIndex=[0], # 控制 x 轴 - ), - opts.DataZoomOpts( - # 设置 y 轴 dataZoom - id_="dataZoomY", - type_="slider", - yAxisIndex=[0], # 控制 y 轴 - ), - opts.DataZoomOpts( - # 设置 x 轴 dataZoom - id_="dataZoomX", - type_="inside", - xAxisIndex=[0], # 控制 x 轴 + is_show=True, type_="slider", + xaxis_index=[0], range_start=0, range_end=100, orient="horizontal", + pos_bottom="0px", pos_left="1%", pos_right="1%" ), opts.DataZoomOpts( - # 设置 y 轴 dataZoom - id_="dataZoomY", - type_="inside", - yAxisIndex=[0], # 控制 y 轴 + is_show=True, type_="slider", + yaxis_index=[0], range_start=0, range_end=100, orient="vertical", + pos_top="0px", pos_right="1%", pos_bottom="3%" ), ], legend_opts=opts.LegendOpts( diff --git a/ai/backend/aidb/aidb.py b/ai/backend/aidb/aidb.py index 2a70166e..adc794a5 100644 --- a/ai/backend/aidb/aidb.py +++ b/ai/backend/aidb/aidb.py @@ -267,7 +267,7 @@ async def put_message(self, state=200, receiver='log', data_type=None, content=N print(send_mess) logger.info(send_mess) - async def check_api_key(self): + async def check_api_key(self, is_auto_pilot=False): # self.agent_instance_util.api_key_use = True # .token_[uid].json @@ -276,7 +276,8 @@ async def check_api_key(self): try: ApiKey, HttpProxyHost, HttpProxyPort, ApiHost, ApiType, ApiModel, LlmSetting = self.load_api_key(token_path) if ApiKey is None or len(ApiKey) == 0: - await self.put_message(500, CONFIG.talker_log, CONFIG.type_log_data, self.error_miss_key) + if not is_auto_pilot: + await self.put_message(500, CONFIG.talker_log, CONFIG.type_log_data, self.error_miss_key) return False self.agent_instance_util.set_api_key(ApiKey, ApiType, ApiHost, ApiModel, LlmSetting) @@ -300,12 +301,14 @@ async def check_api_key(self): traceback.print_exc() error_miss_key = self.generate_error_message(http_err, error_message=LanguageInfo.api_key_fail) - await self.put_message(500, CONFIG.talker_log, CONFIG.type_log_data, error_miss_key) + if not is_auto_pilot: + await self.put_message(500, CONFIG.talker_log, CONFIG.type_log_data, error_miss_key) return False except Exception as e: traceback.print_exc() logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) - await self.put_message(500, CONFIG.talker_log, CONFIG.type_log_data, self.error_miss_key) + if not is_auto_pilot: + await self.put_message(500, CONFIG.talker_log, CONFIG.type_log_data, self.error_miss_key) return False else: diff --git a/ai/backend/aidb/autopilot/__init__.py b/ai/backend/aidb/autopilot/__init__.py index 53c4990e..b9337279 100644 --- a/ai/backend/aidb/autopilot/__init__.py +++ b/ai/backend/aidb/autopilot/__init__.py @@ -2,4 +2,4 @@ from .autopilot import Autopilot from .autopilot_starrocks_api import AutopilotStarrocks from .autopilot_mongodb import AutopilotMongoDB - +from .autopilot_csv import AutopilotCSV diff --git a/ai/backend/aidb/autopilot/autopilot_csv.py b/ai/backend/aidb/autopilot/autopilot_csv.py new file mode 100644 index 00000000..d8c283a2 --- /dev/null +++ b/ai/backend/aidb/autopilot/autopilot_csv.py @@ -0,0 +1,503 @@ +# coding:utf-8 +import traceback +import json +from ai.backend.util.write_log import logger +from ai.backend.base_config import CONFIG +from ai.backend.util import database_util +from .autopilot import Autopilot +import re +import ast +import pandas as pd +import chardet +from ai.agents.agentchat import AssistantAgent +from ai.backend.util import base_util + + +max_retry_times = CONFIG.max_retry_times + + +class AutopilotCSV(Autopilot): + + async def deal_question(self, json_str, message): + """ + Process mysql data source and select the corresponding workflow + """ + result = {'state': 200, 'data': {}, 'receiver': ''} + q_sender = json_str['sender'] + q_data_type = json_str['data']['data_type'] + print('q_data_type : ', q_data_type) + q_str = json_str['data']['content'] + print('q_str: ', q_str) + + print("self.agent_instance_util.api_key_use :", self.agent_instance_util.api_key_use) + + if not self.agent_instance_util.api_key_use: + re_check = await self.check_api_key() + if not re_check: + return + + if q_sender == 'user': + if q_data_type == 'question': + # print("agent_instance_util.base_message :", self.agent_instance_util.base_message) + if self.agent_instance_util.base_message is not None: + try: + await self.start_chatgroup(q_str) + except Exception as e: + traceback.print_exc() + logger.error("from user:[{}".format(self.user_name) + "] , " + str(e)) + + result['receiver'] = 'user' + result['data']['data_type'] = 'answer' + result['data']['content'] = self.error_message_timeout + consume_output = json.dumps(result) + await self.outgoing.put(consume_output) + else: + await self.put_message(500, receiver=CONFIG.talker_user, data_type=CONFIG.type_answer, + content=self.error_miss_data) + elif q_sender == CONFIG.talker_bi: + if q_data_type == CONFIG.type_comment: + await self.check_data_csv(q_str) + elif q_data_type == CONFIG.type_comment_first: + if json_str.get('data').get('language_mode'): + q_language_mode = json_str['data']['language_mode'] + if q_language_mode == CONFIG.language_chinese or q_language_mode == CONFIG.language_english: + self.set_language_mode(q_language_mode) + self.agent_instance_util.set_language_mode(q_language_mode) + + if CONFIG.database_model == 'online': + # Set csv basic information + self.agent_instance_util.set_base_csv_info(q_str) + self.agent_instance_util.base_message = str(q_str) + else: + self.agent_instance_util.set_base_csv_info(q_str) + self.agent_instance_util.base_message = str(q_str) + + await self.get_data_desc(q_str) + elif q_data_type == CONFIG.type_comment_second: + print(CONFIG.type_comment_second) + if json_str.get('data').get('language_mode'): + q_language_mode = json_str['data']['language_mode'] + if q_language_mode == CONFIG.language_chinese or q_language_mode == CONFIG.language_english: + self.set_language_mode(q_language_mode) + self.agent_instance_util.set_language_mode(q_language_mode) + + if CONFIG.database_model == 'online': + databases_id = json_str['data']['databases_id'] + db_id = str(databases_id) + print("db_id:", db_id) + obj = database_util.Main(db_id) + if_suss, db_info = obj.run() + if if_suss: + self.agent_instance_util.base_mysql_info = ' When connecting to the database, be sure to bring the port. This is mysql database info :' + '\n' + str( + db_info) + self.agent_instance_util.base_message = str(q_str) + self.agent_instance_util.db_id = db_id + else: + self.agent_instance_util.base_message = str(q_str) + + await self.put_message(200, receiver=CONFIG.talker_bi, data_type=CONFIG.type_comment_second, + content='') + elif q_data_type == 'mysql_code' or q_data_type == 'chart_code' or q_data_type == 'delete_chart' or q_data_type == 'ask_data': + self.delay_messages['bi'][q_data_type].append(message) + print("delay_messages : ", self.delay_messages) + return + else: + print("q_sender is not right") + + async def check_data_csv(self, q_str): + """Check the data description to see if it meets the requirements""" + print("CONFIG.up_file_path : " + CONFIG.up_file_path) + if q_str.get('table_desc'): + for tb in q_str.get('table_desc'): + if len(tb.get('field_desc')) == 0: + table_name = tb.get('table_name') + + # Read file and detect encoding + csv_file = CONFIG.up_file_path + table_name + f = open(csv_file, 'rb') + # Read the file using the detected encoding + encoding = chardet.detect(f.read())['encoding'] + f.close() + if str(table_name).endswith('.csv'): + data = pd.read_csv(open(csv_file, encoding=encoding, errors='ignore')) + else: + data = pd.read_excel(csv_file) + + # Get column headers (first row of data) + column_titles = list(data.columns) + # print("column_titles : ", column_titles) + + for i in range(len(column_titles)): + tb['field_desc'].append({ + "name": column_titles[i], + "comment": '', + "in_use": 1 + }) + + await self.check_data_base(q_str) + + async def task_base(self, qustion_message): + """ Task type: mysql data analysis""" + try: + error_times = 0 + for i in range(max_retry_times): + try: + base_mysql_assistant = self.get_agent_base_mysql_assistant() + python_executor = self.agent_instance_util.get_agent_python_executor() + + await python_executor.initiate_chat( + base_mysql_assistant, + message=self.agent_instance_util.base_message + '\n' + self.question_ask + '\n' + str( + qustion_message), + ) + + answer_message = python_executor.chat_messages[base_mysql_assistant] + print("answer_message: ", answer_message) + + for i in range(len(answer_message)): + answer_mess = answer_message[len(answer_message) - 1 - i] + # print("answer_mess :", answer_mess) + if answer_mess['content'] and answer_mess['content'] != 'TERMINATE': + print("answer_mess['content'] ", answer_mess['content']) + return answer_mess['content'] + + except Exception as e: + traceback.print_exc() + logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) + error_times = error_times + 1 + + if error_times >= max_retry_times: + return self.error_message_timeout + + except Exception as e: + traceback.print_exc() + logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) + + return self.agent_instance_util.data_analysis_error + + def get_agent_base_mysql_assistant(self): + """ Basic Agent, processing mysql data source """ + base_mysql_assistant = AssistantAgent( + name="base_mysql_assistant", + system_message="""You are a helpful AI assistant. + Solve tasks using your coding and language skills. + In the following cases, suggest python code (in a python coding block) for the user to execute. + 1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself. + 2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly. + Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill. + When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user. + If you want the user to save the code in a file before executing it, put # filename: inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user. + If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try. + When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible. + Reply "TERMINATE" in the end when everything is done. + When you find an answer, You are a report analysis, you have the knowledge and skills to turn raw data into information and insight, which can be used to make business decisions.include your analysis in your reply. + Be careful to avoid using mysql special keywords in mysql code. + """ + '\n' + self.agent_instance_util.base_mysql_info + '\n' + CONFIG.python_base_dependency + '\n' + self.agent_instance_util.quesion_answer_language, + human_input_mode="NEVER", + user_name=self.user_name, + websocket=self.websocket, + llm_config={ + "config_list": self.agent_instance_util.config_list_gpt4_turbo, + "request_timeout": CONFIG.request_timeout, + }, + openai_proxy=self.agent_instance_util.openai_proxy, + ) + return base_mysql_assistant + + async def start_chatgroup(self, q_str): + + report_html_code = {} + report_html_code['report_name'] = '电商销售报告' + + report_html_code['report_question'] = [] + + question_message = await self.generate_quesiton(q_str) + print('question_message :', question_message) + + report_html_code['report_thought'] = question_message + + question_list = [] + que_num = 1 + for ques in question_message: + print('ques :', ques) + report_demand = 'i need a echart report , ' + ques['report_name'] + ':' + ques['description'] + print("report_demand: ", report_demand) + + question = {} + question['question'] = ques + que_num = que_num + 1 + if que_num > 5: + break + + answer_message, echart_code = await self.task_generate_echart(str(report_demand)) + question['answer'] = answer_message + question['echart_code'] = echart_code + report_html_code['report_question'].append(question) + + question_obj = {'question': report_demand, 'answer': answer_message, 'echart_code': ""} + question_list.append(question_obj) + + print('question_list: ', question_list) + + planner_user = self.agent_instance_util.get_agent_planner_user() + analyst = self.get_agent_analyst() + + question_supplement = 'Please make an analysis and summary in English, including which charts were generated, and briefly introduce the contents of these charts.' + if self.language_mode == CONFIG.language_chinese: + question_supplement = " 请用中文帮我对报告做最终总结,给我有价值的结论" + + await planner_user.initiate_chat( + analyst, + message=str( + question_list) + '\n' + "这是本次报告的目标: " + '\n' + q_str + '\n' + self.question_ask + '\n' + question_supplement, + ) + + last_analyst = planner_user.last_message()["content"] + + print('last_analyst : ', last_analyst) + + match = re.search( + r"\[.*\]", last_analyst.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL + ) + + if match: + json_str = match.group() + print("json_str : ", json_str) + # report_demand_list = json.loads(json_str) + + chart_code_str = str(json_str).replace("\n", "") + if len(chart_code_str) > 0: + print("chart_code_str: ", chart_code_str) + if base_util.is_json(chart_code_str): + report_demand_list = json.loads(chart_code_str) + report_html_code['report_analyst'] = report_demand_list + else: + report_demand_list = ast.literal_eval(chart_code_str) + report_html_code['report_analyst'] = report_demand_list + + print('report_html_code +++++++++++++++++ :', report_html_code) + + rendered_html = self.generate_report_template(report_html_code) + + result_message = { + 'state': 200, + 'receiver': 'autopilot', + 'data': { + 'data_type': 'autopilot_code', + 'content': rendered_html + } + } + + send_json_str = json.dumps(result_message) + await self.websocket.send(send_json_str) + + async def generate_quesiton(self, q_str): + questioner = self.get_agent_questioner() + ai_analyst = self.get_agent_ai_analyst() + + message = self.agent_instance_util.base_message + '\n' + self.question_ask + '\n\n' + q_str + print(' generate_quesiton message: ', message) + + await questioner.initiate_chat( + ai_analyst, + message=self.agent_instance_util.base_message + '\n' + self.question_ask + '\n\n' + q_str, + ) + + base_content = [] + question_message = ai_analyst.chat_messages[questioner] + print('question_message : ', question_message) + for answer_mess in question_message: + # print("answer_mess :", answer_mess) + if answer_mess['content']: + if str(answer_mess['role']) == 'assistant': + + answer_mess_content = str(answer_mess['content']).replace('\n', '') + + print("answer_mess: ", answer_mess) + match = re.search( + r"\[.*\]", answer_mess_content.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL + ) + json_str = '' + if match: + json_str = match.group() + print("json_str : ", json_str) + # report_demand_list = json.loads(json_str) + + chart_code_str = str(json_str).replace("\n", "") + if len(chart_code_str) > 0: + print("chart_code_str: ", chart_code_str) + if base_util.is_json(chart_code_str): + report_demand_list = json.loads(chart_code_str) + print("report_demand_list: ", report_demand_list) + for jstr in report_demand_list: + + # 检查列表中是否存在相同名字的对象 + name_exists = any(item['report_name'] == jstr['report_name'] for item in base_content) + + if not name_exists: + base_content.append(jstr) + print("插入成功") + else: + print("对象已存在,不重复插入") + + else: + # String instantiated as object + report_demand_list = ast.literal_eval(chart_code_str) + print("report_demand_list: ", report_demand_list) + for jstr in report_demand_list: + + # 检查列表中是否存在相同名字的对象 + name_exists = any(item['report_name'] == jstr['report_name'] for item in base_content) + + if not name_exists: + base_content.append(jstr) + print("插入成功") + else: + print("对象已存在,不重复插入") + + return base_content + + async def task_generate_echart(self, qustion_message): + try: + base_content = [] + base_mess = [] + report_demand_list = [] + json_str = "" + error_times = 0 + use_cache = True + for i in range(max_retry_times): + try: + mysql_echart_assistant = self.agent_instance_util.get_agent_mysql_echart_assistant( + use_cache=use_cache) + python_executor = self.agent_instance_util.get_agent_python_executor() + + await python_executor.initiate_chat( + mysql_echart_assistant, + message=self.agent_instance_util.base_message + '\n' + self.question_ask + '\n' + str( + qustion_message), + ) + + answer_message = mysql_echart_assistant.chat_messages[python_executor] + + for answer_mess in answer_message: + # print("answer_mess :", answer_mess) + if answer_mess['content']: + if str(answer_mess['content']).__contains__('execution succeeded'): + + answer_mess_content = str(answer_mess['content']).replace('\n', '') + + print("answer_mess: ", answer_mess) + match = re.search( + r"\[.*\]", answer_mess_content.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL + ) + + if match: + json_str = match.group() + print("json_str : ", json_str) + # report_demand_list = json.loads(json_str) + + chart_code_str = str(json_str).replace("\n", "") + if len(chart_code_str) > 0: + print("chart_code_str: ", chart_code_str) + if base_util.is_json(chart_code_str): + report_demand_list = json.loads(chart_code_str) + + print("report_demand_list: ", report_demand_list) + + for jstr in report_demand_list: + if str(jstr).__contains__('echart_name') and str(jstr).__contains__( + 'echart_code'): + base_content.append(jstr) + else: + # String instantiated as object + report_demand_list = ast.literal_eval(chart_code_str) + print("report_demand_list: ", report_demand_list) + for jstr in report_demand_list: + if str(jstr).__contains__('echart_name') and str(jstr).__contains__( + 'echart_code'): + base_content.append(jstr) + + print("base_content: ", base_content) + base_mess = [] + base_mess.append(answer_message) + break + + except Exception as e: + traceback.print_exc() + logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) + error_times = error_times + 1 + use_cache = False + + if error_times >= max_retry_times: + return self.error_message_timeout + + logger.info( + "from user:[{}".format(self.user_name) + "] , " + ",report_demand_list" + str(report_demand_list)) + + # bi_proxy = self.agent_instance_util.get_agent_bi_proxy() + is_chart = False + # Call the interface to generate pictures + last_echart_code = None + for img_str in base_content: + echart_name = img_str.get('echart_name') + echart_code = img_str.get('echart_code') + + if len(echart_code) > 0 and str(echart_code).__contains__('x'): + is_chart = True + print("echart_name : ", echart_name) + last_echart_code = json.dumps(echart_code) + # re_str = await bi_proxy.run_echart_code(str(echart_code), echart_name) + # base_mess.append(re_str) + base_mess = [] + base_mess.append(img_str) + + error_times = 0 + for i in range(max_retry_times): + try: + planner_user = self.agent_instance_util.get_agent_planner_user() + analyst = self.get_agent_analyst() + + question_supplement = 'Please make an analysis and summary in English, including which charts were generated, and briefly introduce the contents of these charts.' + if self.language_mode == CONFIG.language_chinese: + question_supplement = qustion_message + ". 请用中文帮我分析以上的报表数据,给我有价值的结论" + print("question_supplement : ", question_supplement) + + await planner_user.initiate_chat( + analyst, + message=str( + base_mess) + '\n' + self.question_ask + '\n' + question_supplement, + ) + + answer_message = planner_user.last_message()["content"] + + match = re.search( + r"\[.*\]", answer_message.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL + ) + + if match: + json_str = match.group() + print("json_str : ", json_str) + # report_demand_list = json.loads(json_str) + + chart_code_str = str(json_str).replace("\n", "") + if len(chart_code_str) > 0: + print("chart_code_str: ", chart_code_str) + if base_util.is_json(chart_code_str): + report_demand_list = json.loads(chart_code_str) + return report_demand_list, last_echart_code + else: + report_demand_list = ast.literal_eval(chart_code_str) + return report_demand_list, last_echart_code + + except Exception as e: + traceback.print_exc() + logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) + error_times = error_times + 1 + + if error_times == max_retry_times: + return self.error_message_timeout + + except Exception as e: + traceback.print_exc() + logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) + return self.agent_instance_util.data_analysis_error diff --git a/ai/backend/aidb/autopilot/autopilot_csv_api.py b/ai/backend/aidb/autopilot/autopilot_csv_api.py new file mode 100644 index 00000000..0ccf51ff --- /dev/null +++ b/ai/backend/aidb/autopilot/autopilot_csv_api.py @@ -0,0 +1,424 @@ +# coding:utf-8 +import traceback +import json +from ai.backend.util.write_log import logger +from ai.backend.base_config import CONFIG +from ai.backend.util import database_util +from .autopilot import Autopilot +import re +import os +import ast +from ai.backend.util import base_util +from ai.backend.util.db.postgresql_report import PsgReport +from ai.agents.agentchat import Questioner, AssistantAgent +from ai.backend.language_info import LanguageInfo + +max_retry_times = CONFIG.max_retry_times +max_report_question = 5 + + +class AutopilotCSV(Autopilot): + + async def deal_question(self, json_str): + """ + Process mysql data source and select the corresponding workflow + """ + + report_file_name = CONFIG.up_file_path + json_str['file_name'] + report_id = json_str['report_id'] + + with open(report_file_name, 'r') as file: + data = json.load(file) + db_comment = data['db_comment'] + db_id = str(data['databases_id']) + q_str = data['report_desc'] + q_name = data['report_name'] + + csv_local_paths = [] + for item in db_comment['table_desc']: + csv_file = item['table_name'] + csv_local_path = CONFIG.up_file_path + csv_file + if os.path.exists(csv_local_path): + table = { + "table_name": item['table_name'], + item['table_name'] + "_file_path": csv_local_path + } + csv_local_paths.append(csv_local_path) + + print("self.agent_instance_util.api_key_use :", self.agent_instance_util.api_key_use) + + if not self.agent_instance_util.api_key_use: + re_check = await self.check_api_key(is_auto_pilot=True) + if not re_check: + return + + if len(csv_local_paths) > 0: + self.agent_instance_util.base_csv_info = ' csv file path is :' + str( + csv_local_path) + self.agent_instance_util.set_base_message(db_comment) + self.agent_instance_util.db_id = db_id + # start chat + try: + psg = PsgReport() + re = psg.select_data(report_id) + if re is not None and len(re) > 0: + print('need deal task') + data_to_update = (1, report_id) + update_state = psg.update_data(data_to_update) + if update_state: + await self.start_chatgroup(q_str, report_file_name, report_id, q_name) + else: + print('no task') + + except Exception as e: + traceback.print_exc() + # update report status + data_to_update = (-1, report_id) + PsgReport().update_data(data_to_update) + + async def task_base(self, qustion_message): + return qustion_message + + def get_agent_base_mysql_assistant(self): + """ Basic Agent, processing mysql data source """ + base_mysql_assistant = AssistantAgent( + name="base_mysql_assistant", + system_message="""You are a helpful AI assistant. + Solve tasks using your coding and language skills. + In the following cases, suggest python code (in a python coding block) for the user to execute. + 1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself. + 2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly. + Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill. + When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user. + If you want the user to save the code in a file before executing it, put # filename: inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user. + If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try. + When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible. + Reply "TERMINATE" in the end when everything is done. + When you find an answer, You are a report analysis, you have the knowledge and skills to turn raw data into information and insight, which can be used to make business decisions.include your analysis in your reply. + Be careful to avoid using mysql special keywords in mysql code. + """ + '\n' + self.agent_instance_util.base_mysql_info + '\n' + CONFIG.python_base_dependency + '\n' + self.agent_instance_util.quesion_answer_language, + human_input_mode="NEVER", + user_name=self.user_name, + websocket=self.websocket, + llm_config={ + "config_list": self.agent_instance_util.config_list_gpt4_turbo, + "request_timeout": CONFIG.request_timeout, + }, + openai_proxy=self.agent_instance_util.openai_proxy, + ) + return base_mysql_assistant + + async def start_chatgroup(self, q_str, report_file_name, report_id, q_name): + + report_html_code = {} + try: + report_html_code['report_name'] = q_name + report_html_code['report_author'] = 'DeepBI' + + report_html_code['report_question'] = [] + report_html_code['report_thought'] = [] + report_html_code['report_analyst'] = [] + + question_message = await self.generate_quesiton(q_str, report_file_name) + + print('question_message :', question_message) + + report_html_code['report_thought'] = question_message + + question_list = [] + que_num = 0 + for ques in question_message: + print('ques :', ques) + report_demand = 'i need a echart report , ' + ques['report_name'] + ' : ' + ques['description'] + # report_demand = ' 10-1= ?? ' + print("report_demand: ", report_demand) + + question = {} + question['question'] = ques + que_num = que_num + 1 + if que_num > max_report_question: + break + + answer_message, echart_code = await self.task_generate_echart(str(report_demand), report_file_name) + if answer_message is not None and echart_code is not None: + question['answer'] = answer_message + question['echart_code'] = echart_code + report_html_code['report_question'].append(question) + + question_obj = {'question': report_demand, 'answer': answer_message, 'echart_code': ""} + question_list.append(question_obj) + + print('question_list: ', question_list) + + planner_user = self.agent_instance_util.get_agent_planner_user(report_file_name=report_file_name) + analyst = self.get_agent_analyst(report_file_name=report_file_name) + + question_supplement = " Make a final summary of the report and give me valuable conclusions. " + + await planner_user.initiate_chat( + analyst, + message=str( + question_list) + '\n' + " This is the goal of this report: " + '\n' + q_str + '\n' + LanguageInfo.question_ask + '\n' + question_supplement, + ) + + last_analyst = planner_user.last_message()["content"] + + print('last_analyst : ', last_analyst) + + match = re.search( + r"\[.*\]", last_analyst.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL + ) + + if match: + json_str = match.group() + print("json_str : ", json_str) + # report_demand_list = json.loads(json_str) + + last_analyst_str = str(json_str).replace("\n", "") + if len(last_analyst_str) > 0: + print("chart_code_str: ", last_analyst_str) + if base_util.is_json(last_analyst_str): + report_demand_list = json.loads(last_analyst_str) + report_html_code['report_analyst'] = report_demand_list + else: + report_demand_list = ast.literal_eval(last_analyst_str) + report_html_code['report_analyst'] = report_demand_list + + except Exception as e: + traceback.print_exc() + data_to_update = (-1, report_id) + PsgReport().update_data(data_to_update) + else: + # 更新数据 + data_to_update = (2, report_id) + PsgReport().update_data(data_to_update) + + print('report_html_code +++++++++++++++++ :', report_html_code) + if len(report_html_code['report_thought']) > 0: + rendered_html = self.generate_report_template(report_html_code) + with open(report_file_name, 'r') as file: + data = json.load(file) + + # 修改其中的值 + data['html_code'] = rendered_html + # if self.log_list is not None: + # data['chat_log'] = self.log_list + + # 将更改后的内容写回文件 + with open(report_file_name, 'w') as file: + json.dump(data, file, indent=4) + + async def generate_quesiton(self, q_str, report_file_name): + questioner = self.get_agent_questioner(report_file_name) + ai_analyst = self.get_agent_ai_analyst(report_file_name) + + message = self.agent_instance_util.base_message + '\n' + LanguageInfo.question_ask + '\n\n' + q_str + print(' generate_quesiton message: ', message) + + await questioner.initiate_chat( + ai_analyst, + message=message, + ) + + base_content = [] + question_message = ai_analyst.chat_messages[questioner] + print('question_message : ', question_message) + for answer_mess in question_message: + # print("answer_mess :", answer_mess) + if answer_mess['content']: + if str(answer_mess['role']) == 'assistant': + + answer_mess_content = str(answer_mess['content']).replace('\n', '') + + print("answer_mess: ", answer_mess) + match = re.search( + r"\[.*\]", answer_mess_content.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL + ) + json_str = '' + if match: + json_str = match.group() + print("json_str : ", json_str) + # report_demand_list = json.loads(json_str) + + chart_code_str = str(json_str).replace("\n", "") + if len(chart_code_str) > 0: + print("chart_code_str: ", chart_code_str) + report_demand_list = None + if base_util.is_json(chart_code_str): + report_demand_list = json.loads(chart_code_str) + else: + # String instantiated as object + report_demand_list = ast.literal_eval(chart_code_str) + print("report_demand_list: ", report_demand_list) + if report_demand_list is not None: + for jstr in report_demand_list: + # 检查列表中是否存在相同名字的对象 + name_exists = any(item['report_name'] == jstr['report_name'] for item in base_content) + + if not name_exists: + if len(base_content) > max_report_question: + break + base_content.append(jstr) + # print("插入成功") + else: + print("对象已存在,不重复插入") + return base_content + + async def task_generate_echart(self, qustion_message, report_file_name): + try: + base_content = [] + base_mess = [] + report_demand_list = [] + json_str = "" + error_times = 0 + use_cache = True + for i in range(max_retry_times): + try: + csv_echart_assistant = self.agent_instance_util.get_agent_csv_echart_assistant( + use_cache=use_cache) + python_executor = self.agent_instance_util.get_agent_python_executor( + report_file_name=report_file_name, is_auto_pilot=True) + + await python_executor.initiate_chat( + csv_echart_assistant, + message=self.agent_instance_util.base_message + '\n' + LanguageInfo.question_ask + '\n' + str( + qustion_message), + ) + + answer_message = csv_echart_assistant.chat_messages[python_executor] + + for answer_mess in answer_message: + # print("answer_mess :", answer_mess) + if answer_mess['content']: + if str(answer_mess['content']).__contains__('execution succeeded'): + + answer_mess_content = str(answer_mess['content']).replace('\n', '') + + print("answer_mess: ", answer_mess) + match = re.search( + r"\[.*\]", answer_mess_content.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL + ) + + if match: + json_str = match.group() + print("json_str : ", json_str) + # report_demand_list = json.loads(json_str) + + chart_code_str = str(json_str).replace("\n", "") + if len(chart_code_str) > 0: + print("chart_code_str: ", chart_code_str) + if base_util.is_json(chart_code_str): + report_demand_list = json.loads(chart_code_str) + + print("report_demand_list: ", report_demand_list) + + for jstr in report_demand_list: + if str(jstr).__contains__('echart_name') and str(jstr).__contains__( + 'echart_code') and jstr not in base_content: + base_content.append(jstr) + else: + # String instantiated as object + report_demand_list = ast.literal_eval(chart_code_str) + print("report_demand_list: ", report_demand_list) + for jstr in report_demand_list: + if str(jstr).__contains__('echart_name') and str(jstr).__contains__( + 'echart_code') and jstr not in base_content: + base_content.append(jstr) + + print("base_content: ", base_content) + base_mess = [] + base_mess.append(answer_message) + break + + except Exception as e: + traceback.print_exc() + logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) + error_times = error_times + 1 + use_cache = False + + if error_times >= max_retry_times: + return self.error_message_timeout + + logger.info( + "from user:[{}".format(self.user_name) + "] , " + ",report_demand_list" + str(report_demand_list)) + + # bi_proxy = self.agent_instance_util.get_agent_bi_proxy() + is_chart = False + # Call the interface to generate pictures + last_echart_code = [] + for img_str in base_content: + echart_name = img_str.get('echart_name') + echart_code = img_str.get('echart_code') + + if len(echart_code) > 0 and str(echart_code).__contains__('x'): + is_chart = True + print("echart_name : ", echart_name) + # 格式化echart_code + try: + if base_util.is_json(str(echart_code)): + json_str = json.loads(str(echart_code)) + json_str = json.dumps(json_str) + last_echart_code.append(json_str) + else: + str_obj = ast.literal_eval(str(echart_code)) + json_str = json.dumps(str_obj) + last_echart_code.append(json_str) + except Exception as e: + traceback.print_exc() + logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) + last_echart_code.append(json.dumps(echart_code)) + + # re_str = await bi_proxy.run_echart_code(str(echart_code), echart_name) + # base_mess.append(re_str) + base_mess = [] + base_mess.append(img_str) + + error_times = 0 + for i in range(max_retry_times): + try: + planner_user = self.agent_instance_util.get_agent_planner_user(report_file_name=report_file_name) + analyst = self.get_agent_analyst(report_file_name=report_file_name) + + question_supplement = qustion_message + '\n' + "Analyze the above report data and give me valuable conclusions" + print("question_supplement : ", question_supplement) + + await planner_user.initiate_chat( + analyst, + message=str(base_mess) + '\n' + LanguageInfo.question_ask + '\n' + question_supplement, + ) + + answer_message = planner_user.last_message()["content"] + + match = re.search( + r"\[.*\]", answer_message.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL + ) + + if match: + json_str = match.group() + print("json_str : ", json_str) + # report_demand_list = json.loads(json_str) + + chart_code_str = str(json_str).replace("\n", "") + if len(chart_code_str) > 0: + print("chart_code_str: ", chart_code_str) + if base_util.is_json(chart_code_str): + report_demand_list = json.loads(chart_code_str) + return report_demand_list, last_echart_code + else: + report_demand_list = ast.literal_eval(chart_code_str) + return report_demand_list, last_echart_code + + except Exception as e: + traceback.print_exc() + logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) + error_times = error_times + 1 + + if error_times == max_retry_times: + print(self.error_message_timeout) + return None, None + + except Exception as e: + traceback.print_exc() + logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) + print(self.agent_instance_util.data_analysis_error) + return None, None diff --git a/ai/backend/aidb/autopilot/autopilot_mongodb_api.py b/ai/backend/aidb/autopilot/autopilot_mongodb_api.py index 04aaadd6..f0892597 100644 --- a/ai/backend/aidb/autopilot/autopilot_mongodb_api.py +++ b/ai/backend/aidb/autopilot/autopilot_mongodb_api.py @@ -36,7 +36,7 @@ async def deal_question(self, json_str): print("self.agent_instance_util.api_key_use :", self.agent_instance_util.api_key_use) if not self.agent_instance_util.api_key_use: - re_check = await self.check_api_key() + re_check = await self.check_api_key(is_auto_pilot=True) if not re_check: return @@ -274,7 +274,7 @@ async def task_generate_echart(self, qustion_message, report_file_name): use_cache=use_cache, report_file_name=report_file_name) python_executor = self.agent_instance_util.get_agent_python_executor( - report_file_name=report_file_name) + report_file_name=report_file_name, is_auto_pilot=True) # new db await python_executor.initiate_chat( mongodb_echart_assistant, @@ -311,7 +311,7 @@ async def task_generate_echart(self, qustion_message, report_file_name): for jstr in report_demand_list: if str(jstr).__contains__('echart_name') and str(jstr).__contains__( - 'echart_code'): + 'echart_code'): base_content.append(jstr) else: # String instantiated as object @@ -319,7 +319,7 @@ async def task_generate_echart(self, qustion_message, report_file_name): print("report_demand_list: ", report_demand_list) for jstr in report_demand_list: if str(jstr).__contains__('echart_name') and str(jstr).__contains__( - 'echart_code'): + 'echart_code'): base_content.append(jstr) print("base_content: ", base_content) @@ -327,7 +327,6 @@ async def task_generate_echart(self, qustion_message, report_file_name): base_mess.append(answer_message) break - except Exception as e: traceback.print_exc() logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) diff --git a/ai/backend/aidb/autopilot/autopilot_mysql_api.py b/ai/backend/aidb/autopilot/autopilot_mysql_api.py index b28fbdf5..c2ec2798 100644 --- a/ai/backend/aidb/autopilot/autopilot_mysql_api.py +++ b/ai/backend/aidb/autopilot/autopilot_mysql_api.py @@ -36,7 +36,7 @@ async def deal_question(self, json_str): print("self.agent_instance_util.api_key_use :", self.agent_instance_util.api_key_use) if not self.agent_instance_util.api_key_use: - re_check = await self.check_api_key() + re_check = await self.check_api_key(is_auto_pilot=True) if not re_check: return @@ -270,7 +270,7 @@ async def task_generate_echart(self, qustion_message, report_file_name): mysql_echart_assistant = self.agent_instance_util.get_agent_mysql_echart_assistant( use_cache=use_cache, report_file_name=report_file_name) python_executor = self.agent_instance_util.get_agent_python_executor( - report_file_name=report_file_name) + report_file_name=report_file_name, is_auto_pilot=True) await python_executor.initiate_chat( mysql_echart_assistant, @@ -307,7 +307,7 @@ async def task_generate_echart(self, qustion_message, report_file_name): for jstr in report_demand_list: if str(jstr).__contains__('echart_name') and str(jstr).__contains__( - 'echart_code'): + 'echart_code'): base_content.append(jstr) else: # String instantiated as object @@ -315,7 +315,7 @@ async def task_generate_echart(self, qustion_message, report_file_name): print("report_demand_list: ", report_demand_list) for jstr in report_demand_list: if str(jstr).__contains__('echart_name') and str(jstr).__contains__( - 'echart_code'): + 'echart_code'): base_content.append(jstr) print("base_content: ", base_content) @@ -323,7 +323,6 @@ async def task_generate_echart(self, qustion_message, report_file_name): base_mess.append(answer_message) break - except Exception as e: traceback.print_exc() logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) diff --git a/ai/backend/aidb/autopilot/autopilot_starrocks_api.py b/ai/backend/aidb/autopilot/autopilot_starrocks_api.py index b077cd1a..3a9d34c8 100644 --- a/ai/backend/aidb/autopilot/autopilot_starrocks_api.py +++ b/ai/backend/aidb/autopilot/autopilot_starrocks_api.py @@ -15,6 +15,7 @@ max_retry_times = CONFIG.max_retry_times max_report_question = 5 + class AutopilotStarrocks(Autopilot): async def deal_question(self, json_str): @@ -35,7 +36,7 @@ async def deal_question(self, json_str): print("self.agent_instance_util.api_key_use :", self.agent_instance_util.api_key_use) if not self.agent_instance_util.api_key_use: - re_check = await self.check_api_key() + re_check = await self.check_api_key(is_auto_pilot=True) if not re_check: return @@ -155,7 +156,6 @@ async def start_chatgroup(self, q_str, report_file_name, report_id, q_name): question_message = await self.generate_quesiton(q_str, report_file_name) - print('question_message :', question_message) report_html_code['report_thought'] = question_message @@ -296,7 +296,6 @@ async def generate_quesiton(self, q_str, report_file_name): else: print("对象已存在,不重复插入") - else: # String instantiated as object report_demand_list = ast.literal_eval(chart_code_str) @@ -331,7 +330,7 @@ async def task_generate_echart(self, qustion_message, report_file_name): # mysql_echart_assistant = self.agent_instance_util.get_agent_mysql_echart_assistant35( # use_cache=use_cache, report_file_name=report_file_name) python_executor = self.agent_instance_util.get_agent_python_executor( - report_file_name=report_file_name) + report_file_name=report_file_name, is_auto_pilot=True) await python_executor.initiate_chat( mysql_echart_assistant, @@ -368,7 +367,7 @@ async def task_generate_echart(self, qustion_message, report_file_name): for jstr in report_demand_list: if str(jstr).__contains__('echart_name') and str(jstr).__contains__( - 'echart_code'): + 'echart_code'): base_content.append(jstr) else: # String instantiated as object @@ -376,7 +375,7 @@ async def task_generate_echart(self, qustion_message, report_file_name): print("report_demand_list: ", report_demand_list) for jstr in report_demand_list: if str(jstr).__contains__('echart_name') and str(jstr).__contains__( - 'echart_code'): + 'echart_code'): base_content.append(jstr) print("base_content: ", base_content) @@ -384,7 +383,6 @@ async def task_generate_echart(self, qustion_message, report_file_name): base_mess.append(answer_message) break - except Exception as e: traceback.print_exc() logger.error("from user:[{}".format(self.user_name) + "] , " + "error: " + str(e)) diff --git a/ai/backend/aidb/autopilot/html_template/report_2.html b/ai/backend/aidb/autopilot/html_template/report_2.html index 2190a361..63897274 100644 --- a/ai/backend/aidb/autopilot/html_template/report_2.html +++ b/ai/backend/aidb/autopilot/html_template/report_2.html @@ -1,5 +1,6 @@ + {{ report_name }} @@ -82,47 +83,59 @@ + -
-

{{ report_name }}

-
----Powered by {{ report_author }}
+
+

{{ report_name }}

+
----Powered by {{ report_author }}
-
    -

    【分析思路】

    - {% for thought in report_thought %} +
      +

      【分析思路】

      + {% for thought in report_thought %}
    • {{ thought['report_name'] }}: {{ thought['description'] }}
    • - {% endfor %} -
    - {% for question_data in report_question %} + {% endfor %} +
+ {% for question_data in report_question %} + {% set chart_index = loop.index %}

{{ question_data['question']['report_name'] }}

    【数据解读】

    {% for answer_item in question_data['answer'] %} -
  • {{ answer_item['analysis_item'] }}: {{ answer_item['description'] }}
  • +
  • {{ answer_item['analysis_item'] }}: {{ answer_item['description'] }}
  • {% endfor %}
- -
- -
- -
- {% endfor %} - -
-

报告总结

-
    - {% for summary in report_analyst %} + + {% if question_data['echart_code'] is string %} +
    + + {% elif question_data['echart_code'] is sequence %} + {% for echart_code in question_data['echart_code'] %} +
    + + {% endfor %} + {% endif %} +
+ {% endfor %} + +
+

报告总结

+
    + {% for summary in report_analyst %}
  • {{ summary['analysis_item'] }}: {{ summary['description'] }}
  • - {% endfor %} -
+ {% endfor %} + +
-
- + + \ No newline at end of file diff --git a/ai/backend/aidb/autopilot/html_template_en/report_2.html b/ai/backend/aidb/autopilot/html_template_en/report_2.html index 76e5b32b..acf526ae 100644 --- a/ai/backend/aidb/autopilot/html_template_en/report_2.html +++ b/ai/backend/aidb/autopilot/html_template_en/report_2.html @@ -1,5 +1,6 @@ + {{ report_name }} @@ -82,47 +83,58 @@ + -
-

{{ report_name }}

-
----Powered by {{ report_author }}
+
+

{{ report_name }}

+
----Powered by {{ report_author }}
-
    -

    【Analysis Framework】

    - {% for thought in report_thought %} +
      +

      【Analysis Framework】

      + {% for thought in report_thought %}
    • {{ thought['report_name'] }}: {{ thought['description'] }}
    • - {% endfor %} -
    - {% for question_data in report_question %} + {% endfor %} +
+ {% for question_data in report_question %} + {% set chart_index = loop.index %}

{{ question_data['question']['report_name'] }}

    -

    【Data Interpretation】

    +

    【数据解读】

    {% for answer_item in question_data['answer'] %} -
  • {{ answer_item['analysis_item'] }}: {{ answer_item['description'] }}
  • +
  • {{ answer_item['analysis_item'] }}: {{ answer_item['description'] }}
  • {% endfor %}
- -
- -
- -
- {% endfor %} - -
-

Report Summary

-
    - {% for summary in report_analyst %} + + {% if question_data['echart_code'] is string %} +
    + + {% elif question_data['echart_code'] is sequence %} + {% for echart_code in question_data['echart_code'] %} +
    + + {% endfor %} + {% endif %} +
+ {% endfor %} +
+

Report Summary

+
    + {% for summary in report_analyst %}
  • {{ summary['analysis_item'] }}: {{ summary['description'] }}
  • - {% endfor %} -
+ {% endfor %} + +
-
- + + \ No newline at end of file diff --git a/ai/backend/app2.py b/ai/backend/app2.py index 9bad69ff..8a5731de 100644 --- a/ai/backend/app2.py +++ b/ai/backend/app2.py @@ -5,6 +5,7 @@ from ai.backend.aidb.autopilot.autopilot_mysql_api import AutopilotMysql from ai.backend.aidb.autopilot.autopilot_starrocks_api import AutopilotStarrocks from ai.backend.aidb.autopilot.autopilot_mongodb_api import AutopilotMongoDB +from ai.backend.aidb.autopilot.autopilot_csv_api import AutopilotCSV from ai.backend.base_config import CONFIG from ai.backend.aidb.dashboard.prettify_dashboard import PrettifyDashboard @@ -52,6 +53,9 @@ async def process_data(self, data): autopilot_mongodb = AutopilotMongoDB(chat_class) # new db await autopilot_mongodb.deal_question(json_str) + elif "csv" == databases_type: + autopilot_csv = AutopilotCSV(chat_class) + await autopilot_csv.deal_question(json_str) else: autopilotMysql = AutopilotMysql(chat_class) await autopilotMysql.deal_question(json_str) diff --git a/ai/backend/base_config.py b/ai/backend/base_config.py index 72566ab7..be475fe2 100644 --- a/ai/backend/base_config.py +++ b/ai/backend/base_config.py @@ -28,7 +28,7 @@ def load_conf(self): self.if_hide_sensitive = False - self.python_base_dependency = """python installed dependency environment: pymysql, pandas, mysql-connector-python, pyecharts, sklearn, psycopg2, sqlalchemy,pymongo""" + self.python_base_dependency = """python installed dependency environment: pymysql, pandas, mysql-connector-python, pyecharts, sklearn, psycopg2, sqlalchemy, pymongo""" self.max_token_num = 7500 diff --git a/ai/backend/chat_task.py b/ai/backend/chat_task.py index 9b7acebe..b266e4f5 100644 --- a/ai/backend/chat_task.py +++ b/ai/backend/chat_task.py @@ -9,7 +9,7 @@ from ai.backend.aidb.report import ReportMysql, ReportPostgresql, ReportStarrocks, ReportMongoDB from ai.backend.aidb.analysis import AnalysisMysql, AnalysisCsv, AnalysisPostgresql, AnalysisStarrocks, AnalysisMongoDB from ai.backend.aidb import AIDB -from ai.backend.aidb.autopilot import AutopilotMysql, AutopilotMongoDB +from ai.backend.aidb.autopilot import AutopilotMysql, AutopilotMongoDB, AutopilotCSV message_pool: ChatMemoryManager = ChatMemoryManager(name="message_pool") @@ -59,6 +59,7 @@ def __init__(self, websocket, path): self.autopilotMysql = AutopilotMysql(self) self.autopilotMongoDB = AutopilotMongoDB(self) + self.autopilotCSV = AutopilotCSV(self) async def get_message(self): """ Receive messages and put them into the [pending] message queue """ @@ -152,6 +153,8 @@ async def consume(self): await self.autopilotMysql.deal_question(json_str, message) elif q_database == 'mongodb': await self.autopilotMongoDB.deal_question(json_str, message) + elif q_database == 'csv': + await self.autopilotCSV.deal_question(json_str, message) else: result['state'] = 500 diff --git a/ai/backend/util/base_util.py b/ai/backend/util/base_util.py index 4eb22468..e88ef64b 100644 --- a/ai/backend/util/base_util.py +++ b/ai/backend/util/base_util.py @@ -4,7 +4,7 @@ from pathlib import Path # from bi.settings import DATA_SOURCE_FILE_DIR as docker_data_source_file_dir -docker_data_source_file_dir = "./user_upload_files" +docker_data_source_file_dir = "/app/user_upload_files" host_secret = 'tNGoVq0KpQ4LKr5WMIZM' db_secret = 'aCyBIffJv2OSW5dOvREL' @@ -28,12 +28,8 @@ def get_upload_path(): else: # 获取当前工作目录的路径 current_directory = Path.cwd() - # 获取当前工作目录的父级目录 - # parent_directory = current_directory.parent data_source_file_dir = str(current_directory) + '/user_upload_files/' - - # data_source_file_dir = '/app/user_upload_files/' return data_source_file_dir diff --git a/bi/__init__.py b/bi/__init__.py index 37d59a1d..dcf86f6d 100644 --- a/bi/__init__.py +++ b/bi/__init__.py @@ -15,8 +15,8 @@ from .query_runner import import_query_runners from .destinations import import_destinations -__version__ = "2.0.2" -__DeepBI_version__ = "2.0.2" +__version__ = "2.0.3" +__DeepBI_version__ = "2.0.3" def setup_logging(): diff --git a/bi/handlers/data_report_file.py b/bi/handlers/data_report_file.py index 2ef8e088..98f2f4c7 100644 --- a/bi/handlers/data_report_file.py +++ b/bi/handlers/data_report_file.py @@ -103,7 +103,6 @@ def post(self): with open(file_name, 'w') as file: json.dump(data, file) - pass result = models.DataReportFile( user_id=user_id, org_id=self.current_org.id, diff --git a/bi/settings/__init__.py b/bi/settings/__init__.py index cf2ec96c..f54f4ee1 100644 --- a/bi/settings/__init__.py +++ b/bi/settings/__init__.py @@ -494,4 +494,3 @@ def email_server_is_configured(): BLOCKED_DOMAINS = set_from_string(os.environ.get("DEEPBI_BLOCKED_DOMAINS", "")) AI_WEB_SERVER = os.environ.get('AI_WEB_SERVER', '127.0.0.1:8340') - diff --git a/client/dist_source/llm.json b/client/dist_source/llm.json index 5ff24031..cc1016a6 100644 --- a/client/dist_source/llm.json +++ b/client/dist_source/llm.json @@ -51,5 +51,15 @@ "ApiSecret", "Model" ] + }, + "Azure": { + "ApiKey": "", + "Model": "", + "ApiHost": "", + "required": [ + "ApiKey", + "Model", + "ApiHost" + ] } } \ No newline at end of file diff --git a/version.md b/version.md index 1b877177..a75f1e39 100644 --- a/version.md +++ b/version.md @@ -1,4 +1,12 @@ # Version +### 2.0.3 +- The problem of generating pictures in the process of modifying report dialog . +- The problem of not finding function in modifying auxiliary data analysis dialog. +- Added the csv automatic data analysis function. +- Modify the exe startup problem + + + ### 2.0.2 - Adapt an existing function call to the new openai version of tools