Skip to content

Commit

Permalink
Refactor MM Vector store and Index for empty collection (run-llama#9717)
Browse files Browse the repository at this point in the history
  • Loading branch information
hatianzhang authored Dec 27, 2023
1 parent 68e0346 commit a81c009
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 83 deletions.
100 changes: 65 additions & 35 deletions llama_index/indices/multi_modal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,19 @@ def __init__(
image_vector_store: Optional[VectorStore] = None,
image_embed_model: EmbedType = "clip",
is_image_to_text: bool = False,
# is_image_vector_store_empty is used to indicate whether image_vector_store is empty
# those flags are used for cases when only one vector store is used
is_image_vector_store_empty: bool = False,
is_text_vector_store_empty: bool = False,
**kwargs: Any,
) -> None:
"""Initialize params."""
image_embed_model = resolve_embed_model(image_embed_model)
assert isinstance(image_embed_model, MultiModalEmbedding)
self._image_embed_model = image_embed_model
self._is_image_to_text = is_image_to_text
self._is_image_vector_store_empty = is_image_vector_store_empty
self._is_text_vector_store_empty = is_text_vector_store_empty
storage_context = storage_context or StorageContext.from_defaults()

if image_vector_store is not None:
Expand Down Expand Up @@ -96,6 +102,14 @@ def image_vector_store(self) -> VectorStore:
def image_embed_model(self) -> MultiModalEmbedding:
return self._image_embed_model

@property
def is_image_vector_store_empty(self) -> bool:
return self._is_image_vector_store_empty

@property
def is_text_vector_store_empty(self) -> bool:
return self._is_text_vector_store_empty

def as_retriever(self, **kwargs: Any) -> BaseRetriever:
# NOTE: lazy import
from llama_index.indices.multi_modal.retriever import (
Expand Down Expand Up @@ -263,30 +277,38 @@ async def _async_add_nodes_to_index(

image_nodes: List[ImageNode] = []
text_nodes: List[BaseNode] = []
new_text_ids: List[str] = []
new_img_ids: List[str] = []

for node in nodes:
if isinstance(node, ImageNode):
image_nodes.append(node)
if node.text:
text_nodes.append(node)

# embed all nodes as text - incclude image nodes that have text attached
text_nodes = await self._aget_node_with_embedding(
text_nodes, show_progress, is_image=False
)
new_text_ids = await self.storage_context.vector_stores[
DEFAULT_VECTOR_STORE
].async_add(text_nodes, **insert_kwargs)

# embed image nodes as images directly
image_nodes = await self._aget_node_with_embedding(
image_nodes,
show_progress,
is_image=True,
)
new_img_ids = await self.storage_context.vector_stores[
self.image_namespace
].async_add(image_nodes, **insert_kwargs)
if len(text_nodes) > 0:
# embed all nodes as text - include image nodes that have text attached
text_nodes = await self._aget_node_with_embedding(
text_nodes, show_progress, is_image=False
)
new_text_ids = await self.storage_context.vector_stores[
DEFAULT_VECTOR_STORE
].async_add(text_nodes, **insert_kwargs)
else:
self._is_text_vector_store_empty = True

if len(image_nodes) > 0:
# embed image nodes as images directly
image_nodes = await self._aget_node_with_embedding(
image_nodes,
show_progress,
is_image=True,
)
new_img_ids = await self.storage_context.vector_stores[
self.image_namespace
].async_add(image_nodes, **insert_kwargs)
else:
self._is_image_vector_store_empty = True

# if the vector store doesn't store text, we need to add the nodes to the
# index struct and document store
Expand Down Expand Up @@ -316,31 +338,39 @@ def _add_nodes_to_index(

image_nodes: List[ImageNode] = []
text_nodes: List[BaseNode] = []
new_text_ids: List[str] = []
new_img_ids: List[str] = []

for node in nodes:
if isinstance(node, ImageNode):
image_nodes.append(node)
if node.text:
text_nodes.append(node)

# embed all nodes as text - incclude image nodes that have text attached
text_nodes = self._get_node_with_embedding(
text_nodes, show_progress, is_image=False
)
new_text_ids = self.storage_context.vector_stores[DEFAULT_VECTOR_STORE].add(
text_nodes, **insert_kwargs
)

# embed image nodes as images directly
# check if we should use text embedding for images instead of default
image_nodes = self._get_node_with_embedding(
image_nodes,
show_progress,
is_image=True,
)
new_img_ids = self.storage_context.vector_stores[self.image_namespace].add(
image_nodes, **insert_kwargs
)
if len(text_nodes) > 0:
# embed all nodes as text - include image nodes that have text attached
text_nodes = self._get_node_with_embedding(
text_nodes, show_progress, is_image=False
)
new_text_ids = self.storage_context.vector_stores[DEFAULT_VECTOR_STORE].add(
text_nodes, **insert_kwargs
)
else:
self._is_text_vector_store_empty = True

if len(image_nodes) > 0:
# embed image nodes as images directly
# check if we should use text embedding for images instead of default
image_nodes = self._get_node_with_embedding(
image_nodes,
show_progress,
is_image=True,
)
new_img_ids = self.storage_context.vector_stores[self.image_namespace].add(
image_nodes, **insert_kwargs
)
else:
self._is_image_vector_store_empty = True

# if the vector store doesn't store text, we need to add the nodes to the
# index struct and document store
Expand Down
115 changes: 67 additions & 48 deletions llama_index/indices/multi_modal/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,20 @@ def _text_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if self._vector_store.is_embedding_query:
if query_bundle.embedding is None and len(query_bundle.embedding_strs) > 0:
query_bundle.embedding = (
self._service_context.embed_model.get_agg_embedding_from_queries(
if not self._index.is_text_vector_store_empty:
if self._vector_store.is_embedding_query:
if (
query_bundle.embedding is None
and len(query_bundle.embedding_strs) > 0
):
query_bundle.embedding = self._service_context.embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
)
return self._get_nodes_with_embeddings(
query_bundle, self._similarity_top_k, self._vector_store
)
return self._get_nodes_with_embeddings(
query_bundle, self._similarity_top_k, self._vector_store
)
else:
return []

def text_retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
if isinstance(str_or_query_bundle, str):
Expand All @@ -153,16 +157,19 @@ def _text_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = (
self._image_embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
if not self._index.is_image_vector_store_empty:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = (
self._image_embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
)
return self._get_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
return self._get_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
else:
return []

def text_to_image_retrieve(
self, str_or_query_bundle: QueryType
Expand All @@ -175,15 +182,18 @@ def _image_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Image encoder for image input
assert isinstance(self._index.image_embed_model, MultiModalEmbedding)
query_bundle.embedding = self._image_embed_model.get_image_embedding(
query_bundle.embedding_image[0]
if not self._index.is_image_vector_store_empty:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Image encoder for image input
assert isinstance(self._index.image_embed_model, MultiModalEmbedding)
query_bundle.embedding = self._image_embed_model.get_image_embedding(
query_bundle.embedding_image[0]
)
return self._get_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
return self._get_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
else:
return []

def image_to_image_retrieve(
self, str_or_query_bundle: QueryType
Expand Down Expand Up @@ -270,16 +280,17 @@ async def _atext_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if self._vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = (
await self._service_context.embed_model.aget_agg_embedding_from_queries(
if not self._index.is_text_vector_store_empty:
if self._vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = await self._service_context.embed_model.aget_agg_embedding_from_queries(
query_bundle.embedding_strs
)
return await self._aget_nodes_with_embeddings(
query_bundle, self._similarity_top_k, self._vector_store
)
return await self._aget_nodes_with_embeddings(
query_bundle, self._similarity_top_k, self._vector_store
)
else:
return []

async def atext_retrieve(
self, str_or_query_bundle: QueryType
Expand All @@ -292,16 +303,19 @@ async def _atext_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = (
await self._image_embed_model.aget_agg_embedding_from_queries(
query_bundle.embedding_strs
if not self._index.is_image_vector_store_empty:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = (
await self._image_embed_model.aget_agg_embedding_from_queries(
query_bundle.embedding_strs
)
)
return await self._aget_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
return await self._aget_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
else:
return []

async def atext_to_image_retrieve(
self, str_or_query_bundle: QueryType
Expand All @@ -326,16 +340,21 @@ async def _aimage_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Image encoder for image input
assert isinstance(self._index.image_embed_model, MultiModalEmbedding)
# Using the first imaage in the list for image retrieval
query_bundle.embedding = await self._image_embed_model.aget_image_embedding(
query_bundle.embedding_image[0]
if not self._index.is_image_vector_store_empty:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Image encoder for image input
assert isinstance(self._index.image_embed_model, MultiModalEmbedding)
# Using the first imaage in the list for image retrieval
query_bundle.embedding = (
await self._image_embed_model.aget_image_embedding(
query_bundle.embedding_image[0]
)
)
return await self._aget_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
return await self._aget_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
else:
return []

async def aimage_to_image_retrieve(
self, str_or_query_bundle: QueryType
Expand Down

0 comments on commit a81c009

Please sign in to comment.