diff --git a/CHANGELOG.md b/CHANGELOG.md index cf113d6b..142d0a7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Fixed +- `pw.xpacks.llm.document_store.DocumentStore` no longer requires `_metadata` column in the input table. + ## [0.16.0] - 2024-11-29 ### Added diff --git a/python/pathway/xpacks/llm/document_store.py b/python/pathway/xpacks/llm/document_store.py index 33632ae4..a412d339 100644 --- a/python/pathway/xpacks/llm/document_store.py +++ b/python/pathway/xpacks/llm/document_store.py @@ -7,6 +7,7 @@ multiple methods for querying. """ +import warnings from collections.abc import Callable from typing import TYPE_CHECKING, Iterable, TypeAlias @@ -35,6 +36,11 @@ class DocumentStore: Args: - docs: pathway tables typically coming out of connectors which contain source documents. + The table needs to contain a ``data`` column of type bytes - usually by setting + format of the connector to be ``"raw""``. Optionally, it can contain + a ``_metadata`` column containing a dictionary with metadata which is then + used for filters. Some connectors offer ``with_metadata`` argument for returning + ``_metadata`` column. - retriever_factory: factory for building an index, which will be provided texts by the ``DocumentStore``. - parser: callable that parses file contents into a list of documents. @@ -262,22 +268,32 @@ def split_doc(text: str, metadata: pw.Json) -> list[dict]: return self._apply_processor(post_processed_docs, split_doc) + def _clean_tables(self, docs: pw.Table | Iterable[pw.Table]) -> list[pw.Table]: + if isinstance(docs, pw.Table): + docs = [docs] + + def _clean_table(doc: pw.Table) -> pw.Table: + if "_metadata" not in doc.column_names(): + warnings.warn( + f"`_metadata` column is not present in Table {doc}. Filtering will not work for this Table" + ) + doc = doc.with_columns(_metadata=dict()) + + return doc.select(pw.this.data, pw.this._metadata) + + return [_clean_table(doc) for doc in docs] + def build_pipeline(self): - if isinstance(self.docs, pw.Table): - docs = self.docs - else: - docs_list = list(self.docs) - if len(docs_list) == 0: - raise ValueError( - """Please provide at least one data source, e.g. read files from disk: + cleaned_tables = self._clean_tables(self.docs) + if len(cleaned_tables) == 0: + raise ValueError( + """Please provide at least one data source, e.g. read files from disk: pw.io.fs.read('./sample_docs', format='binary', mode='static', with_metadata=True) """ - ) - elif len(docs_list) == 1: - (docs,) = self.docs - else: - docs = docs_list[0].concat_reindex(*docs_list[1:]) + ) + + docs = pw.Table.concat_reindex(*cleaned_tables) self.input_docs = docs.select(text=pw.this.data, metadata=pw.this._metadata) self.parsed_docs = self.parse_documents(self.input_docs) diff --git a/python/pathway/xpacks/llm/tests/test_document_store.py b/python/pathway/xpacks/llm/tests/test_document_store.py index 1a0ff932..59c372e6 100644 --- a/python/pathway/xpacks/llm/tests/test_document_store.py +++ b/python/pathway/xpacks/llm/tests/test_document_store.py @@ -594,3 +594,72 @@ def fake_embeddings_model(x: str) -> list[float]: (query_result,) = val.as_list() # extract the single match assert isinstance(query_result, dict) assert query_result["text"] # just check if some text was returned + + +def test_docstore_on_table_without_metadata(): + @pw.udf + def fake_embeddings_model(x: str) -> list[float]: + return [1.0, 1.0, 0.0] + + docs = pw.debug.table_from_rows( + schema=pw.schema_from_types(data=bytes), + rows=[("test".encode("utf-8"),)], + ) + + index_factory = BruteForceKnnFactory( + dimensions=3, + reserved_space=10, + embedder=fake_embeddings_model, + metric=BruteForceKnnMetricKind.COS, + ) + + document_store = DocumentStore(docs, retriever_factory=index_factory) + + retrieve_queries = pw.debug.table_from_rows( + schema=DocumentStore.RetrieveQuerySchema, + rows=[("Foo", 1, None, None)], + ) + + retrieve_outputs = document_store.retrieve_query(retrieve_queries) + _, rows = pw.debug.table_to_dicts(retrieve_outputs) + (val,) = rows["result"].values() + assert isinstance(val, pw.Json) + (query_result,) = val.as_list() # extract the single match + assert isinstance(query_result, dict) + assert query_result["text"] == "test" # just check if some text was returned + + +def test_docstore_on_tables_with_different_schemas(): + @pw.udf + def fake_embeddings_model(x: str) -> list[float]: + return [1.0, 1.0, 0.0] + + docs1 = pw.debug.table_from_rows( + schema=pw.schema_from_types(data=bytes), + rows=[("test".encode("utf-8"),)], + ) + + docs2 = pw.debug.table_from_rows( + schema=pw.schema_from_types(data=bytes, _metadata=dict, val=int), + rows=[("test2".encode("utf-8"), {}, 1)], + ) + + index_factory = BruteForceKnnFactory( + dimensions=3, + reserved_space=10, + embedder=fake_embeddings_model, + metric=BruteForceKnnMetricKind.COS, + ) + + document_store = DocumentStore([docs1, docs2], retriever_factory=index_factory) + + retrieve_queries = pw.debug.table_from_rows( + schema=DocumentStore.RetrieveQuerySchema, + rows=[("Foo", 2, None, None)], + ) + + retrieve_outputs = document_store.retrieve_query(retrieve_queries) + _, rows = pw.debug.table_to_dicts(retrieve_outputs) + (val,) = rows["result"].values() + assert isinstance(val, pw.Json) + assert len(val.as_list()) == 2