From 87b2c74c9036aba1cbef805490ffd0ad39501a03 Mon Sep 17 00:00:00 2001 From: Pierre-Loic Doulcet Date: Mon, 29 Jan 2024 16:54:09 -0800 Subject: [PATCH] Pierre/node parser md (#10340) * Change the markdown element to better serialize table. Fix: Make sure that all node have content to prevent empty nodes * Improve table representation * Change assert on markdown test LLAMA2 from 214 to 208 nodes --- .../node_parser/relational/base_element.py | 37 ++++++++++++++++--- tests/node_parser/test_markdown_element.py | 2 +- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/llama_index/node_parser/relational/base_element.py b/llama_index/node_parser/relational/base_element.py index 29fbaaa305f8c..8126defba31f4 100644 --- a/llama_index/node_parser/relational/base_element.py +++ b/llama_index/node_parser/relational/base_element.py @@ -209,7 +209,6 @@ def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]: nodes = [] cur_text_el_buffer: List[str] = [] - for element in elements: if element.type == "table": # flush text buffer @@ -224,17 +223,42 @@ def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]: table_df = cast(pd.DataFrame, element.table) table_id = element.id + "_table" table_ref_id = element.id + "_table_ref" - # TODO: figure out what to do with columns - # NOTE: right now they're excluded from embedding + col_schema = "\n\n".join([str(col) for col in table_output.columns]) + + # We build a summary of the table containing the extracted summary, and a description of the columns + table_summary = ( + str(table_output.summary) + ", with the following columns:\n" + ) + + for col in table_output.columns: + table_summary += f"- {col.col_name}: {col.summary}\n" + index_node = IndexNode( - text=str(table_output.summary), + text=table_summary, metadata={"col_schema": col_schema}, excluded_embed_metadata_keys=["col_schema"], id_=table_ref_id, index_id=table_id, ) - table_str = table_df.to_string() + + # We serialize the table as markdown as it allow better accuracy + # We do not use the table_df.to_markdown() method as it generate + # a table with a token hngry format. + table_md = "|" + for col_name, col in table_df.items(): + table_md += f"{col_name}|" + table_md += "\n|" + for col_name, col in table_df.items(): + table_md += f"---|" + table_md += "\n|" + for row in table_df.itertuples(): + table_md += "|" + for col in row[1:]: + table_md += f"{col}|" + table_md += "\n" + + table_str = table_summary + "\n" + table_md text_node = TextNode( text=table_str, id_=table_id, @@ -250,4 +274,5 @@ def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]: nodes.extend(cur_text_nodes) cur_text_el_buffer = [] - return nodes + # remove empty nodes + return [node for node in nodes if len(node.text) > 0] diff --git a/tests/node_parser/test_markdown_element.py b/tests/node_parser/test_markdown_element.py index 8c33961e797db..96b593109ae3b 100644 --- a/tests/node_parser/test_markdown_element.py +++ b/tests/node_parser/test_markdown_element.py @@ -2645,4 +2645,4 @@ def test_llama2_bad_md() -> None: node_parser = MarkdownElementNodeParser(llm=MockLLM()) nodes = node_parser.get_nodes_from_documents([test_data]) - assert len(nodes) == 214 + assert len(nodes) == 208