Skip to content

Commit

Permalink
add multi headings (#207)
Browse files Browse the repository at this point in the history
* add multi headings

* add multi headings

* add multi headings

* add multi headings
  • Loading branch information
Ceceliachenen authored Sep 9, 2024
1 parent a9af85e commit feb2836
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
73 changes: 73 additions & 0 deletions src/pai_rag/integrations/readers/pai_pdf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from PIL import Image
from rapidocr_onnxruntime import RapidOCR
from rapid_table import RapidTable
from operator import itemgetter

import logging
import os
Expand All @@ -37,6 +38,7 @@
PAGE_TABLE_SUMMARY_MAX_TOKEN = 400
IMAGE_URL_PATTERN = r"(https?://[^\s]+?[\s\w.-]*\.(jpg|jpeg|png|gif|bmp))"
IMAGE_COMBINED_PATTERN = r"!\[.*?\]\((https?://[^\s()]+|/[^()\s]+(?:\s[^()\s]*)?/\S*?\.(jpg|jpeg|png|gif|bmp))\)"
DEFAULT_HEADING_DIFF_THRESHOLD = 2


class PaiPDFReader(BaseReader):
Expand Down Expand Up @@ -277,6 +279,74 @@ def process_table(self, markdown_content, json_data):
print(f"警告:图片文件不存在 {img_path}")
return markdown_content

def post_process_multi_level_headings(self, json_data, md_content):
logger.info(
"*****************************start process headings*****************************"
)
pages_list = json_data["pdf_info"]
if not pages_list:
return md_content
text_height_min = float("inf")
text_height_max = 0
title_list = []
for page in pages_list:
page_infos = page["preproc_blocks"]
for item in page_infos:
if not item.get("lines", None) or len(item["lines"]) <= 0:
continue
x0, y0, x1, y1 = item["lines"][0]["bbox"]
content_height = y1 - y0
if item["type"] == "title":
title_height = int(content_height)
title_text = ""
for line in item["lines"]:
for span in line["spans"]:
if span["type"] == "inline_equation":
span["content"] = " $" + span["content"] + "$ "
title_text += span["content"]
title_text = title_text.replace("\\", "\\\\")
title_list.append((title_text, title_height))
elif item["type"] == "text":
if content_height < text_height_min:
text_height_min = content_height
if content_height > text_height_max:
text_height_max = content_height

sorted_list = sorted(title_list, key=itemgetter(1), reverse=True)
diff_list = [
(sorted_list[i][1] - sorted_list[i + 1][1], i)
for i in range(len(sorted_list) - 1)
]
sorted_diff = sorted(diff_list, key=itemgetter(0), reverse=True)
slice_index = []
for diff, index in sorted_diff:
# 标题差的绝对值超过2,则认为是下一级标题
# markdown 中,# 表示一级标题,## 表示二级标题,以此类推,最多有6级标题,最多能有5次切分
if diff >= DEFAULT_HEADING_DIFF_THRESHOLD and len(slice_index) <= 5:
slice_index.append(index)
slice_index.sort(reverse=True)
rank = 1
cur_index = 0
if len(slice_index) > 0:
cur_index = slice_index.pop()
for index, (title_text, title_height) in enumerate(sorted_list):
if index > cur_index:
rank += 1
if len(slice_index) > 0:
cur_index = slice_index.pop()
else:
cur_index = len(sorted_list) - 1
title_level = "#" * rank + " "
if text_height_min <= text_height_max and int(
text_height_min
) <= title_height <= int(text_height_max):
title_level = ""
old_title = "# " + title_text
new_title = title_level + title_text
md_content = re.sub(re.escape(old_title), new_title, md_content)

return md_content

def parse_pdf(
self,
pdf_path: str,
Expand Down Expand Up @@ -334,6 +404,9 @@ def parse_pdf(
pipe.pipe_parse()
content_list = pipe.pipe_mk_uni_format(temp_file_path, drop_mode="none")
md_content = pipe.pipe_mk_markdown(temp_file_path, drop_mode="none")
md_content = self.post_process_multi_level_headings(
pipe.pdf_mid_data, md_content
)
md_content = self.process_table(md_content, content_list)
new_md_content = self.replace_image_paths(pdf_name, md_content)

Expand Down
54 changes: 54 additions & 0 deletions tests/utils/test_post_process_multi_level_headings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json
from pai_rag.integrations.readers.pai_pdf_reader import PaiPDFReader

json_str = r"""{
"pdf_info": [
{
"preproc_blocks": [
{
"type": "title",
"bbox": [
59,
67,
196,
89
],
"lines": [
{
"bbox": [
57,
70,
72,
85
],
"spans": [
{
"bbox": [
57,
70,
72,
85
],
"score": 0.75,
"content": "\\leftarrow",
"type": "inline_equation"
}
]
}
]
}
]
}
]
}"""

md_str = """# $\leftarrow$ """


def test_post_process_multi_level_headings():
pdf_process = PaiPDFReader()
json_content = json.loads(json_str)
md_content_escape = pdf_process.post_process_multi_level_headings(
json_content, md_str
)
assert md_content_escape == "# $\leftarrow$ "

0 comments on commit feb2836

Please sign in to comment.