From e959ebee4d68cbd9b87b5796e11edab003449137 Mon Sep 17 00:00:00 2001 From: AmenRa Date: Wed, 23 Aug 2023 20:23:37 +0200 Subject: [PATCH] v0.2.2 --- README.md | 8 +- changelog.md | 7 + docs/advanced_retriever.md | 281 ++++++++ docs/dense_retriever.md | 4 +- docs/filters.md | 76 +++ docs/hybrid_retriever.md | 4 +- docs/sparse_retriever.md | 6 +- retriv/base_retriever.py | 7 +- retriv/dense_retriever/dense_retriever.py | 2 +- retriv/experimental/__init__.py | 4 + retriv/experimental/advanced_retriever.py | 576 +++++++++++++++++ retriv/hybrid_retriever.py | 2 +- retriv/paths.py | 4 + .../sparse_retrieval_models/bm25.py | 17 +- .../sparse_retrieval_models/tf_idf.py | 17 +- retriv/sparse_retriever/sparse_retriever.py | 2 +- retriv/utils/numba_utils.py | 7 +- setup.py | 4 +- .../advanced_retriever_test.py | 605 ++++++++++++++++++ .../numba_utils_test.py | 7 + 20 files changed, 1616 insertions(+), 24 deletions(-) create mode 100755 docs/advanced_retriever.md create mode 100644 docs/filters.md create mode 100644 retriv/experimental/__init__.py create mode 100755 retriv/experimental/advanced_retriever.py create mode 100755 tests/advanced_retriever/advanced_retriever_test.py rename tests/{sparse_retriever => }/numba_utils_test.py (92%) diff --git a/README.md b/README.md index c244437..dce18b3 100755 --- a/README.md +++ b/README.md @@ -24,6 +24,10 @@

