From db053d496f0da8661c50efc9dcd3e809144dfd2c Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Tue, 28 Nov 2023 13:31:12 +0100 Subject: [PATCH] fix openai version --- .github/workflows/fetch-dataset.yml | 1 + align_data/embeddings/embedding_utils.py | 10 +++++++--- align_data/settings.py | 4 ++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.github/workflows/fetch-dataset.yml b/.github/workflows/fetch-dataset.yml index 80fcb58..7731097 100644 --- a/.github/workflows/fetch-dataset.yml +++ b/.github/workflows/fetch-dataset.yml @@ -73,6 +73,7 @@ jobs: CODA_TOKEN: ${{ secrets.CODA_TOKEN || inputs.coda_token }} AIRTABLE_API_KEY: ${{ secrets.AIRTABLE_API_KEY || inputs.airtable_api_key }} YOUTUBE_API_KEY: ${{ secrets.YOUTUBE_API_KEY || inputs.youtube_api_key }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY || inputs.openai_api_key }} ARD_DB_USER: ${{ secrets.ARD_DB_USER || inputs.db_user }} ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD || inputs.db_password }} ARD_DB_HOST: ${{ secrets.ARD_DB_HOST || inputs.db_host }} diff --git a/align_data/embeddings/embedding_utils.py b/align_data/embeddings/embedding_utils.py index 1ad35df..b0b1d16 100644 --- a/align_data/embeddings/embedding_utils.py +++ b/align_data/embeddings/embedding_utils.py @@ -2,7 +2,8 @@ from typing import List, Tuple, Dict, Any, Optional, Callable from functools import wraps -import openai +from openai import OpenAI + from langchain.embeddings import HuggingFaceEmbeddings from openai import ( OpenAIError, @@ -21,11 +22,14 @@ from align_data.settings import ( USE_OPENAI_EMBEDDINGS, OPENAI_EMBEDDINGS_MODEL, + OPENAI_API_KEY, + OPENAI_ORGANIZATION, EMBEDDING_LENGTH_BIAS, SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, DEVICE, ) +client = OpenAI(api_key=OPENAI_API_KEY, organization=OPENAI_ORGANIZATION) # -------------------- # CONSTANTS & CONFIGURATION @@ -90,7 +94,7 @@ def wrapper(*args, **kwargs): @handle_openai_errors def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType]: """Process a batch for moderation checks.""" - return openai.Moderation.create(input=batch)["results"] + return client.moderations.create(input=batch)["results"] def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counter: Callable[[str], int] = len) -> List[ModerationInfoType]: @@ -125,7 +129,7 @@ def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counte @handle_openai_errors def _single_batch_compute_openai_embeddings(batch: List[str], **kwargs) -> List[List[float]]: """Compute embeddings for a batch.""" - batch_data = openai.Embedding.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data + batch_data = client.embeddings.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data return [d["embedding"] for d in batch_data] diff --git a/align_data/settings.py b/align_data/settings.py index 6f747d8..383e89f 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -63,8 +63,8 @@ OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002" OPENAI_EMBEDDINGS_DIMS = 1536 OPENAI_EMBEDDINGS_RATE_LIMIT = 3500 -openai.api_key = os.environ.get("OPENAI_API_KEY", None) -openai.organization = os.environ.get("OPENAI_ORGANIZATION", None) +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None) +OPENAI_ORGANIZATION = os.environ.get("OPENAI_ORGANIZATION", None) SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1" SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768