Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(document-search): Option to choose between image and text embeddings for ImageElements #205

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/apps/documents_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ async def _handle_message(
if not self._documents_ingested:
yield self.NO_DOCUMENTS_INGESTED_MESSAGE
results = await self.document_search.search(message[-1])
prompt = RAGPrompt(QueryWithContext(query=message, context=[i.text_representation for i in results]))
prompt = RAGPrompt(
QueryWithContext(query=message, context=[i.text_representation for i in results if i.text_representation])
)
response = await self._llm.generate(prompt)
yield response.answer

Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import Embeddings
from .base import Embeddings, EmbeddingTypes
from .noop import NoopEmbeddings

__all__ = ["Embeddings", "NoopEmbeddings"]
__all__ = ["EmbeddingTypes", "Embeddings", "NoopEmbeddings"]

module = sys.modules[__name__]

Expand Down
10 changes: 10 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from abc import ABC, abstractmethod
from enum import Enum


class EmbeddingTypes(Enum):
"""
Enum for listing supported embedding types
"""

TEXT: str = "text"
IMAGE: str = "image"


class Embeddings(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel, Field

from ragbits.core.audit import traceable
from ragbits.core.embeddings import Embeddings, get_embeddings
from ragbits.core.embeddings import Embeddings, EmbeddingTypes, get_embeddings
from ragbits.core.vector_stores import VectorStore, get_vector_store
from ragbits.core.vector_stores.base import VectorStoreOptions
from ragbits.document_search.documents.document import Document, DocumentMeta
Expand Down Expand Up @@ -150,16 +150,29 @@ async def insert_elements(self, elements: list[Element]) -> None:
Args:
elements: The list of Elements to insert.
"""
vectors = await self.embedder.embed_text([element.key for element in elements])
elements_with_text = [element for element in elements if element.key]
images_with_text = [element for element in elements_with_text if isinstance(element, ImageElement)]
vectors = await self.embedder.embed_text([element.key for element in elements_with_text if element.key])

image_elements = [element for element in elements if isinstance(element, ImageElement)]
entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors, strict=False)]

num_images_with_no_textual_repr = len(image_elements) - len(images_with_text)
if num_images_with_no_textual_repr > 0:
warnings.warn(
f"{num_images_with_no_textual_repr} of {len(image_elements)}"
"Have no textual representation and have not been text emedded"
)

entries = [
element.to_vector_db_entry(vector, EmbeddingTypes.TEXT)
for element, vector in zip(elements_with_text, vectors, strict=False)
]

if image_elements and self.embedder.image_support():
image_vectors = await self.embedder.embed_image([element.image_bytes for element in image_elements])
entries.extend(
[
element.to_vector_db_entry(vector)
element.to_vector_db_entry(vector, EmbeddingTypes.IMAGE)
for element, vector in zip(image_elements, image_vectors, strict=False)
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pydantic import BaseModel, computed_field

from ragbits.core.embeddings import EmbeddingTypes
from ragbits.core.vector_stores.base import VectorStoreEntry
from ragbits.document_search.documents.document import DocumentMeta

Expand Down Expand Up @@ -42,15 +43,14 @@ def id(self) -> str:
id_components = [
self.document_meta.id,
self.element_type,
self.key,
self.text_representation,
self.key or "null",
str(self.location),
]
return str(uuid.uuid5(uuid.NAMESPACE_OID, ";".join(id_components)))

@computed_field # type: ignore[prop-decorator]
@property
def key(self) -> str:
def key(self) -> str | None:
"""
Get the representation of the element for embedding.

Expand All @@ -62,7 +62,7 @@ def key(self) -> str:
@computed_field # type: ignore[prop-decorator]
@property
@abstractmethod
def text_representation(self) -> str:
def text_representation(self) -> str | None:
"""
Get the text representation of the element.

Expand Down Expand Up @@ -90,23 +90,35 @@ def from_vector_db_entry(cls, db_entry: VectorStoreEntry) -> "Element":
"""
element_type = db_entry.metadata["element_type"]
element_cls = Element._elements_registry[element_type]
if "embedding_type" in db_entry.metadata:
del db_entry.metadata["embedding_type"]
return element_cls(**db_entry.metadata)

def to_vector_db_entry(self, vector: list[float]) -> VectorStoreEntry:
def to_vector_db_entry(self, vector: list[float], embedding_type: EmbeddingTypes | None = None) -> VectorStoreEntry:
"""
Create a vector database entry from the element.

Args:
vector: The vector.
embedding_type: EmbeddingTypes.TEXT, EmbeddingTypes.IMAGE or None

Returns:
The vector database entry
"""
metadata = self.model_dump(exclude={"id", "key"})
vector_store_entry_id = self.id
if embedding_type:
id_components = [
vector_store_entry_id,
str(embedding_type),
]
vector_store_entry_id = str(uuid.uuid5(uuid.NAMESPACE_OID, ";".join(id_components)))
metadata["embedding_type"] = str(embedding_type)
return VectorStoreEntry(
id=self.id,
key=self.key,
id=vector_store_entry_id,
key=self.key or "null",
vector=vector,
metadata=self.model_dump(exclude={"id", "key"}),
metadata=metadata,
)


Expand Down Expand Up @@ -142,11 +154,18 @@ class ImageElement(Element):

@computed_field # type: ignore[prop-decorator]
@property
def text_representation(self) -> str:
def text_representation(self) -> str | None:
"""
Get the text representation of the element.

Returns:
The text representation.
"""
return f"Description: {self.description}\nExtracted text: {self.ocr_extracted_text}"
if not self.description and not self.ocr_extracted_text:
return None
repr = ""
if self.description:
repr += f"Description: {self.description}\n"
if self.ocr_extracted_text:
repr += f"Extracted text: {self.ocr_extracted_text}"
return repr
Loading