diff --git a/docs/src/embeddings/index.md b/docs/src/embeddings/index.md index 0752cabeae..6faee93fa7 100644 --- a/docs/src/embeddings/index.md +++ b/docs/src/embeddings/index.md @@ -68,6 +68,39 @@ table.add( ] ) +query = "greetings" +actual = table.search(query).limit(1).to_pydantic(Words)[0] +print(actual.text) +``` + +### Jina Embeddings +LanceDB registers the JinaAI embeddings function in the registry as `jina`. You can pass any supported model name to the `create`. By default it uses `"jina-clip-v1"`. +`jina-clip-v1` can handle both text and images and other models only support `text`. + +You need to pass `JINA_API_KEY` in the environment variable or pass it as `api_key` to `create` method. + +```python +import os +import lancedb +from lancedb.pydantic import LanceModel, Vector +from lancedb.embeddings import get_registry +os.environ['JINA_API_KEY'] = "jina_*" + +db = lancedb.connect("/tmp/db") +func = get_registry().get("jina").create(name="jina-clip-v1") + +class Words(LanceModel): + text: str = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() + +table = db.create_table("words", schema=Words, mode="overwrite") +table.add( + [ + {"text": "hello world"}, + {"text": "goodbye world"} + ] + ) + query = "greetings" actual = table.search(query).limit(1).to_pydantic(Words)[0] print(actual.text) diff --git a/docs/src/reranking/jina.md b/docs/src/reranking/jina.md new file mode 100644 index 0000000000..085a379a83 --- /dev/null +++ b/docs/src/reranking/jina.md @@ -0,0 +1,78 @@ +# Jina Reranker + +This re-ranker uses the [Jina](https://jina.ai/reranker/) API to rerank the search results. You can use this re-ranker by passing `JinaReranker()` to the `rerank()` method. Note that you'll either need to set the `JINA_API_KEY` environment variable or pass the `api_key` argument to use this re-ranker. + + +!!! note + Supported Query Types: Hybrid, Vector, FTS + + +```python +import os +import lancedb +from lancedb.embeddings import get_registry +from lancedb.pydantic import LanceModel, Vector +from lancedb.rerankers import JinaReranker + +os.environ['JINA_API_KEY'] = "jina_*" + + +embedder = get_registry().get("jina").create() +db = lancedb.connect("~/.lancedb") + +class Schema(LanceModel): + text: str = embedder.SourceField() + vector: Vector(embedder.ndims()) = embedder.VectorField() + +data = [ + {"text": "hello world"}, + {"text": "goodbye world"} + ] +tbl = db.create_table("test", schema=Schema, mode="overwrite") +tbl.add(data) +reranker = JinaReranker(api_key="key") + +# Run vector search with a reranker +result = tbl.search("hello").rerank(reranker=reranker).to_list() + +# Run FTS search with a reranker +result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list() + +# Run hybrid search with a reranker +tbl.create_fts_index("text", replace=True) +result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list() + +``` + +Accepted Arguments +---------------- +| Argument | Type | Default | Description | +| --- | --- | --- | --- | +| `model_name` | `str` | `"jina-reranker-v2-base-multilingual"` | The name of the reranker model to use. You can find the list of available models in https://jina.ai/reranker/| +| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. | +| `top_n` | `str` | `None` | The number of results to return. If None, will return all results. | +| `api_key` | `str` | `None` | The API key for the Jina API. If not provided, the `JINA_API_KEY` environment variable is used. | +| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type | + + + +## Supported Scores for each query type +You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type: + +### Hybrid Search +|`return_score`| Status | Description | +| --- | --- | --- | +| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column | +| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) | + +### Vector Search +|`return_score`| Status | Description | +| --- | --- | --- | +| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column | +| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) | + +### FTS Search +|`return_score`| Status | Description | +| --- | --- | --- | +| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column | +| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) | \ No newline at end of file diff --git a/python/python/lancedb/embeddings/__init__.py b/python/python/lancedb/embeddings/__init__.py index 9567d07163..3bb6d56d92 100644 --- a/python/python/lancedb/embeddings/__init__.py +++ b/python/python/lancedb/embeddings/__init__.py @@ -25,3 +25,4 @@ from .transformers import TransformersEmbeddingFunction, ColbertEmbeddings from .imagebind import ImageBindEmbeddings from .utils import with_embeddings +from .jinaai import JinaEmbeddings diff --git a/python/python/lancedb/embeddings/jinaai.py b/python/python/lancedb/embeddings/jinaai.py new file mode 100644 index 0000000000..8f6d23698d --- /dev/null +++ b/python/python/lancedb/embeddings/jinaai.py @@ -0,0 +1,172 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import io +import requests +import base64 +import urllib.parse as urlparse +from typing import ClassVar, List, Union, Optional, TYPE_CHECKING + +import numpy as np +import pyarrow as pa + +from ..util import attempt_import_or_raise +from .base import EmbeddingFunction +from .registry import register +from .utils import api_key_not_found_help, TEXT, IMAGES, url_retrieve + +if TYPE_CHECKING: + import PIL + +API_URL = "https://api.jina.ai/v1/embeddings" + + +@register("jina") +class JinaEmbeddings(EmbeddingFunction): + """ + An embedding function that uses the Jina API + + https://jina.ai/embeddings/ + + Parameters + ---------- + name: str, default "jina-clip-v1". Note that some models support both image + and text embeddings and some just text embedding + + api_key: str, default None + The api key to access Jina API. If you pass None, you can set JINA_API_KEY + environment variable + + """ + + name: str = "jina-clip-v1" + api_key: Optional[str] = None + _session: ClassVar = None + + def ndims(self): + # TODO: fix hardcoding + return 768 + + def sanitize_input(self, inputs: IMAGES) -> Union[List[bytes], np.ndarray]: + """ + Sanitize the input to the embedding function. + """ + if isinstance(inputs, (str, bytes)): + inputs = [inputs] + elif isinstance(inputs, pa.Array): + inputs = inputs.to_pylist() + elif isinstance(inputs, pa.ChunkedArray): + inputs = inputs.combine_chunks().to_pylist() + return inputs + + def compute_query_embeddings( + self, query: Union[str, "PIL.Image.Image"], *args, **kwargs + ) -> List[np.ndarray]: + """ + Compute the embeddings for a given user query + + Parameters + ---------- + query : Union[str, PIL.Image.Image] + The query to embed. A query can be either text or an image. + """ + if isinstance(query, str): + return self.generate_text_embeddings([query]) + else: + PIL = attempt_import_or_raise("PIL", "pillow") + if isinstance(query, PIL.Image.Image): + return [self.generate_image_embedding(query)] + else: + raise TypeError( + "JinaEmbeddingFunction supports str or PIL Image as query" + ) + + def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: + texts = self.sanitize_input(texts) + return self.generate_text_embeddings(texts) + + def generate_image_embedding( + self, image: Union[str, bytes, "PIL.Image.Image"] + ) -> np.ndarray: + """ + Generate the embedding for a single image + + Parameters + ---------- + image : Union[str, bytes, PIL.Image.Image] + The image to embed. If the image is a str, it is treated as a uri. + If the image is bytes, it is treated as the raw image bytes. + """ + PIL = attempt_import_or_raise("PIL", "pillow") + if isinstance(image, bytes): + image = {"image": base64.b64encode(image).decode("utf-8")} + if isinstance(image, PIL.Image.Image): + buffered = io.BytesIO() + image.save(buffered, format="PNG") + image_bytes = buffered.getvalue() + image = {"image": base64.b64encode(image_bytes).decode("utf-8")} + elif isinstance(image, str): + parsed = urlparse.urlparse(image) + # TODO handle drive letter on windows. + if parsed.scheme == "file": + pil_image = PIL.Image.open(parsed.path) + elif parsed.scheme == "": + pil_image = PIL.Image.open(image if os.name == "nt" else parsed.path) + elif parsed.scheme.startswith("http"): + pil_image = PIL.Image.open(io.BytesIO(url_retrieve(image))) + else: + raise NotImplementedError("Only local and http(s) urls are supported") + buffered = io.BytesIO() + pil_image.save(buffered, format="PNG") + image_bytes = buffered.getvalue() + image = {"image": base64.b64encode(image_bytes).decode("utf-8")} + return self._generate_embeddings(input=[image])[0] + + def generate_text_embeddings( + self, texts: Union[List[str], np.ndarray], *args, **kwargs + ) -> List[np.array]: + return self._generate_embeddings(input=texts) + + def _generate_embeddings(self, input: List, *args, **kwargs) -> List[np.array]: + """ + Get the embeddings for the given texts + + Parameters + ---------- + texts: list[str] or np.ndarray (of str) + The texts to embed + """ + self._init_client() + resp = JinaEmbeddings._session.post( # type: ignore + API_URL, json={"input": input, "model": self.name} + ).json() + if "data" not in resp: + raise RuntimeError(resp["detail"]) + + embeddings = resp["data"] + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore + + return [result["embedding"] for result in sorted_embeddings] + + def _init_client(self): + if JinaEmbeddings._session is None: + if self.api_key is None and os.environ.get("JINA_API_KEY") is None: + api_key_not_found_help("jina") + api_key = self.api_key or os.environ.get("JINA_API_KEY") + JinaEmbeddings._session = requests.Session() + JinaEmbeddings._session.headers.update( + {"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"} + ) diff --git a/python/python/lancedb/rerankers/__init__.py b/python/python/lancedb/rerankers/__init__.py index af833fd7cb..c7e34d578a 100644 --- a/python/python/lancedb/rerankers/__init__.py +++ b/python/python/lancedb/rerankers/__init__.py @@ -4,6 +4,7 @@ from .cross_encoder import CrossEncoderReranker from .linear_combination import LinearCombinationReranker from .openai import OpenaiReranker +from .jinaai import JinaReranker __all__ = [ "Reranker", @@ -12,4 +13,5 @@ "LinearCombinationReranker", "OpenaiReranker", "ColbertReranker", + "JinaReranker", ] diff --git a/python/python/lancedb/rerankers/jinaai.py b/python/python/lancedb/rerankers/jinaai.py new file mode 100644 index 0000000000..3d3f13c0ef --- /dev/null +++ b/python/python/lancedb/rerankers/jinaai.py @@ -0,0 +1,122 @@ +import os +import requests +from functools import cached_property +from typing import Union + +import pyarrow as pa + +from .base import Reranker + +API_URL = "https://api.jina.ai/v1/rerank" + + +class JinaReranker(Reranker): + """ + Reranks the results using the Jina Rerank API. + https://jina.ai/rerank + + Parameters + ---------- + model_name : str, default "jina-reranker-v2-base-multilingual" + The name of the cross reanker model to use + column : str, default "text" + The name of the column to use as input to the cross encoder model. + top_n : str, default None + The number of results to return. If None, will return all results. + api_key : str, default None + The api key to access Jina API. If you pass None, you can set JINA_API_KEY + environment variable + """ + + def __init__( + self, + model_name: str = "jina-reranker-v2-base-multilingual", + column: str = "text", + top_n: Union[int, None] = None, + return_score="relevance", + api_key: Union[str, None] = None, + ): + super().__init__(return_score) + self.model_name = model_name + self.column = column + self.top_n = top_n + self.api_key = api_key + + @cached_property + def _client(self): + if os.environ.get("JINA_API_KEY") is None and self.api_key is None: + raise ValueError( + "JINA_API_KEY not set. Either set it in your environment or \ + pass it as `api_key` argument to the JinaReranker." + ) + self.api_key = self.api_key or os.environ.get("JINA_API_KEY") + self._session = requests.Session() + self._session.headers.update( + {"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"} + ) + return self._session + + def _rerank(self, result_set: pa.Table, query: str): + docs = result_set[self.column].to_pylist() + response = self._client.post( # type: ignore + API_URL, + json={ + "query": query, + "documents": docs, + "model": self.model_name, + "top_n": self.top_n, + }, + ).json() + if "results" not in response: + raise RuntimeError(response["detail"]) + + results = response["results"] + + indices, scores = list( + zip(*[(result["index"], result["relevance_score"]) for result in results]) + ) # tuples + result_set = result_set.take(list(indices)) + # add the scores + result_set = result_set.append_column( + "_relevance_score", pa.array(scores, type=pa.float32()) + ) + + return result_set + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + combined_results = self._rerank(combined_results, query) + if self.score == "relevance": + combined_results = combined_results.drop_columns(["score", "_distance"]) + elif self.score == "all": + raise NotImplementedError( + "return_score='all' not implemented for JinaReranker" + ) + return combined_results + + def rerank_vector( + self, + query: str, + vector_results: pa.Table, + ): + result_set = self._rerank(vector_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["_distance"]) + + return result_set + + def rerank_fts( + self, + query: str, + fts_results: pa.Table, + ): + result_set = self._rerank(fts_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["score"]) + + return result_set