-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
171 lines (136 loc) · 5.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import base64
import hashlib
import os
import queue
import threading
import time
from typing import Optional
import uvicorn
from aiohttp.web_fileresponse import FileResponse
from fastapi import FastAPI, File, UploadFile, BackgroundTasks, Form, Query, HTTPException
from pydantic import BaseModel
from starlette.middleware.cors import CORSMiddleware
from starlette.status import HTTP_404_NOT_FOUND
import Pdf2MD
from SQLiteManager import SQLiteORM
# 数据库文件路径
db_file = 'minerU-server.db'
# 创建表的SQL语句
sql_create_users_table = """
CREATE TABLE IF NOT EXISTS file_task (
task_id TEXT PRIMARY KEY,
file_path TEXT NOT NULL,
md_file_path TEXT,
status TEXT
);
"""
# 实例化 SQLiteManager
db = SQLiteORM(db_file)
# 创建表
db.create_table(sql_create_users_table)
db.close()
current_script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 设置允许的origins来源
allow_credentials=True,
allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。
allow_headers=["*"]) # 允许跨域的headers,可以用来鉴别来源等作用。
@app.post("/upload")
async def handle(background_task: BackgroundTasks, file: UploadFile = File(...), user_name: Optional[str] = Form(...)):
try:
doc_id = hashlib.md5((file.filename + user_name).encode('utf-8')).hexdigest()
# 文件保存路径
file_path = os.path.join(current_script_dir, f"{doc_id}")
if not os.path.exists(file_path):
os.makedirs(file_path)
md_file_path = file_path
file_path = os.path.join(file_path, f"{file.filename}") # 假设保存为文本文件
# 保存文件
# 重置文件指针到开头,以便可以读取文件并保存
await file.seek(0)
# 保存文件
with open(file_path, "wb") as f:
while chunk := await file.read(1024): # 读取文件块
f.write(chunk)
dbM = SQLiteORM(db_file)
dbM.create("file_task",
{"task_id": doc_id, "file_path": file_path, "md_file_path": md_file_path, "status": "waiting"})
dbM.close()
return {"message": "success", 'task_id': doc_id, 'filename': file.filename}
except Exception as e:
print(e)
return {"message": str(e), 'task_id': None, 'filename': file.filename}
@app.get("/download/{task_id}")
async def download_file(task_id: str):
# 查询数据库以获取文件路径
db = SQLiteORM(db_file)
result = db.read("file_task", {"task_id": task_id})
db.close()
if not result:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Task ID not found")
status = result[0][3]
if status == "success":
md_file_path = result[0][2] # 假设 file_path 是查询结果的第三个元素
file_path = result[0][1] # 假设 file_path 是查询结果的第二个元素
filename = os.path.basename(file_path)
dest_name = os.path.splitext(filename)[0] # 移除扩展名
# 拼接新的文件路径
new_md_file_path = os.path.join(md_file_path, f"{dest_name}.md")
# 读取文件并转换为 base64
with open(new_md_file_path, "rb") as file:
file_data = file.read()
base64_data = base64.b64encode(file_data).decode('utf-8') # 转换为 base64 并解码为字符串
# 返回 base64 编码的文件数据
return {"message": "success", 'task_id': task_id, 'filename': result[0][1], "data": base64_data}
elif status == "processing":
return {"message": "processing", 'task_id': task_id, 'filename': result[0][1]}
elif status == "waiting":
return {"message": "waiting", 'task_id': task_id, 'filename': result[0][1]}
else:
return {"message": "error", 'task_id': task_id, 'filename': result[0][1]}
class AddLink(BaseModel):
link: str
# 定义一个队列
q = queue.Queue(maxsize=20) # 可选参数 maxsize 设置队列的最大长度
def producer():
dbP = SQLiteORM(db_file)
while True:
waitingList = dbP.read("file_task", {"status": "waiting"})
for waiting in waitingList:
q.put(waiting)
dbP.update("file_task", {"status": "processing"}, {"task_id": waiting[0]})
time.sleep(5) # 模拟生产数据的时间
def consumer():
dbC = SQLiteORM(db_file)
while True:
item = q.get() # 当队列为空时,get() 方法会阻塞直到有数据可用
print('Consumer 消费了', item)
try:
# 处理PDF转换任务
Pdf2MD.processPdf2MD(item)
# 获取文件名并移除扩展名
filename = os.path.basename(item[1])
dest_name, extension = os.path.splitext(filename)
# 拼接完整文件路径
base_dir = os.path.dirname(item[1]) # 获取item[1]的目录
full_path = os.path.join(base_dir, dest_name, 'auto') # 拼接完整路径
dbC.update("file_task", {"status": "success", "md_file_path": full_path}, {"task_id": item[0]})
except Exception as e:
# 处理异常
dbC.update("file_task", {"status": "error"}, {"task_id": item[0]})
print(f"An error occurred while processing PDF to MD: {e}")
q.task_done() # 通知队列此任务已完成
time.sleep(1) # 模拟消费数据的时间
@app.on_event("startup")
async def startup_event():
# 创建生产者线程
producer_thread = threading.Thread(target=producer, daemon=True)
producer_thread.start()
# 创建消费者线程
consumer_thread = threading.Thread(target=consumer, daemon=True)
consumer_thread.start()
if __name__ == '__main__':
uvicorn.run('main:app', host="0.0.0.0", reload=True)