Skip to content

Commit

Permalink
feat: add Jina integration in Python for Embedding and Reranker (lanc…
Browse files Browse the repository at this point in the history
…edb#1424)

Integration of Jina Embeddings and Rerankers through its API
  • Loading branch information
JoanFM authored Jul 4, 2024
1 parent a5ff623 commit 08d25c5
Show file tree
Hide file tree
Showing 6 changed files with 408 additions and 0 deletions.
33 changes: 33 additions & 0 deletions docs/src/embeddings/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,39 @@ table.add(
]
)

query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text)
```

### Jina Embeddings
LanceDB registers the JinaAI embeddings function in the registry as `jina`. You can pass any supported model name to the `create`. By default it uses `"jina-clip-v1"`.
`jina-clip-v1` can handle both text and images and other models only support `text`.

You need to pass `JINA_API_KEY` in the environment variable or pass it as `api_key` to `create` method.

```python
import os
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
os.environ['JINA_API_KEY'] = "jina_*"

db = lancedb.connect("/tmp/db")
func = get_registry().get("jina").create(name="jina-clip-v1")

class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()

table = db.create_table("words", schema=Words, mode="overwrite")
table.add(
[
{"text": "hello world"},
{"text": "goodbye world"}
]
)

query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text)
Expand Down
78 changes: 78 additions & 0 deletions docs/src/reranking/jina.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Jina Reranker

This re-ranker uses the [Jina](https://jina.ai/reranker/) API to rerank the search results. You can use this re-ranker by passing `JinaReranker()` to the `rerank()` method. Note that you'll either need to set the `JINA_API_KEY` environment variable or pass the `api_key` argument to use this re-ranker.


!!! note
Supported Query Types: Hybrid, Vector, FTS


```python
import os
import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import JinaReranker

os.environ['JINA_API_KEY'] = "jina_*"


embedder = get_registry().get("jina").create()
db = lancedb.connect("~/.lancedb")

class Schema(LanceModel):
text: str = embedder.SourceField()
vector: Vector(embedder.ndims()) = embedder.VectorField()

data = [
{"text": "hello world"},
{"text": "goodbye world"}
]
tbl = db.create_table("test", schema=Schema, mode="overwrite")
tbl.add(data)
reranker = JinaReranker(api_key="key")

# Run vector search with a reranker
result = tbl.search("hello").rerank(reranker=reranker).to_list()

# Run FTS search with a reranker
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()

# Run hybrid search with a reranker
tbl.create_fts_index("text", replace=True)
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()

```

Accepted Arguments
----------------
| Argument | Type | Default | Description |
| --- | --- | --- | --- |
| `model_name` | `str` | `"jina-reranker-v2-base-multilingual"` | The name of the reranker model to use. You can find the list of available models in https://jina.ai/reranker/|
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
| `top_n` | `str` | `None` | The number of results to return. If None, will return all results. |
| `api_key` | `str` | `None` | The API key for the Jina API. If not provided, the `JINA_API_KEY` environment variable is used. |
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |



## Supported Scores for each query type
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:

### Hybrid Search
|`return_score`| Status | Description |
| --- | --- | --- |
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |

### Vector Search
|`return_score`| Status | Description |
| --- | --- | --- |
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |

### FTS Search
|`return_score`| Status | Description |
| --- | --- | --- |
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
1 change: 1 addition & 0 deletions python/python/lancedb/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .transformers import TransformersEmbeddingFunction, ColbertEmbeddings
from .imagebind import ImageBindEmbeddings
from .utils import with_embeddings
from .jinaai import JinaEmbeddings
172 changes: 172 additions & 0 deletions python/python/lancedb/embeddings/jinaai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import io
import requests
import base64
import urllib.parse as urlparse
from typing import ClassVar, List, Union, Optional, TYPE_CHECKING

import numpy as np
import pyarrow as pa

from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help, TEXT, IMAGES, url_retrieve

if TYPE_CHECKING:
import PIL

API_URL = "https://api.jina.ai/v1/embeddings"


@register("jina")
class JinaEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the Jina API
https://jina.ai/embeddings/
Parameters
----------
name: str, default "jina-clip-v1". Note that some models support both image
and text embeddings and some just text embedding
api_key: str, default None
The api key to access Jina API. If you pass None, you can set JINA_API_KEY
environment variable
"""

name: str = "jina-clip-v1"
api_key: Optional[str] = None
_session: ClassVar = None

def ndims(self):
# TODO: fix hardcoding
return 768

def sanitize_input(self, inputs: IMAGES) -> Union[List[bytes], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(inputs, (str, bytes)):
inputs = [inputs]
elif isinstance(inputs, pa.Array):
inputs = inputs.to_pylist()
elif isinstance(inputs, pa.ChunkedArray):
inputs = inputs.combine_chunks().to_pylist()
return inputs

def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
"""
if isinstance(query, str):
return self.generate_text_embeddings([query])
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError(
"JinaEmbeddingFunction supports str or PIL Image as query"
)

def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_text_embeddings(texts)

def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
) -> np.ndarray:
"""
Generate the embedding for a single image
Parameters
----------
image : Union[str, bytes, PIL.Image.Image]
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, bytes):
image = {"image": base64.b64encode(image).decode("utf-8")}
if isinstance(image, PIL.Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
if parsed.scheme == "file":
pil_image = PIL.Image.open(parsed.path)
elif parsed.scheme == "":
pil_image = PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
pil_image = PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
return self._generate_embeddings(input=[image])[0]

def generate_text_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]:
return self._generate_embeddings(input=texts)

def _generate_embeddings(self, input: List, *args, **kwargs) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
self._init_client()
resp = JinaEmbeddings._session.post( # type: ignore
API_URL, json={"input": input, "model": self.name}
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])

embeddings = resp["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore

return [result["embedding"] for result in sorted_embeddings]

def _init_client(self):
if JinaEmbeddings._session is None:
if self.api_key is None and os.environ.get("JINA_API_KEY") is None:
api_key_not_found_help("jina")
api_key = self.api_key or os.environ.get("JINA_API_KEY")
JinaEmbeddings._session = requests.Session()
JinaEmbeddings._session.headers.update(
{"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}
)
2 changes: 2 additions & 0 deletions python/python/lancedb/rerankers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .cross_encoder import CrossEncoderReranker
from .linear_combination import LinearCombinationReranker
from .openai import OpenaiReranker
from .jinaai import JinaReranker

__all__ = [
"Reranker",
Expand All @@ -12,4 +13,5 @@
"LinearCombinationReranker",
"OpenaiReranker",
"ColbertReranker",
"JinaReranker",
]
Loading

0 comments on commit 08d25c5

Please sign in to comment.