Skip to content

Commit

Permalink
removed EmbeddingType
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas-Lemoine committed Aug 23, 2023
1 parent 3d8a9d9 commit 18fe229
Showing 1 changed file with 22 additions and 26 deletions.
48 changes: 22 additions & 26 deletions align_data/embeddings/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
encode_kwargs={"show_progress_bar": False},
)

EmbeddingType = List[float]
ModerationInfoType = Dict[str, Any]


Expand Down Expand Up @@ -88,22 +87,22 @@ def wrapper(*args, **kwargs):


@handle_openai_errors
def moderation_check(texts: List[str]):
def moderation_check(texts: List[str]) -> List[ModerationInfoType]:
return openai.Moderation.create(input=texts)["results"]


@handle_openai_errors
def compute_openai_embeddings(non_flagged_texts: List[str], engine: str, **kwargs):
def _compute_openai_embeddings(non_flagged_texts: List[str], engine: str, **kwargs) -> List[List[float]]:
data = openai.Embedding.create(input=non_flagged_texts, engine=engine, **kwargs).data
return [d["embedding"] for d in data]


def get_embeddings_without_moderation(
texts: List[str],
engine=OPENAI_EMBEDDINGS_MODEL,
engine: str = OPENAI_EMBEDDINGS_MODEL,
source: Optional[str] = None,
**kwargs,
) -> List[EmbeddingType]:
) -> List[List[float]]:
"""
Obtain embeddings without moderation checks.
Expand All @@ -114,17 +113,18 @@ def get_embeddings_without_moderation(
- **kwargs: Additional keyword arguments passed to the embedding function.
Returns:
- List[EmbeddingType]: List of embeddings for the provided texts.
- List[List[float]]: List of embeddings for the provided texts.
"""
if not texts:
return []

embeddings = []
if texts: # Only call the embedding function if there are non-flagged texts
if USE_OPENAI_EMBEDDINGS:
embeddings = compute_openai_embeddings(texts, engine, **kwargs)
elif hf_embedding_model:
embeddings = hf_embedding_model.embed_documents(texts)
else:
raise ValueError("No embedding model available.")
texts = [text.replace("\n", " ") for text in texts]
if USE_OPENAI_EMBEDDINGS:
embeddings = _compute_openai_embeddings(texts, engine, **kwargs)
elif hf_embedding_model:
embeddings = hf_embedding_model.embed_documents(texts)
else:
raise ValueError("No embedding model available.")

# Bias adjustment
if bias := EMBEDDING_LENGTH_BIAS.get(source or "", 1.0):
Expand All @@ -138,7 +138,7 @@ def get_embeddings_or_none_if_flagged(
engine=OPENAI_EMBEDDINGS_MODEL,
source: Optional[str] = None,
**kwargs,
) -> Tuple[Optional[List[EmbeddingType]], List[ModerationInfoType]]:
) -> Tuple[List[List[float]] | None, List[ModerationInfoType]]:
"""
Obtain embeddings for the provided texts. If any text is flagged during moderation,
the function returns None for the embeddings while still providing the moderation results.
Expand All @@ -150,7 +150,7 @@ def get_embeddings_or_none_if_flagged(
- **kwargs: Additional keyword arguments passed to the embedding function.
Returns:
- Tuple[Optional[List[EmbeddingType]], ModerationInfoListType]: Tuple containing the list of embeddings (or None if any text is flagged) and the moderation results.
- Tuple[Optional[List[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (or None if any text is flagged) and the moderation results.
"""
moderation_results = moderation_check(texts)
if any(result["flagged"] for result in moderation_results):
Expand All @@ -165,7 +165,7 @@ def get_embeddings(
engine=OPENAI_EMBEDDINGS_MODEL,
source: Optional[str] = None,
**kwargs,
) -> Tuple[List[Optional[EmbeddingType]], List[ModerationInfoType]]:
) -> Tuple[List[List[float] | None], List[ModerationInfoType]]:
"""
Obtain embeddings for the provided texts, replacing the embeddings of flagged texts with `None`.
Expand All @@ -176,7 +176,7 @@ def get_embeddings(
- **kwargs: Additional keyword arguments passed to the embedding function.
Returns:
- Tuple[List[Optional[EmbeddingType]], ModerationInfoListType]: Tuple containing the list of embeddings (with None for flagged texts) and the moderation results.
- Tuple[List[Optional[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (with None for flagged texts) and the moderation results.
"""
assert len(texts) <= 2048, "The batch size should not be larger than 2048."
assert all(texts), "No empty strings allowed in the input list."
Expand All @@ -186,23 +186,19 @@ def get_embeddings(

# Check all texts for moderation flags
moderation_results = moderation_check(texts)
flagged_bools = [result["flagged"] for result in moderation_results]
flags = [result["flagged"] for result in moderation_results]

non_flagged_texts = [text for text, flagged in zip(texts, flagged_bools) if not flagged]
non_flagged_texts = [text for text, flag in zip(texts, flags) if not flag]
non_flagged_embeddings = get_embeddings_without_moderation(
non_flagged_texts, engine, source, **kwargs
)

embeddings = []
for flagged in flagged_bools:
embeddings.append(None if flagged else non_flagged_embeddings.pop(0))

embeddings = [None if flag else non_flagged_embeddings.pop(0) for flag in flags]
return embeddings, moderation_results


def get_embedding(
text: str, engine=OPENAI_EMBEDDINGS_MODEL, source: Optional[str] = None, **kwargs
) -> Tuple[Optional[EmbeddingType], ModerationInfoType]:
) -> Tuple[List[float] | None, ModerationInfoType]:
"""Obtain an embedding for a single text."""
embedding, moderation_result = get_embeddings([text], engine, source, **kwargs)
return embedding[0], moderation_result[0]

0 comments on commit 18fe229

Please sign in to comment.