From f49ac17f53e52aef194ec72fbe5a80ad26d0c7a1 Mon Sep 17 00:00:00 2001 From: Pierre-Loic Doulcet Date: Sat, 2 Dec 2023 18:00:40 +0100 Subject: [PATCH] Fix issue with irregular table (#9130) (#9249) * Fix an inssue where unstructured will extract table with varying row columns count, leading to panda crashing #9130 * Consider table with wrong layout (likely html positioning with table) as Text. Add a test to verify that node_parser unstruct_element handle table correctly * Update the irregular table to work also when the irregularity happen beyond line 0-1. Add a test for table that contain empty cell, Add a test for table that were not all line contian the same number of column. --------- Co-authored-by: Pierre --- .../relational/unstructured_element.py | 20 +++- tests/node_parser/test_unstructured.py | 103 ++++++++++++++++++ 2 files changed, 120 insertions(+), 3 deletions(-) create mode 100644 tests/node_parser/test_unstructured.py diff --git a/llama_index/node_parser/relational/unstructured_element.py b/llama_index/node_parser/relational/unstructured_element.py index de3a782e30381..ad601def64b20 100644 --- a/llama_index/node_parser/relational/unstructured_element.py +++ b/llama_index/node_parser/relational/unstructured_element.py @@ -62,13 +62,22 @@ def html_to_df(html_str: str) -> pd.DataFrame: cols = [c.text.strip() if c.text is not None else "" for c in cols] data.append(cols) + """ Check if the table is empty""" + if len(data) == 0: + return None + + """ Check if the all rows have the same number of columns """ + if not all(len(row) == len(data[0]) for row in data): + return None + return pd.DataFrame(data[1:], columns=data[0]) def filter_table(table_element: Any) -> bool: """Filter table.""" table_df = html_to_df(table_element.metadata.text_as_html) - return len(table_df) > 1 and len(table_df.columns) > 1 + """ check if table_df is not None, has more than one row, and more than one column """ + return table_df is not None and not table_df.empty and len(table_df.columns) > 1 def extract_elements( @@ -91,7 +100,13 @@ def extract_elements( ) ) else: - pass + """if not a table, keep it as Text as we don't want to loose context""" + from unstructured.documents.html import HTMLText + + newElement = HTMLText(str(element), tag=element.tag) + output_els.append( + Element(id=f"id_{idx}", type="text", element=newElement) + ) else: output_els.append(Element(id=f"id_{idx}", type="text", element=element)) return output_els @@ -261,7 +276,6 @@ def get_nodes_from_node(self, node: TextNode) -> List[BaseNode]: table_elements = get_table_elements(elements) # extract summaries over table elements extract_table_summaries(table_elements, self.llm, self.summary_query_str) - # convert into nodes # will return a list of Nodes and Index Nodes return get_nodes_from_elements(elements) diff --git a/tests/node_parser/test_unstructured.py b/tests/node_parser/test_unstructured.py new file mode 100644 index 0000000000000..85f0d29752f35 --- /dev/null +++ b/tests/node_parser/test_unstructured.py @@ -0,0 +1,103 @@ +import pytest +from llama_index.node_parser.relational.unstructured_element import ( + UnstructuredElementNodeParser, +) +from llama_index.schema import Document, IndexNode, TextNode + +try: + from unstructured.partition.html import partition_html +except ImportError: + partition_html = None # type: ignore + +try: + from lxml import html +except ImportError: + html = None # type: ignore + + +@pytest.mark.skipif(partition_html is None, reason="unstructured not installed") +@pytest.mark.skipif(html is None, reason="lxml not installed") +def test_html_table_extraction() -> None: + test_data = Document( + text=""" + + + + Test Page + + + + + + + + + + +
My title center
Design Website like its 2000Yeah!
+

+ Test paragraph +

+ + + + + + + + + + + + + + + + + +
YearBenefits
202012,000
202110,000
2022130,000
+ + + + + + + + + + + + + + + + + + + +
YearBenefits
202012,000
202110,000202110,000
2022130,000
+ + + + + + + + + +
agegroup
yellow
+ + + """ + ) + + node_parser = UnstructuredElementNodeParser() + + nodes = node_parser.get_nodes_from_documents([test_data]) + print(len(nodes)) + print(nodes) + assert len(nodes) == 4 + assert isinstance(nodes[0], TextNode) + assert isinstance(nodes[1], IndexNode) + assert isinstance(nodes[2], TextNode) + assert isinstance(nodes[3], TextNode)