Skip to content

Commit

Permalink
Fix moderation batching (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik authored Oct 18, 2023
1 parent 57db939 commit 714b252
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions align_data/embeddings/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List, Tuple, Dict, Any, Optional
from typing import List, Tuple, Dict, Any, Optional, Callable
from functools import wraps

import openai
Expand Down Expand Up @@ -93,11 +93,31 @@ def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType]
return openai.Moderation.create(input=batch)["results"]


def moderation_check(texts: List[str], max_texts_num: int = 32) -> List[ModerationInfoType]:
"""Batch moderation checks on list of texts."""
def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counter: Callable[[str], int] = len) -> List[ModerationInfoType]:
"""Batch moderation checks on list of texts.
:param List[str] texts: the texts to be checked
:param int max_batch_size: the max size in tokens for a single batch
:param Callable[[str], int] tokens_counter: the function used to count tokens
"""
# A very ugly loop that will split the `texts` into smaller batches so that the
# total sum of tokens in each batch will not exceed `max_batch_size`
parts = []
part = []
part_count = 0
for item in texts:
if part_count + tokens_counter(item) > max_batch_size:
parts.append(part)
part = []
part_count = 0
part.append(item)
part_count += tokens_counter(item)
if part:
parts.append(part)

return [
result
for batch in (texts[i : i + max_texts_num] for i in range(0, len(texts), max_texts_num))
for batch in parts
for result in _single_batch_moderation_check(batch)
]

Expand Down

0 comments on commit 714b252

Please sign in to comment.