forked from wxywb/history_rag
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cli.py
121 lines (106 loc) · 4.57 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from executor import MilvusExecutor
from executor import PipelineExecutor
import yaml
from easydict import EasyDict
import argparse
def read_yaml_config(file_path):
with open(file_path, "r") as file:
config_data = yaml.safe_load(file)
return EasyDict(config_data)
class CommandLine():
def __init__(self, config_path):
self._mode = None
self._executor = None
self.config_path = config_path
def show_start_info(self):
with open('./start_info.txt') as fw:
print(fw.read())
def run(self):
self.show_start_info()
while True:
conf = read_yaml_config(self.config_path)
print('(rag) 选择[milvus|pipeline]方案')
mode = input('(rag) ')
if mode == 'milvus':
self._executor = MilvusExecutor(conf)
print('(rag) milvus模式已选择')
print(' 1.使用`build data/history_24/baihuasanguozhi.txt`来进行知识库构建。')
print(' 2.已有索引可以使用`ask`进行提问, `-d`参数以debug模式进入。')
print(' 3.删除已有索引可以使用`remove baihuasanguozhi.txt`。')
self._mode = 'milvus'
break
elif mode == 'pipeline':
self._executor = PipelineExecutor(conf)
print('(rag) pipeline模式已选择, 使用`build https://raw.githubusercontent.com/wxywb/history_rag/master/data/history_24/baihuasanguozhi.txt`来进行知识库构建。')
print(' 1.使用`build https://raw.githubusercontent.com/wxywb/history_rag/master/data/history_24/baihuasanguozhi.txt`来进行知识库构建。')
print(' 2.已有索引可以使用`ask`进行提问, `-d`参数以debug模式进入。')
print(' 3.删除已有索引可以使用`remove baihuasanguozhi.txt`。')
self._mode = 'pipeline'
break
elif mode == 'quit':
self._exit()
break
else:
print(f'(rag) {mode}不是已知方案,选择[milvus|pipeline]方案,或者quit退出。')
assert self._mode != None
while True:
command_text = input("(rag) ")
self.parse_input(command_text)
def parse_input(self, text):
commands = text.split(' ')
if commands[0] == 'build':
if len(commands) == 3:
if commands[1] == '-overwrite':
print(commands)
self.build_index(path=commands[2], overwrite=True)
else:
print('(rag) build仅支持 `-overwrite`参数')
elif len(commands) == 2:
self.build_index(path=commands[1], overwrite=False)
elif commands[0] == 'ask':
if len(commands) == 2:
if commands[1] == '-d':
self._executor.set_debug(True)
else:
print('(rag) ask仅支持 `-d`参数 ')
else:
self._executor.set_debug(False)
self.question_answer()
elif commands[0] == 'remove':
if len(commands) != 2:
print('(rag) remove只接受1个参数。')
self._executor.delete_file(commands[1])
elif 'quit' in commands[0]:
self._exit()
else:
print('(rag) 只有[build|ask|remove|quit]中的操作, 请重新尝试。')
def query(self, question):
ans = self._executor.query(question)
print(ans)
print('+---------------------------------------------------------------------------------------------------------------------+')
print('\n')
def build_index(self, path, overwrite):
self._executor.build_index(path, overwrite)
print('(rag) 索引构建完成')
def remove(self, filename):
self._executor.delete_file(filename)
def question_answer(self):
self._executor.build_query_engine()
while True:
question = input("(rag) Question: ")
if question == 'quit':
print('(rag) 退出问答')
break
elif question == "":
continue
else:
pass
self.query(question)
def _exit(self):
exit()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, help='Path to the configuration file', default='cfgs/config.yaml')
args = parser.parse_args()
cli = CommandLine(args.cfg)
cli.run()