Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Data Loss in Batch Embedding Processing for llamindex #1139

Closed
hu9029 opened this issue Dec 3, 2024 · 2 comments · Fixed by #1166
Closed

[bug] Data Loss in Batch Embedding Processing for llamindex #1139

hu9029 opened this issue Dec 3, 2024 · 2 comments · Fixed by #1166
Assignees
Labels
bug Something isn't working language: python Related to Python integration

Comments

@hu9029
Copy link

hu9029 commented Dec 3, 2024

Describe the Bug

In the llamindex BaseEmbedding class:

class BaseEmbedding(TransformComponent, DispatcherSpanMixin):  
    @dispatcher.span
    def get_text_embedding_batch(
        self,
        texts: List[str],
        show_progress: bool = False,
        **kwargs: Any,
    ) -> List[Embedding]:
           """Get a list of text embeddings, with batching."""
          cur_batch: List[str] = []
          result_embeddings: List[Embedding] = []
  
          queue_with_progress = enumerate(
              get_tqdm_iterable(texts, show_progress, "Generating embeddings")
          )
  
          model_dict = self.to_dict()
          model_dict.pop("api_key", None)
          for idx, text in queue_with_progress:
              cur_batch.append(text)
              if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size:
                  # flush
                  dispatcher.event(
                      EmbeddingStartEvent(
                          model_dict=model_dict,
                      )
                  )
                  with self.callback_manager.event(
                      CBEventType.EMBEDDING,
                      payload={EventPayload.SERIALIZED: self.to_dict()},
                  ) as event:
                      embeddings = self._get_text_embeddings(cur_batch)
                      result_embeddings.extend(embeddings)
                      event.on_end(
                          payload={
                              EventPayload.CHUNKS: cur_batch,
                              EventPayload.EMBEDDINGS: embeddings,
                          },
                      )
                  dispatcher.event(
                      EmbeddingEndEvent(
                          chunks=cur_batch,
                          embeddings=embeddings,
                      )
                  )
                  cur_batch = []
  
          return result_embeddings

When embed_batch_size < len(texts), multiple EmbeddingEndEvent instances are emitted. However, the event handler in openinference.instrumentation.llama_index._handler.py processes these events by using keys like {EMBEDDING_EMBEDDINGS}.{i}.{EMBEDDING_TEXT} to store data:

@_process_event.register
def _(self, event: EmbeddingEndEvent) -> None:
    for i, (text, vector) in enumerate(zip(event.chunks, event.embeddings)):
        self[f"{EMBEDDING_EMBEDDINGS}.{i}.{EMBEDDING_TEXT}"] = text
        self[f"{EMBEDDING_EMBEDDINGS}.{i}.{EMBEDDING_VECTOR}"] = vector

This causes only the last batch of data to be saved, while data from earlier batches is overwritten and lost.


To Reproduce

  1. Use llamindex to construct a VectorStoreIndex.
  2. Ensure the number of nodes exceeds the batch size defined in the embedding model.

Expected Behavior

All vector data should be retained, ensuring no data is lost during the embedding process.


Screenshots

(No screenshots provided.)


Desktop (please complete the following information)

  • OS: Windows
  • Version: 11

Additional Context

The issue occurs due to the handling of keys in the event processing, where unique keys are not assigned to each batch's data.

@hu9029 hu9029 added bug Something isn't working triage Issues that require triage labels Dec 3, 2024
@github-project-automation github-project-automation bot moved this to 📘 Todo in phoenix Dec 3, 2024
@dosubot dosubot bot added the language: python Related to Python integration label Dec 3, 2024
@hu9029
Copy link
Author

hu9029 commented Dec 3, 2024

change like this work fine

@_process_event.register
def _(self, event: EmbeddingEndEvent) -> None:
    index = self._attributes.get("embedding_index", 0)
    for i, (text, vector) in enumerate(zip(event.chunks, event.embeddings)):
        self[f"{EMBEDDING_EMBEDDINGS}.{index}.{EMBEDDING_TEXT}"] = text
        self[f"{EMBEDDING_EMBEDDINGS}.{index}.{EMBEDDING_VECTOR}"] = vector
        index += 1
    self["embedding_index"] = index

@mikeldking
Copy link
Contributor

Hey @hu9029 thanks for the detailed report! Makes sense to me. Seems like we need to track the index of these embeddings as the events fire.

@mikeldking mikeldking removed the triage Issues that require triage label Dec 3, 2024
@RogerHYang RogerHYang moved this from 📘 Todo to 👨‍💻 In progress in phoenix Dec 4, 2024
@RogerHYang RogerHYang moved this from 👨‍💻 In progress to 📘 Todo in phoenix Dec 5, 2024
@RogerHYang RogerHYang moved this from 📘 Todo to 👨‍💻 In progress in phoenix Dec 11, 2024
@RogerHYang RogerHYang moved this from 👨‍💻 In progress to 🔍. Needs Review in phoenix Dec 12, 2024
@github-project-automation github-project-automation bot moved this from 🔍. Needs Review to ✅ Done in phoenix Dec 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working language: python Related to Python integration
Projects
Archived in project
Development

Successfully merging a pull request may close this issue.

3 participants