## 🔥 News +- [August 23, 2023] `retriv` 0.2.2 is out! +This release adds _experimental_ support for multi-field documents and filters. +Please, refer to [Advanced Retriever](https://github.com/AmenRa/retriv/blob/main/docs/advanced_retriever.md) documentation. + - [February 18, 2023] `retriv` 0.2.0 is out! This release adds support for Dense and Hybrid Retrieval. Dense Retrieval leverages the semantic similarity of the queries' and documents' vector representations, which can be computed directly by `retriv` or imported from other sources. @@ -51,6 +55,8 @@ Click [here](https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md Click [here](https://github.com/AmenRa/retriv/blob/main/docs/dense_retriever.md) to learn more. - [Hybrid Retriever](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md): an hybrid retriever is a retrieval model built on top of a sparse and a dense retriever. Click [here](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md) to learn more. +- [Advanced Retriever](https://github.com/AmenRa/retriv/blob/main/docs/advanced_retriever.md): an advanced sparse retriever supporting filters. This is and experimental feature. +Click [here](https://github.com/AmenRa/retriv/blob/main/docs/advanced_retriever.md) to learn more. ### Unified Search Interface All the supported retrievers share the same search interface: @@ -101,7 +107,7 @@ se = SearchEngine("new-index").index(collection) se.search("witches masses") ``` Output: -```python +```json [ { "id": "doc_2", diff --git a/changelog.md b/changelog.md index ff5b77b..ae5c4b5 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.2.2] - 2023-08-23 +### Added +- Added `advanced_retriever.py`. + +### Changed +- Text preprocessing refactoring. + ## [0.2.1] - 2023-05-16 ### Added - Added doc strings to `sparse_retriver.py`. diff --git a/docs/advanced_retriever.md b/docs/advanced_retriever.md new file mode 100755 index 0000000..aa5ab5f --- /dev/null +++ b/docs/advanced_retriever.md @@ -0,0 +1,281 @@ +# Advanced Retriever + +⚠️ This is an experimental feature. + +The Advanced Retriever is a searcher based on lexical matching and search filters. +It supports [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) and [TF-IDF](https://en.wikipedia.org/wiki/Tf–idf) as the [Sparse Retriever](https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md) and provides the same resources for multi-lingual [text pre-processing](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md). In addition, it supports search filters, i.e., a set of rules that can be used to filter out documents from the search results. + +In the following, we show how to build a search engine employing an advanced retriever, index a document collection, and search it. + +## Schema + +The first step to building an Advanced Retriever is to define the `schema` of document collection. +The `schema` is a dictionary describing the documents' `fields` and their `data types`. +Based on the `data types`, search `filters` can be defined and applied to the search results. + +[retriv](https://github.com/AmenRa/retriv) supports the following data types: +- __id:__ field used for the document IDs. +- __text:__ text field used for lexical matching. +- __number:__ numeric value. +- __bool:__ boolean value (True or False). +- __keyword:__ string or number representing a keyword or a category. +- __keywords:__ list of keywords. + +An example of `schema` for a collection of books is shown below. +NB: At the time of writing, [retriv](https://github.com/AmenRa/retriv) supports only one text field per schema. +Therefore, the `content` field is used for both the title and the abstract of the books. + +```json +schema = { + "isbn": "id", + "content": "text", + "year": "number", + "is_english": "bool", + "author": "keyword", + "genres": "keywords", +} +``` + +## Build + +The Advanced Retriever provides several options to tailor its functioning to you preferences, as shown below. + +```python +from retriv.experimental import AdvancedRetriever + +ar = AdvancedRetriever( + schema=schema, + index_name="new-index", + model="bm25", + min_df=1, + tokenizer="whitespace", + stemmer="english", + stopwords="english", + do_lowercasing=True, + do_ampersand_normalization=True, + do_special_chars_normalization=True, + do_acronyms_normalization=True, + do_punctuation_removal=True, +) +``` + +- `schema`: the documents' schema. +- `index_name`: [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index. +- `model`: defines the retrieval model to use for searching (`bm25` or `tf-idf`). +- `min_df`: terms that appear in less than `min_df` documents will be ignored. +If integer, the parameter indicates the absolute count. +If float, it represents a proportion of documents. +- `tokenizer`: [tokenizer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable tokenizer or disable tokenization by setting the parameter to `None`. +- `stemmer`: [stemmer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable stemmer or disable stemming setting the parameter to `None`. +- `stopwords`: [stopwords](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to remove during preprocessing. You can pass a custom stop-word list or disable stop-words removal by setting the parameter to `None`. +- `do_lowercasing`: whether to lowercase texts. +- `do_ampersand_normalization`: whether to convert `&` in `and` during pre-processing. +- `do_special_chars_normalization`: whether to remove special characters for letters, e.g., `übermensch` → `ubermensch`. +- `do_acronyms_normalization`: whether to remove full stop symbols from acronyms without splitting them in multiple words, e.g., `P.C.I.` → `PCI`. +- `do_punctuation_removal`: whether to remove punctuation. + +__Note:__ text pre-processing is equally applied to documents during indexing and to queries at search time. + +## Index + +### Create +You can index a document collection from JSONl, CSV, or TSV files. +CSV and TSV files must have a header. +[retriv](https://github.com/AmenRa/retriv) automatically infers the file kind, so there's no need to specify it. +Use the `callback` parameter to pass a function for converting your documents in the format defined by your `schema` on the fly. +Indexes are automatically persisted on disk at the end of the process. + +```python +ar = ar.index_file( + path="path/to/collection", # File kind is automatically inferred + show_progress=True, # Default value + callback=lambda doc: { # Callback defaults to None. + "id": doc["id"], + "text": doc["title"] + ". " + doc["text"], + ... + ) +``` + +### Load +```python +ar = AdvancedRetriever.load("index-name") +``` + +### Delete +```python +AdvancedRetriever.delete("index-name") +``` + +## Search + +### Query & Filters + +Advanced Retriever search query can be either a string or a dictionary. +In the former case, the string is used as the query text and no filters are applied. +In the latter case, the dictionary defines the query text and the filters to apply to the search results. If the query text is omitted from the dictionary, documents matching the filters will be returned. + +[retriv](https://github.com/AmenRa/retriv) supports two way of filtering the search results (`where` and `where_not`) and several type-specific operators. + +- `where` means that only the documents matching the filter will be considered during search. +- `where_not` means that the documents matching the filter will be ignored during search. + +Below we describe the effects of the supported operators for each data type and way of filtering. + +#### Where + +| Field Type | Operator | Value | Meaning | +| ---------- | --------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | +| number | `eq` | number | Only the documents whose field value is **equal to** the provided value will be considered during search. | +| number | `gt` | number | Only the documents whose field value is **greater than** the provided value will be considered during search. | +| number | `gte` | number | Only the documents whose field value is **greater or equal to** the provided value will be considered during search. | +| number | `lt` | number | Only the documents whose field value is **less than** the provided value will be considered during search. | +| number | `lte` | number | Only the documents whose field value is **less or equal to** the provided value will be considered during search. | +| number | `between` | number | Only the documents whose field value is **between** the provided values (included) will be considered during search. | +| bool | | True / False | Only the documents whose field value is **equal to** the provided value will be considered during search. | +| keyword | | any value / list of values | Only the documents whose field value is **equal to** the provided value or **among** the provided values will be considered during search. | +| keywords | `or` | any value / list of values | Only the documents whose field value is **contains** the provided value or **contains one of** the provided values will be considered during search. | +| keywords | `and` | any value / list of values | Only the documents whose field value **contains all** the provided values will be considered during search. | + +Query example: +```python +query = { + "text": "search terms", + "where": { + "numeric_field_name": ("gte", 1970), + "boolean_field_name": True, + "keyword_field_name": "kw_1", + "keywords_field_name": ("or", ["kws_23", "kws_666"]), + } +} +``` + +Alternatively, you can omit the `where` key and use the following syntax: +```python +query = { + "text": "search terms", + "numeric_field_name": ("gte", 1970), + "boolean_field_name": True, + "keyword_field_name": "kw_1", + "keywords_field_name": ("or", ["kws_23", "kws_666"]), +} +``` + + +#### Where not + +| Field Type | Operator | Value | Meaning | +| ---------- | --------- | -------------------------- | ------------------------------------------------------------------------------------------------------------------------------ | +| number | `eq` | number | The documents whose field value is **equal to** the provided value will be ignored. | +| number | `gt` | number | The documents whose field value is **greater than** the provided value will be ignored. | +| number | `gte` | number | The documents whose field value is **greater or equal to** the provided value will be ignored. | +| number | `lt` | number | The documents whose field value is **less than** the provided value will be ignored. | +| number | `lte` | number | The documents whose field value is **less or equal to** the provided value will be ignored. | +| number | `between` | number | The documents whose field value is **between** the provided values (included) will be ignored. | +| bool | | True / False | The documents whose field value is **equal to** the provided value will be ignored. | +| keyword | | any value / list of values | The documents whose field value is **equal to** the provided value or **among** the provided values will be ignored. | +| keywords | `or` | any value / list of values | The documents whose field value is **contains** the provided value or **contains one of** the provided values will be ignored. | +| keywords | `and` | any value / list of values | The documents whose field value **contains all** the provided values will be ignored. | + +Query example: +```python +query = { + "text": "search terms", + "where": { + "numeric_field_name": ("gte", 1970), + "boolean_field_name": True, + "keyword_field_name": "kw_1", + "keywords_field_name": ("or", ["kws_23", "kws_666"]), + } +} +``` + +### Search + +```python +ar.search( + query: ... + return_docs=True # Default value. + cutoff=100 # Default value. + operator="OR" # Default value. + subset_doc_ids=None # Default value. +) +``` + +- `query`: what to search for and which filters to apply. See the section [Query & Filters](#query--filters) for more details. +- `return_docs`: whether to return documents or only their IDs. +- `cutoff`: number of results to return. +- `operator`: whether to perform conjunctive (`AND`) or disjunctive (`OR`) search. Conjunctive search retrieves documents that contain **all** the query terms. Disjunctive search retrieves documents that contain **at least one** of the query terms. +- `subset_doc_ids`: restrict the search to the subset of documents having the provided IDs. + +Sample output: +```json +[ + { + "id": "doc_2", + "text": "Just like witches at black masses", + "score": 1.7536403 + }, + { + "id": "doc_1", + "text": "Generals gathered in their masses", + "score": 0.6931472 + } +] +``` + + + + + + + diff --git a/docs/dense_retriever.md b/docs/dense_retriever.md index 977e1a5..0ca7fb4 100755 --- a/docs/dense_retriever.md +++ b/docs/dense_retriever.md @@ -84,7 +84,7 @@ dr.search( ) ``` Output: -```python +```json [ { "id": "doc_2", @@ -111,7 +111,7 @@ dr.msearch( ) ``` Output: -```python +```json { "q_1": { "doc_2": 1.7536403, diff --git a/docs/filters.md b/docs/filters.md new file mode 100644 index 0000000..fa12d3f --- /dev/null +++ b/docs/filters.md @@ -0,0 +1,76 @@ +## Filtering Search Results + +[retriv](https://github.com/AmenRa/retriv) supports two way of filtering the search results (`where` and `where_not`) and several type-specific operators. + +- `where` means that only the documents matching the filter will be considered during search. +- `where_not` means that the documents matching the filter will be ignored during search. + +Below we describe the effects of the supported operators for each data type and way of filtering. + +### Where + +| Field Type | Operator | Value | Meaning | +| ---------- | --------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | +| number | `eq` | number | Only the documents whose field value is **equal to** the provided value will be considered during search. | +| number | `gt` | number | Only the documents whose field value is **greater than** the provided value will be considered during search. | +| number | `gte` | number | Only the documents whose field value is **greater or equal to** the provided value will be considered during search. | +| number | `lt` | number | Only the documents whose field value is **less than** the provided value will be considered during search. | +| number | `lte` | number | Only the documents whose field value is **less or equal to** the provided value will be considered during search. | +| number | `between` | number | Only the documents whose field value is **between** the provided values (included) will be considered during search. | +| bool | | True / False | Only the documents whose field value is **equal to** the provided value will be considered during search. | +| keyword | | any value / list of values | Only the documents whose field value is **equal to** the provided value or **among** the provided values will be considered during search. | +| keywords | `or` | any value / list of values | Only the documents whose field value is **contains** the provided value or **contains one of** the provided values will be considered during search. | +| keywords | `and` | any value / list of values | Only the documents whose field value **contains all** the provided values will be considered during search. | + +Query example: +```python +query = { + "text": "search terms", + "where": { + "numeric_field_name": ("gte", 1970), + "boolean_field_name": True, + "keyword_field_name": "kw_1", + "keywords_field_name": ("or", ["kws_23", "kws_666"]), + } +} +``` + +Alternatively, you can omit the `where` key and use the following syntax: +```python +query = { + "text": "search terms", + "numeric_field_name": ("gte", 1970), + "boolean_field_name": True, + "keyword_field_name": "kw_1", + "keywords_field_name": ("or", ["kws_23", "kws_666"]), +} +``` + + +### Where not + +| Field Type | Operator | Value | Meaning | +| ---------- | --------- | -------------------------- | ------------------------------------------------------------------------------------------------------------------------------ | +| number | `eq` | number | The documents whose field value is **equal to** the provided value will be ignored. | +| number | `gt` | number | The documents whose field value is **greater than** the provided value will be ignored. | +| number | `gte` | number | The documents whose field value is **greater or equal to** the provided value will be ignored. | +| number | `lt` | number | The documents whose field value is **less than** the provided value will be ignored. | +| number | `lte` | number | The documents whose field value is **less or equal to** the provided value will be ignored. | +| number | `between` | number | The documents whose field value is **between** the provided values (included) will be ignored. | +| bool | | True / False | The documents whose field value is **equal to** the provided value will be ignored. | +| keyword | | any value / list of values | The documents whose field value is **equal to** the provided value or **among** the provided values will be ignored. | +| keywords | `or` | any value / list of values | The documents whose field value is **contains** the provided value or **contains one of** the provided values will be ignored. | +| keywords | `and` | any value / list of values | The documents whose field value **contains all** the provided values will be ignored. | + +Query example: +```python +query = { + "text": "search terms", + "where": { + "numeric_field_name": ("gte", 1970), + "boolean_field_name": True, + "keyword_field_name": "kw_1", + "keywords_field_name": ("or", ["kws_23", "kws_666"]), + } +} +``` \ No newline at end of file diff --git a/docs/hybrid_retriever.md b/docs/hybrid_retriever.md index 7521028..a94796b 100755 --- a/docs/hybrid_retriever.md +++ b/docs/hybrid_retriever.md @@ -113,7 +113,7 @@ hr.search( ) ``` Output: -```python +```json [ { "id": "doc_2", @@ -140,7 +140,7 @@ hr.msearch( ) ``` Output: -```python +```json { "q_1": { "doc_2": 1.7536403, diff --git a/docs/sparse_retriever.md b/docs/sparse_retriever.md index 4e8ea18..a3d95c6 100755 --- a/docs/sparse_retriever.md +++ b/docs/sparse_retriever.md @@ -2,7 +2,7 @@ The Sparse Retriever is a traditional searcher based on lexical matching. It supports [BM25](https://en.wikipedia.org/wiki/Okapi_BM25), the retrieval model used by major search engines libraries, such as [Lucene](https://en.wikipedia.org/wiki/Apache_Lucene) and [Elasticsearch](https://en.wikipedia.org/wiki/Elasticsearch). -[retriv](https://github.com/AmenRa/retriv) also implements the classic relevance model TF-IDF for educational purposes. +[retriv](https://github.com/AmenRa/retriv) also implements the classic relevance model [TF-IDF](https://en.wikipedia.org/wiki/Tf–idf) for educational purposes. The Sparse Retriever also provides several resources for multi-lingual [text pre-processing](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md), aiming to maximize its retrieval effectiveness. @@ -94,7 +94,7 @@ sr.search( ) ``` Output: -```python +```json [ { "id": "doc_2", @@ -120,7 +120,7 @@ sr.msearch( ) ``` Output: -```python +```json { "q_1": { "doc_2": 1.7536403, diff --git a/retriv/base_retriever.py b/retriv/base_retriever.py index 6b7fd71..8e7bec6 100644 --- a/retriv/base_retriever.py +++ b/retriv/base_retriever.py @@ -44,12 +44,7 @@ def collection_generator(self, path: str, callback: callable = None): return collection - def save_collection( - self, - collection: Iterable, - callback: callable = None, - show_progress: bool = True, - ): + def save_collection(self, collection: Iterable, callback: callable = None): with open(docs_path(self.index_name), "wb") as f: for doc in collection: x = callback(doc) if callback is not None else doc diff --git a/retriv/dense_retriever/dense_retriever.py b/retriv/dense_retriever/dense_retriever.py index 0ad89ca..90b38ae 100644 --- a/retriv/dense_retriever/dense_retriever.py +++ b/retriv/dense_retriever/dense_retriever.py @@ -168,7 +168,7 @@ def index( DenseRetriever: Dense Retriever """ - self.save_collection(collection, callback, show_progress) + self.save_collection(collection, callback) self.initialize_doc_index() self.initialize_id_mapping() self.doc_count = len(self.id_mapping) diff --git a/retriv/experimental/__init__.py b/retriv/experimental/__init__.py new file mode 100644 index 0000000..c045e53 --- /dev/null +++ b/retriv/experimental/__init__.py @@ -0,0 +1,4 @@ +__all__ = ["AdvancedRetriever"] + + +from .advanced_retriever import AdvancedRetriever diff --git a/retriv/experimental/advanced_retriever.py b/retriv/experimental/advanced_retriever.py new file mode 100755 index 0000000..a0b17fd --- /dev/null +++ b/retriv/experimental/advanced_retriever.py @@ -0,0 +1,576 @@ +import os +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Set, Union + +import numba as nb +import numpy as np +import orjson +from numba.typed import List as TypedList +from oneliner_utils import create_path, read_jsonl +from tqdm import tqdm + +from ..autotune import tune_bm25 +from ..base_retriever import BaseRetriever +from ..paths import docs_path, fr_state_path +from ..sparse_retriever.preprocessing import ( + get_stemmer, + get_stopwords, + get_tokenizer, + preprocessing, + preprocessing_multi, +) +from ..sparse_retriever.sparse_retrieval_models.bm25 import bm25, bm25_multi +from ..sparse_retriever.sparse_retrieval_models.tf_idf import tf_idf, tf_idf_multi +from ..sparse_retriever.sparse_retriever import SparseRetriever, build_inverted_index +from ..utils.numba_utils import diff_sorted, intersect_sorted_multi, union_sorted_multi + +CLAUSE_LIST = ["must", "must not"] +OPERATOR_LIST = ["eq", "gt", "gte", "lt", "lte", "between", "and", "or"] +KIND_LIST = ["id", "text", "number", "bool", "keyword", "keywords"] + + +class AdvancedRetriever(SparseRetriever): + def __init__( + self, + schema: Dict[str, str], + index_name: str = "new-index", + model: str = "bm25", + min_df: int = 1, + tokenizer: Union[str, callable] = "whitespace", + stemmer: Union[str, callable] = "english", + stopwords: Union[str, List[str], Set[str]] = "english", + do_lowercasing: bool = True, + do_ampersand_normalization: bool = True, + do_special_chars_normalization: bool = True, + do_acronyms_normalization: bool = True, + do_punctuation_removal: bool = True, + hyperparams: dict = None, + ): + assert model.lower() in {"bm25", "tf-idf"} + assert min_df > 0, "`min_df` must be greater than zero." + self.check_schema(schema) + self.init_args = { + "schema": schema, + "model": model.lower(), + "min_df": min_df, + "index_name": index_name, + "do_lowercasing": do_lowercasing, + "do_ampersand_normalization": do_ampersand_normalization, + "do_special_chars_normalization": do_special_chars_normalization, + "do_acronyms_normalization": do_acronyms_normalization, + "do_punctuation_removal": do_punctuation_removal, + "tokenizer": tokenizer, + "stemmer": stemmer, + "stopwords": stopwords, + } + + self.schema = schema + self.text_field = [k for k, v in self.schema.items() if v == "text"][0] + self.model = model.lower() + self.min_df = min_df + self.index_name = index_name + self.do_lowercasing = do_lowercasing + self.do_ampersand_normalization = do_ampersand_normalization + self.do_special_chars_normalization = do_special_chars_normalization + self.do_acronyms_normalization = do_acronyms_normalization + self.do_punctuation_removal = do_punctuation_removal + self.tokenizer = get_tokenizer(tokenizer) + self.stemmer = get_stemmer(stemmer) + self.stopwords = [self.stemmer(sw) for sw in get_stopwords(stopwords)] + self.id_mapping = None + self.reversed_id_mapping = None + self.inverted_index = None + self.vocabulary = None + self.doc_count = None + self.doc_ids = None + self.doc_lens = None + self.avg_doc_len = None + self.relative_doc_lens = None + self.doc_index = None + # self.mappings = None + self.metadata = None + + self.preprocessing_kwargs = { + "tokenizer": self.tokenizer, + "stemmer": self.stemmer, + "stopwords": self.stopwords, + "do_lowercasing": self.do_lowercasing, + "do_ampersand_normalization": self.do_ampersand_normalization, + "do_special_chars_normalization": self.do_special_chars_normalization, + "do_acronyms_normalization": self.do_acronyms_normalization, + "do_punctuation_removal": self.do_punctuation_removal, + } + + self.preprocessing_pipe = preprocessing_multi(**self.preprocessing_kwargs) + + self.hyperparams = dict(b=0.75, k1=1.2) if hyperparams is None else hyperparams + + def save(self) -> None: + """Save the state of the retriever to be able to restore it later.""" + + state = { + "init_args": self.init_args, + "id_mapping": self.id_mapping, + "doc_count": self.doc_count, + "inverted_index": self.inverted_index, + "vocabulary": self.vocabulary, + "doc_lens": self.doc_lens, + "relative_doc_lens": self.relative_doc_lens, + "hyperparams": self.hyperparams, + } + + np.savez_compressed(fr_state_path(self.index_name), state=state) + + @staticmethod + def load(index_name: str = "new-index"): + """Load a retriever and its index. + + Args: + index_name (str, optional): Name of the index. Defaults to "new-index". + + Returns: + SparseRetriever: Sparse Retriever. + """ + + state = np.load(fr_state_path(index_name), allow_pickle=True)["state"][()] + + se = AdvancedRetriever(**state["init_args"]) + se.initialize_doc_index() + se.id_mapping = state["id_mapping"] + se.reversed_id_mapping = {v: k for k, v in state["id_mapping"].items()} + se.doc_count = state["doc_count"] + se.doc_ids = np.arange(state["doc_count"], dtype=np.int32) + se.inverted_index = state["inverted_index"] + se.vocabulary = set(se.inverted_index) + se.doc_lens = state["doc_lens"] + se.relative_doc_lens = state["relative_doc_lens"] + se.hyperparams = state["hyperparams"] + + state = { + "init_args": se.init_args, + "id_mapping": se.id_mapping, + "doc_count": se.doc_count, + "inverted_index": se.inverted_index, + "vocabulary": se.vocabulary, + "doc_lens": se.doc_lens, + "relative_doc_lens": se.relative_doc_lens, + "hyperparams": se.hyperparams, + } + + return se + + def check_schema(self, schema: Dict[str, str]) -> None: + """Check if schema is valid""" + text_found = False + + if "id" not in schema: + raise ValueError("Schema must contain an id field") + + for k in schema: + if not isinstance(k, str): + raise TypeError("Schema keys must be strings") + + for value in schema.values(): + if value not in KIND_LIST: + raise ValueError(f"Type {value} not supported") + if value == "text": + if text_found: + raise ValueError("Only one field can be text") + text_found = True + + return True + + def check_collection(self, collection: Iterable, schema: Dict[str, str]) -> None: + """Check collection against a schema""" + for i, doc in enumerate(collection): + if "id" not in doc: + raise ValueError(f"Doc #{i} has no id") + + doc_id = doc["id"] + + for field in schema: + if field not in doc: + raise ValueError(f"Field {field} not in doc {doc_id}") + + for field in doc: + if field not in schema: + raise ValueError(f"Field {field} not in schema") + + kind = schema[field] + value = doc[field] + + if kind == "id" and not isinstance(value, (int, str)): + raise TypeError(f"Field {field} of doc #{i} has wrong type") + + elif kind == "text" and not isinstance(value, str): + raise TypeError(f"Field {field} of doc {doc_id} has wrong type") + + elif kind == "number" and not isinstance(value, (int, float)): + raise TypeError(f"Field {field} of doc {doc_id} has wrong type") + + elif kind == "bool" and not isinstance(value, bool): + raise TypeError(f"Field {field} of doc {doc_id} has wrong type") + + elif kind == "keyword" and not isinstance(value, str): + raise TypeError(f"Field {field} of doc {doc_id} has wrong type") + + elif kind == "keywords" and not isinstance(value, (list, set, tuple)): + raise TypeError(f"Field {field} of doc {doc_id} has wrong type") + + return True + + def initialize_metadata(self, schema): + metadata = {} + + for field, kind in schema.items(): + if kind == "number": + metadata[field] = [] + if kind == "bool": + metadata[field] = {True: [], False: []} + elif kind in {"keyword", "keywords"}: + metadata[field] = defaultdict(list) + + return metadata + + def fill_metadata(self, metadata, collection, schema): + for i, doc in enumerate(collection): + for field, kind in schema.items(): + if kind == "number": + metadata[field].append(doc[field]) + elif kind in ["bool", "keyword"]: + metadata[field][doc[field]].append(i) + elif kind == "keywords": + for keyword in doc[field]: + metadata[field][keyword].append(i) + + return metadata + + def index_metadata(self, collection, schema): + metadata = self.initialize_metadata(schema) + metadata = self.fill_metadata(metadata, collection, schema) + + # Convert to numpy arrays + for field, kind in schema.items(): + if kind == "number": + metadata[field] = np.array(metadata[field]) + elif kind == "bool": + metadata[field][True] = np.array(metadata[field][True], dtype=np.int32) + metadata[field][False] = np.array( + metadata[field][False], dtype=np.int32 + ) + elif kind in ["keyword", "keywords"]: + metadata[field] = dict(metadata[field]) + for keyword in metadata[field]: + metadata[field][keyword] = np.array( + metadata[field][keyword], dtype=np.int32 + ) + return metadata + + def index_aux(self, text_field: str, show_progress: bool = True): + """Internal usage.""" + collection = read_jsonl( + docs_path(self.index_name), + generator=True, + callback=lambda x: x[text_field], + ) + + # Preprocessing -------------------------------------------------------- + collection = self.preprocessing_pipe(collection, generator=True) + + # Inverted index ------------------------------------------------------- + ( + self.inverted_index, + self.doc_lens, + self.relative_doc_lens, + ) = build_inverted_index( + collection=collection, + n_docs=self.doc_count, + min_df=self.min_df, + show_progress=show_progress, + ) + self.avg_doc_len = np.mean(self.doc_lens, dtype=np.float32) + self.vocabulary = set(self.inverted_index) + + def index( + self, + collection: Iterable, + callback: callable = None, + show_progress: bool = True, + ): + """Index a given collection of documents. + + Args: + collection (Iterable): collection of documents to index. + + callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None. + + show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True. + + Returns: + SparseRetriever: Sparse Retriever. + """ + self.check_collection(collection, self.schema) + self.save_collection(collection, callback) + self.initialize_doc_index() + self.initialize_id_mapping() + self.reversed_id_mapping = {v: k for k, v in self.id_mapping.items()} + self.doc_count = len(self.id_mapping) + self.doc_ids = np.arange(self.doc_count, dtype=np.int32) + self.index_aux( + text_field=self.text_field, + show_progress=show_progress, + ) + self.metadata = self.index_metadata(collection, schema=self.schema) + self.save() + return self + + def index_file( + self, path: str, callback: callable = None, show_progress: bool = True + ): + """Index the collection contained in a given file. + + Args: + path (str): path of file containing the collection to index. + + callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None. + + show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True. + + Returns: + SparseRetriever: Sparse Retriever + """ + + # collection = self.collection_generator(path=path, callback=callback) + collection_generator = self.collection_generator + + class Collection: + def __init__(self, path, callback): + self.path = path + self.callback = callback + + def __iter__(self): + yield from collection_generator(path=self.path, callback=self.callback) + + return self.index( + collection=Collection(path, callback), show_progress=show_progress + ) + + def filter_doc_ids( + self, + field: str, + clause: str, + value: Any = None, + operator: str = None, + raise_error: bool = True, + ): + if clause not in CLAUSE_LIST: + raise ValueError(f"Clause must be one of {CLAUSE_LIST}") + if operator is not None and operator not in OPERATOR_LIST: + raise ValueError(f"Operator must be one of {OPERATOR_LIST}") + if field not in self.schema: + raise ValueError(f"Field `{field}` not in schema") + + kind = self.schema[field] + doc_ids = self.doc_ids + id_mapping = self.reversed_id_mapping + metadata = self.metadata + + def get_value(field, value): + if raise_error and value not in metadata[field]: + raise ValueError(f"No document has value `{value}` in field `{field}`.") + + return metadata[field].get(value, np.array([], dtype=np.int32)) + + if kind == "id": + if clause == "must": + return doc_ids[np.isin(doc_ids, [id_mapping[v] for v in value])] + elif clause == "must not": + return doc_ids[~np.isin(doc_ids, [id_mapping[v] for v in value])] + + elif kind == "bool": + if clause == "must": + return metadata[field][value] + elif clause == "must not": + return metadata[field][not value] + + elif kind == "keyword": + if clause == "must": + if isinstance(value, list): + return union_sorted_multi( + TypedList([get_value(field, v) for v in value]) + ) + else: + return get_value(field, value) + + elif clause == "must not": + if isinstance(value, list): + ids = [ + get_value(field, v) for v in metadata[field] if v not in value + ] + + else: + ids = [get_value(field, v) for v in metadata[field] if v != value] + + return union_sorted_multi(TypedList(ids)) + + elif kind == "keywords": + if clause == "must": + if isinstance(value, list): + if operator == "and": + return intersect_sorted_multi( + TypedList([get_value(field, v) for v in value]) + ) + elif operator == "or": + return union_sorted_multi( + TypedList([get_value(field, v) for v in value]) + ) + else: + raise ValueError( + f"Operator `{operator}`not supported for keywords field" + ) + else: + return get_value(field, value) + + elif clause == "must not": + if isinstance(value, list): + must_not_ids = [get_value(field, v) for v in value] + if operator == "and": + must_not_ids = intersect_sorted_multi(TypedList(must_not_ids)) + elif operator == "or": + must_not_ids = union_sorted_multi(TypedList(must_not_ids)) + else: + raise ValueError( + f"Operator `{operator}`not supported for keywords field" + ) + + return diff_sorted( + np.arange(self.doc_count, dtype=np.int32), must_not_ids + ) + + else: + return diff_sorted( + np.arange(self.doc_count, dtype=np.int32), + get_value(field, value), + ) + + elif kind == "number": + if operator == "eq": + mask = metadata[field] == value + elif operator == "gt": + mask = metadata[field] > value + elif operator == "gte": + mask = metadata[field] >= value + elif operator == "lt": + mask = metadata[field] < value + elif operator == "lte": + mask = metadata[field] <= value + elif operator == "between": + data, min_v, max_v = metadata[field], value[0], value[1] + mask = np.logical_and(data >= min_v, data <= max_v) + else: + raise ValueError("Operator not supported for numeric field") + + if clause == "must": + return doc_ids[mask] + elif clause == "must not": + return doc_ids[~mask] + + else: + raise ValueError( + f"Field {field} of type {kind} not supported for filtering" + ) + + def get_filtered_doc_ids(self, filters: List[Dict]) -> np.ndarray: + if len(filters) == 1: + return self.filter_doc_ids(**filters[0]) + filtered_doc_ids = TypedList([self.filter_doc_ids(**f) for f in filters]) + return intersect_sorted_multi(filtered_doc_ids) + + def format_filters(self, filters: Dict, clause: str = "must") -> List[Dict]: + formatted_filters = [] + + for field, value in filters.items(): + if self.schema[field] in {"id", "bool", "keyword"}: + f = dict(field=field, clause=clause, value=value) + + elif self.schema[field] in {"number", "keywords"}: + f = dict(field=field, clause=clause, value=value[1], operator=value[0]) + + formatted_filters.append(f) + + return formatted_filters + + def search( + self, + query: Union[Dict, str], + return_docs: bool = True, + cutoff: int = 100, + operator: str = "OR", + subset_doc_ids: List = None, + ) -> List: + if isinstance(query, str): + query_text = query + if subset_doc_ids is not None: + subset_doc_ids = np.array( + [self.reversed_id_mapping[doc_id] for doc_id in subset_doc_ids], + dtype=np.int32, + ) + else: + query_text = query.get("text", "") + must_filters = query.get("where", {}) + must_not_filters = query.get("where_not", {}) + must_single_filters = { + k: v + for k, v in query.items() + if k not in {"text", "where", "where_not"} + } + + must_filters = self.format_filters(must_filters) + must_not_filters = self.format_filters(must_not_filters, clause="must not") + must_single_filters = self.format_filters(must_single_filters) + filters = must_filters + must_not_filters + must_single_filters + subset_doc_ids = self.get_filtered_doc_ids(filters) + + query_terms = self.query_preprocessing(query_text) + query_terms = [t for t in query_terms if t in self.vocabulary] + + if query_terms: + doc_ids = self.get_doc_ids(query_terms) + term_doc_freqs = self.get_term_doc_freqs(query_terms) + + if self.model == "bm25": + unique_doc_ids, scores = bm25( + term_doc_freqs=term_doc_freqs, + doc_ids=doc_ids, + relative_doc_lens=self.relative_doc_lens, + doc_count=self.doc_count, + cutoff=cutoff, + operator=operator, + subset_doc_ids=subset_doc_ids, + **self.hyperparams, + ) + elif self.model == "tf-idf": + unique_doc_ids, scores = tf_idf( + term_doc_freqs=term_doc_freqs, + doc_ids=doc_ids, + doc_lens=self.doc_lens, + cutoff=cutoff, + operator=operator, + subset_doc_ids=subset_doc_ids, + ) + else: + raise NotImplementedError() + else: + if subset_doc_ids is None: + unique_doc_ids = self.doc_ids + scores = np.ones(self.doc_count) + else: + unique_doc_ids = subset_doc_ids + scores = np.ones(len(subset_doc_ids)) + + unique_doc_ids = self.map_internal_ids_to_original_ids(unique_doc_ids) + + if not return_docs: + return dict(zip(unique_doc_ids, scores)) + + return self.prepare_results(unique_doc_ids, scores) diff --git a/retriv/hybrid_retriever.py b/retriv/hybrid_retriever.py index f8662ca..431e9ad 100644 --- a/retriv/hybrid_retriever.py +++ b/retriv/hybrid_retriever.py @@ -134,7 +134,7 @@ def index( HybridRetriever: Hybrid Retriever """ - self.save_collection(collection, callback, show_progress) + self.save_collection(collection, callback) self.initialize_doc_index() self.initialize_id_mapping() diff --git a/retriv/paths.py b/retriv/paths.py index 7b878a7..de4fcdd 100644 --- a/retriv/paths.py +++ b/retriv/paths.py @@ -28,6 +28,10 @@ def sr_state_path(index_name: str): return index_path(index_name) / "sr_state.npz" +def fr_state_path(index_name: str): + return index_path(index_name) / "fr_state.npz" + + def embeddings_path(index_name: str): return index_path(index_name) / "embeddings.h5" diff --git a/retriv/sparse_retriever/sparse_retrieval_models/bm25.py b/retriv/sparse_retriever/sparse_retrieval_models/bm25.py index 59572e1..5a20671 100755 --- a/retriv/sparse_retriever/sparse_retrieval_models/bm25.py +++ b/retriv/sparse_retriever/sparse_retrieval_models/bm25.py @@ -5,7 +5,12 @@ from numba import njit, prange from numba.typed import List as TypedList -from ...utils.numba_utils import union_sorted_multi, unsorted_top_k +from ...utils.numba_utils import ( + intersect_sorted, + intersect_sorted_multi, + union_sorted_multi, + unsorted_top_k, +) @njit(cache=True) @@ -17,8 +22,16 @@ def bm25( relative_doc_lens: nb.typed.List[np.ndarray], doc_count: int, cutoff: int, + operator: str = "OR", + subset_doc_ids: np.ndarray = None, ) -> Tuple[np.ndarray]: - unique_doc_ids = union_sorted_multi(doc_ids) + if operator == "AND": + unique_doc_ids = intersect_sorted_multi(doc_ids) + elif operator == "OR": + unique_doc_ids = union_sorted_multi(doc_ids) + + if subset_doc_ids is not None: + unique_doc_ids = intersect_sorted(unique_doc_ids, subset_doc_ids) scores = np.empty(doc_count, dtype=np.float32) scores[unique_doc_ids] = 0.0 # Initialize scores diff --git a/retriv/sparse_retriever/sparse_retrieval_models/tf_idf.py b/retriv/sparse_retriever/sparse_retrieval_models/tf_idf.py index bc6cc9d..94799ba 100755 --- a/retriv/sparse_retriever/sparse_retrieval_models/tf_idf.py +++ b/retriv/sparse_retriever/sparse_retrieval_models/tf_idf.py @@ -5,7 +5,12 @@ from numba import njit, prange from numba.typed import List as TypedList -from ...utils.numba_utils import union_sorted_multi, unsorted_top_k +from ...utils.numba_utils import ( + intersect_sorted, + intersect_sorted_multi, + union_sorted_multi, + unsorted_top_k, +) @njit(cache=True) @@ -14,8 +19,16 @@ def tf_idf( doc_ids: nb.typed.List[np.ndarray], doc_lens: nb.typed.List[np.ndarray], cutoff: int, + operator: str = "OR", + subset_doc_ids: np.ndarray = None, ) -> Tuple[np.ndarray]: - unique_doc_ids = union_sorted_multi(doc_ids) + if operator == "AND": + unique_doc_ids = intersect_sorted_multi(doc_ids) + elif operator == "OR": + unique_doc_ids = union_sorted_multi(doc_ids) + + if subset_doc_ids is not None: + unique_doc_ids = intersect_sorted(unique_doc_ids, subset_doc_ids) doc_count = len(doc_lens) scores = np.empty(doc_count, dtype=np.float32) diff --git a/retriv/sparse_retriever/sparse_retriever.py b/retriv/sparse_retriever/sparse_retriever.py index 5c70644..592b019 100644 --- a/retriv/sparse_retriever/sparse_retriever.py +++ b/retriv/sparse_retriever/sparse_retriever.py @@ -217,7 +217,7 @@ def index( SparseRetriever: Sparse Retriever. """ - self.save_collection(collection, callback, show_progress) + self.save_collection(collection, callback) self.initialize_doc_index() self.initialize_id_mapping() self.doc_count = len(self.id_mapping) diff --git a/retriv/utils/numba_utils.py b/retriv/utils/numba_utils.py index c311e41..30edb11 100644 --- a/retriv/utils/numba_utils.py +++ b/retriv/utils/numba_utils.py @@ -96,7 +96,12 @@ def diff_sorted(a1: np.array, a2: np.array): i += 1 j += 1 - return result[:k] + result = result[:k] + + if i < len(a1): + result = np.concatenate((result, a1[i:])) + + return result # ----------------------------------------------------------------------------- diff --git a/setup.py b/setup.py index a980c32..2d98cf7 100644 --- a/setup.py +++ b/setup.py @@ -5,10 +5,10 @@ setuptools.setup( name="retriv", - version="0.2.1", + version="0.2.2", author="Elias Bassani", author_email="elias.bssn@gmail.com", - description="retriv: A Blazing-Fast Python Search Engine.", + description="retriv: A Python Search Engine for Humans.", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/AmenRa/retriv", diff --git a/tests/advanced_retriever/advanced_retriever_test.py b/tests/advanced_retriever/advanced_retriever_test.py new file mode 100755 index 0000000..daaa47b --- /dev/null +++ b/tests/advanced_retriever/advanced_retriever_test.py @@ -0,0 +1,605 @@ +from collections import defaultdict + +import pytest + +from retriv.experimental import AdvancedRetriever + + +# FIXTURES ===================================================================== +@pytest.fixture +def schema(): + return { + "id": "id", + "lyrics": "text", + "year": "number", + "ozzy": "bool", + "album": "keyword", + "genre": "keywords", + } + + +@pytest.fixture +def collection(): + return [ + { + "id": "doc_0", + "lyrics": "Generals gathered in their masses", + "album": "Black Sabbath", + "year": 1969, + "ozzy": True, + "genre": ["Doom", "Heavy Metal"], + }, + { + "id": "doc_1", + "lyrics": "Just like witches at black masses", + "album": "Paranoid", + "year": 1970, + "ozzy": True, + "genre": ["Doom", "Heavy Metal"], + }, + { + "id": "doc_2", + "lyrics": "Evil minds that plot destruction", + "album": "Heaven and Hell", + "year": 1980, + "ozzy": False, + "genre": ["Heavy Metal"], + }, + ] + + +def test_check_schema_no_id(): + with pytest.raises(Exception, match="Schema must contain an id field"): + AdvancedRetriever({"text": "text"}) + + +def test_check_schema_invalid_key(): + with pytest.raises(Exception, match="Schema keys must be strings"): + AdvancedRetriever({"id": "id", 1: "text"}) + + +def test_check_schema_invalid_value(): + with pytest.raises(Exception, match="Type invalid not supported"): + AdvancedRetriever({"id": "id", "text": "invalid"}) + + +def test_check_schema_double_text(): + with pytest.raises(Exception, match="Only one field can be text"): + AdvancedRetriever({"id": "id", "title": "text", "body": "text"}) + + +def test_check_schema_pass(schema): + se = AdvancedRetriever(schema) + assert se.schema == schema + + +def test_check_collection_field_no_id(schema): + collection = [ + { + "lyrics": "Generals gathered in their masses", + "album": "Black Sabbath", + "year": 1969, + "ozzy": True, + "genre": ["Doom", "Heavy Metal"], + } + ] + se = AdvancedRetriever(schema) + with pytest.raises(Exception, match="has no id"): + se.check_collection(collection, schema) + + +def test_check_collection_missing_field(schema): + collection = [ + { + "id": "doc_0", + "lyrics": "Generals gathered in their masses", + "year": 1969, + "ozzy": True, + "genre": ["Doom", "Heavy Metal"], + } + ] + se = AdvancedRetriever(schema) + with pytest.raises(Exception, match="Field album not in doc"): + se.check_collection(collection, schema) + + +def test_check_collection_additional_field(schema): + collection = [ + { + "id": "doc_0", + "lyrics": "Generals gathered in their masses", + "album": "Black Sabbath", + "year": 1969, + "ozzy": True, + "genre": ["Doom", "Heavy Metal"], + "invalid": "value", + } + ] + se = AdvancedRetriever(schema) + with pytest.raises(Exception, match="Field invalid not in schema"): + se.check_collection(collection, schema) + + +def test_check_collection_field_id_wrong_type(schema): + collection = [ + { + "id": [0], + "lyrics": "Generals gathered in their masses", + "album": "Black Sabbath", + "year": 1969, + "ozzy": True, + "genre": ["Doom", "Heavy Metal"], + }, + ] + se = AdvancedRetriever(schema) + with pytest.raises(Exception, match="Field id"): + se.check_collection(collection, schema) + + +def test_check_collection_field_lyrics_wrong_type(schema): + collection = [ + { + "id": "doc_0", + "lyrics": 666, + "album": "Black Sabbath", + "year": 1969, + "ozzy": True, + "genre": ["Doom", "Heavy Metal"], + }, + ] + se = AdvancedRetriever(schema) + with pytest.raises(Exception, match="Field lyrics"): + se.check_collection(collection, schema) + + +def test_check_collection_field_year_wrong_type(schema): + collection = [ + { + "id": "doc_0", + "lyrics": "Generals gathered in their masses", + "album": "Black Sabbath", + "year": "1969", + "ozzy": True, + "genre": ["Doom", "Heavy Metal"], + }, + ] + se = AdvancedRetriever(schema) + with pytest.raises(Exception, match="Field year"): + se.check_collection(collection, schema) + + +def test_check_collection_field_ozzy_wrong_type(schema): + collection = [ + { + "id": "doc_0", + "lyrics": "Generals gathered in their masses", + "album": "Black Sabbath", + "year": 1969, + "ozzy": "True", + "genre": ["Doom", "Heavy Metal"], + }, + ] + se = AdvancedRetriever(schema) + with pytest.raises(Exception, match="Field ozzy"): + se.check_collection(collection, schema) + + +def test_check_collection_field_album_wrong_type(schema): + collection = [ + { + "id": "doc_0", + "lyrics": "Generals gathered in their masses", + "album": ["Black Sabbath"], + "year": 1969, + "ozzy": True, + "genre": ["Doom", "Heavy Metal"], + }, + ] + se = AdvancedRetriever(schema) + with pytest.raises(Exception, match="Field album"): + se.check_collection(collection, schema) + + +def test_check_collection_field_genre_wrong_type(schema): + collection = [ + { + "id": "doc_0", + "lyrics": "Generals gathered in their masses", + "album": "Black Sabbath", + "year": 1969, + "ozzy": True, + "genre": "Doom", + }, + ] + se = AdvancedRetriever(schema) + with pytest.raises(Exception, match="Field genre"): + se.check_collection(collection, schema) + + +def test_check_collection(collection, schema): + se = AdvancedRetriever(schema) + assert se.check_collection(collection, schema) == True + + +def initialize_metadata(schema): + se = AdvancedRetriever(schema) + metadata = se.initialize_metadata(schema) + + assert metadata == { + "year": [], + "ozzy": {True: [], False: []}, + "album": defaultdict(list), + "genre": defaultdict(list), + } + + +def test_fill_metadata(schema, collection): + se = AdvancedRetriever(schema) + se.metadata = se.initialize_metadata(schema) + + metadata = se.fill_metadata( + metadata=se.metadata, collection=collection, schema=schema + ) + + assert metadata == { + "year": [1969, 1970, 1980], + "ozzy": {True: [0, 1], False: [2]}, + "album": { + "Black Sabbath": [0], + "Paranoid": [1], + "Heaven and Hell": [2], + }, + "genre": { + "Doom": [0, 1], + "Heavy Metal": [0, 1, 2], + }, + } + + +def test_index_metadata(collection, schema): + se = AdvancedRetriever(schema) + assert se.check_collection(collection, schema) == True + + metadata = se.index_metadata(collection=collection, schema=schema) + + assert len(metadata) == 4 + assert "year" in list(metadata) + assert "ozzy" in list(metadata) + assert "album" in list(metadata) + assert "genre" in list(metadata) + + assert metadata["year"].tolist() == [1969, 1970, 1980] + + assert len(metadata["ozzy"]) == 2 + assert metadata["ozzy"][True].tolist() == [0, 1] + assert metadata["ozzy"][False].tolist() == [2] + + assert len(metadata["album"]) == 3 + assert metadata["album"]["Black Sabbath"].tolist() == [0] + assert metadata["album"]["Paranoid"].tolist() == [1] + assert metadata["album"]["Heaven and Hell"].tolist() == [2] + + assert len(metadata["genre"]) == 2 + assert metadata["genre"]["Doom"].tolist() == [0, 1] + assert metadata["genre"]["Heavy Metal"].tolist() == [0, 1, 2] + + +def test_index(collection, schema): + se = AdvancedRetriever(schema).index(collection) + + assert len(se.doc_ids) == 3 + + +def test_filter_doc_ids_bool_must(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value = "ozzy", "must", True + assert se.filter_doc_ids(field, clause, value).tolist() == [0, 1] + field, clause, value = "ozzy", "must", False + assert se.filter_doc_ids(field, clause, value).tolist() == [2] + + +def test_filter_doc_ids_bool_must_not(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value = "ozzy", "must not", True + assert se.filter_doc_ids(field, clause, value).tolist() == [2] + field, clause, value = "ozzy", "must not", False + assert se.filter_doc_ids(field, clause, value).tolist() == [0, 1] + + +def test_filter_doc_ids_keyword_must(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value = "album", "must", "Black Sabbath" + assert se.filter_doc_ids(field, clause, value).tolist() == [0] + field, clause, value = "album", "must", "Paranoid" + assert se.filter_doc_ids(field, clause, value).tolist() == [1] + field, clause, value = "album", "must", "Heaven and Hell" + assert se.filter_doc_ids(field, clause, value).tolist() == [2] + + +def test_filter_doc_ids_keyword_must_multi(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value = "album", "must", ["Black Sabbath", "Heaven and Hell"] + assert se.filter_doc_ids(field, clause, value).tolist() == [0, 2] + + +def test_filter_doc_ids_keyword_must_not(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value = "album", "must not", "Black Sabbath" + assert se.filter_doc_ids(field, clause, value).tolist() == [1, 2] + field, clause, value = "album", "must not", "Heaven and Hell" + assert se.filter_doc_ids(field, clause, value).tolist() == [0, 1] + + +def test_filter_doc_ids_keyword_must_not_multi(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value = "album", "must not", ["Black Sabbath", "Heaven and Hell"] + assert se.filter_doc_ids(field, clause, value).tolist() == [1] + field, clause, value = "album", "must not", ["Black Sabbath", "Paranoid"] + assert se.filter_doc_ids(field, clause, value).tolist() == [2] + + +def test_filter_doc_ids_number_must(collection, schema): + se = AdvancedRetriever(schema).index(collection) + + field, clause, value, operator = "year", "must", 1969, "eq" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [0] + + field, clause, value, operator = "year", "must", 1969, "gt" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [1, 2] + + field, clause, value, operator = "year", "must", 1969, "gte" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [0, 1, 2] + + field, clause, value, operator = "year", "must", 1970, "lt" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [0] + + field, clause, value, operator = "year", "must", 1970, "lte" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [0, 1] + + field, clause, value, operator = "year", "must", [1970, 1980], "between" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [1, 2] + + +def test_filter_doc_ids_number_must_not(collection, schema): + se = AdvancedRetriever(schema).index(collection) + + field, clause, value, operator = "year", "must not", 1969, "eq" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [1, 2] + + field, clause, value, operator = "year", "must not", 1969, "gt" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [0] + + field, clause, value, operator = "year", "must not", 1969, "gte" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [] + + field, clause, value, operator = "year", "must not", 1970, "lt" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [1, 2] + + field, clause, value, operator = "year", "must not", 1970, "lte" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [2] + + field, clause, value, operator = "year", "must not", [1970, 1975], "between" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [0, 2] + + +def test_filter_doc_ids_keywords_must_or(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value, operator = "genre", "must", "Doom", "or" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [0, 1] + + +def test_filter_doc_ids_keywords_must_multi_or(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value, operator = "genre", "must", ["Doom", "Heavy Metal"], "or" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [0, 1, 2] + + +def test_filter_doc_ids_keywords_must_not_or(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value, operator = "genre", "must not", "Doom", "or" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [2] + + +def test_filter_doc_ids_keywords_must_not_multi_or(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value, operator = "genre", "must not", ["Doom"], "or" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [2] + field, clause, value, operator = "genre", "must not", ["Doom", "Heavy Metal"], "or" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [] + + +def test_filter_doc_ids_keywords_must_and(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value, operator = "genre", "must", "Doom", "and" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [0, 1] + + +def test_filter_doc_ids_keywords_must_multi_and(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value, operator = "genre", "must", ["Doom", "Heavy Metal"], "and" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [0, 1] + + +def test_filter_doc_ids_keywords_must_not_and(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value, operator = "genre", "must not", "Doom", "and" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [2] + + +def test_filter_doc_ids_keywords_must_not_multi_and(collection, schema): + se = AdvancedRetriever(schema).index(collection) + field, clause, value, operator = "genre", "must not", ["Doom"], "and" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [2] + field, clause, value, operator = "genre", "must not", ["Doom", "Heavy Metal"], "and" + assert se.filter_doc_ids(field, clause, value, operator).tolist() == [2] + + +def test_get_filtered_doc_ids(collection, schema): + se = AdvancedRetriever(schema).index(collection) + + filters = [ + dict(field="genre", clause="must", value=["Doom", "Heavy Metal"], operator="or") + ] + assert se.get_filtered_doc_ids(filters).tolist() == [0, 1, 2] + + filters = [dict(field="year", clause="must not", value=1970, operator="lt")] + assert se.get_filtered_doc_ids(filters).tolist() == [1, 2] + + filters = [dict(field="ozzy", clause="must", value=True)] + assert se.get_filtered_doc_ids(filters).tolist() == [0, 1] + + filters = [ + dict( + field="genre", clause="must", value=["Doom", "Heavy Metal"], operator="or" + ), # 0, 1, 2 + dict(field="year", clause="must not", value=1970, operator="lt"), # 1, 2 + dict(field="ozzy", clause="must", value=True), # 0, 1 + ] + assert se.get_filtered_doc_ids(filters).tolist() == [1] + + +def test_format_filters(collection, schema): + se = AdvancedRetriever(schema).index(collection) + + filters = { + "year": ("gte", 1980), + "ozzy": True, + "album": ["Paranoid", "Master of Reality"], + "genre": ("or", ["Doom", "Heavy Metal"]), + } + + formatted_filters = se.format_filters(filters) + assert len(formatted_filters) == 4 + assert formatted_filters[0] == dict( + field="year", clause="must", value=1980, operator="gte" + ) + assert formatted_filters[1] == dict(field="ozzy", clause="must", value=True) + assert formatted_filters[2] == dict( + field="album", clause="must", value=["Paranoid", "Master of Reality"] + ) + assert formatted_filters[3] == dict( + field="genre", clause="must", value=["Doom", "Heavy Metal"], operator="or" + ) + + formatted_filters = se.format_filters(filters, clause="must not") + assert len(formatted_filters) == 4 + assert formatted_filters[0] == dict( + field="year", clause="must not", value=1980, operator="gte" + ) + assert formatted_filters[1] == dict(field="ozzy", clause="must not", value=True) + assert formatted_filters[2] == dict( + field="album", clause="must not", value=["Paranoid", "Master of Reality"] + ) + assert formatted_filters[3] == dict( + field="genre", clause="must not", value=["Doom", "Heavy Metal"], operator="or" + ) + + formatted_filters = se.format_filters({}) + assert formatted_filters == [] + + +def test_search_filters_only(collection, schema): + se = AdvancedRetriever(schema).index(collection) + + query = { + "year": ("gte", 1970), + "ozzy": True, + "album": ["Paranoid", "Heaven and Hell"], + "genre": ("or", ["Doom", "Heavy Metal"]), + } + + res = se.search(query=query, return_docs=False) + + assert len(res) == 1 + assert res["doc_1"] == 1.0 + + query = { + "where": { + "year": ("gt", 1969), + "album": ["Paranoid", "Heaven and Hell"], + "genre": ("or", ["Doom", "Heavy Metal"]), + } + } + + res = se.search(query=query, return_docs=False) + + assert len(res) == 2 + assert res["doc_1"] == 1.0 + assert res["doc_2"] == 1.0 + + query = { + "where_not": { + "year": ("gt", 1969), + "ozzy": False, + "album": ["Paranoid", "Heaven and Hell"], + } + } + + res = se.search(query=query, return_docs=False) + + assert len(res) == 1 + assert res["doc_0"] == 1.0 + + +def test_search_or(collection, schema): + se = AdvancedRetriever(schema).index(collection) + res = se.search(query="witches masses", return_docs=False) + assert len(res) == 2 + assert "doc_0" in res + assert "doc_1" in res + + +def test_search_and(collection, schema): + se = AdvancedRetriever(schema).index(collection) + res = se.search(query="witches masses", return_docs=False, operator="AND") + assert len(res) == 1 + assert "doc_1" in res + + +def test_advanced_search(collection, schema): + se = AdvancedRetriever(schema).index(collection) + + query = { + "text": "witches masses", + "year": ("gte", 1970), + "ozzy": True, + "album": ["Paranoid", "Heaven and Hell"], + "genre": ("or", ["Doom", "Heavy Metal"]), + } + + res = se.search(query=query, return_docs=False) + assert len(res) == 1 + assert "doc_1" in res + + +def test_search_with_subset_doc_ids(collection, schema): + se = AdvancedRetriever(schema).index(collection) + + res = se.search( + query="witches masses", subset_doc_ids=["doc_1", "doc_2"], return_docs=False + ) + assert len(res) == 1 + assert "doc_1" in res + + +def test_index_file(schema): + se = AdvancedRetriever(schema).index_file( + "tests/test_data/multifield_collection.jsonl" + ) + + query = { + "text": "witches masses", + "year": ("gte", 1970), + "ozzy": True, + "album": ["Paranoid", "Heaven and Hell"], + "genre": ("or", ["Doom", "Heavy Metal"]), + } + + res = se.search(query=query, return_docs=False) + assert len(res) == 1 + assert "doc_1" in res + + +def test_load(schema): + se = AdvancedRetriever.load("new-index") + assert se.schema == schema diff --git a/tests/sparse_retriever/numba_utils_test.py b/tests/numba_utils_test.py similarity index 92% rename from tests/sparse_retriever/numba_utils_test.py rename to tests/numba_utils_test.py index db8c3a3..462f311 100644 --- a/tests/sparse_retriever/numba_utils_test.py +++ b/tests/numba_utils_test.py @@ -71,6 +71,13 @@ def test_diff_sorted(): assert np.array_equal(result, expected) + a1 = np.array([1, 3, 4, 7, 11], dtype=np.int32) + a2 = np.array([1, 4, 7, 9], dtype=np.int32) + result = diff_sorted(a1, a2) + expected = np.array([3, 11], dtype=np.int32) + + assert np.array_equal(result, expected) + def test_concat1d(): a1 = np.array([1, 3, 4, 7], dtype=np.int32)