From 3316648bff3f58e177693a150ad011902c6e6ed4 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Tue, 28 Nov 2023 13:31:12 +0100 Subject: [PATCH] fix openai version --- align_data/embeddings/embedding_utils.py | 9 ++++++--- align_data/settings.py | 3 +-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/align_data/embeddings/embedding_utils.py b/align_data/embeddings/embedding_utils.py index 1ad35df..f83c2d6 100644 --- a/align_data/embeddings/embedding_utils.py +++ b/align_data/embeddings/embedding_utils.py @@ -2,7 +2,8 @@ from typing import List, Tuple, Dict, Any, Optional, Callable from functools import wraps -import openai +from openai import OpenAI + from langchain.embeddings import HuggingFaceEmbeddings from openai import ( OpenAIError, @@ -21,11 +22,13 @@ from align_data.settings import ( USE_OPENAI_EMBEDDINGS, OPENAI_EMBEDDINGS_MODEL, + OPENAI_ORGANIZATION, EMBEDDING_LENGTH_BIAS, SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, DEVICE, ) +client = OpenAI(organization=OPENAI_ORGANIZATION) # -------------------- # CONSTANTS & CONFIGURATION @@ -90,7 +93,7 @@ def wrapper(*args, **kwargs): @handle_openai_errors def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType]: """Process a batch for moderation checks.""" - return openai.Moderation.create(input=batch)["results"] + return client.moderations.create(input=batch)["results"] def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counter: Callable[[str], int] = len) -> List[ModerationInfoType]: @@ -125,7 +128,7 @@ def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counte @handle_openai_errors def _single_batch_compute_openai_embeddings(batch: List[str], **kwargs) -> List[List[float]]: """Compute embeddings for a batch.""" - batch_data = openai.Embedding.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data + batch_data = client.embeddings.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data return [d["embedding"] for d in batch_data] diff --git a/align_data/settings.py b/align_data/settings.py index 6f747d8..95f8e7f 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -63,8 +63,7 @@ OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002" OPENAI_EMBEDDINGS_DIMS = 1536 OPENAI_EMBEDDINGS_RATE_LIMIT = 3500 -openai.api_key = os.environ.get("OPENAI_API_KEY", None) -openai.organization = os.environ.get("OPENAI_ORGANIZATION", None) +OPENAI_ORGANIZATION = os.environ.get("OPENAI_ORGANIZATION", None) SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1" SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768