Skip to content

Commit

Permalink
make embedding requests parameterized
Browse files Browse the repository at this point in the history
  • Loading branch information
pnadolny13 committed Nov 21, 2024
1 parent 3e47fc2 commit 1650e8b
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions map_gpt_embeddings/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ def map_schema_message(self, message_dict: dict) -> t.Iterable[Message]:
description="Whether to split document into chunks.",
default=True,
),
th.Property(
"embedding_model",
th.StringType,
description="The embedding model to use.",
default="text-embedding-ada-002",
),
th.Property(
"max_requests_per_minute",
th.NumberType,
description="The embedding model to use.",
default=3_000 * 0.5,
),
th.Property(
"max_tokens_per_minute",
th.NumberType,
description="The embedding model to use.",
default=1_000_000 * 0.5,
),
).to_dict()

def _validate_config(self, *, raise_errors: bool = True) -> list[str]:
Expand Down Expand Up @@ -166,7 +184,7 @@ def map_record_message(self, message_dict: dict) -> t.Iterable[RecordMessage]:
text = message_dict["record"][self.config["document_text_property"]]
request = {
"input": text.replace("\n", " "),
"model":"text-embedding-ada-002",
"model": self.config["embedding_model"],
"metadata": message_dict,
}
file.write(
Expand All @@ -182,8 +200,8 @@ def map_record_message(self, message_dict: dict) -> t.Iterable[RecordMessage]:
self.save_filepath.name,
request_url="https://api.openai.com/v1/embeddings",
api_key=self.config.get("openai_api_key", os.environ.get("OPENAI_API_KEY")),
max_requests_per_minute=3_000 * 0.5,
max_tokens_per_minute=1_000_000 * 0.5,
max_requests_per_minute=self.config["max_requests_per_minute"],
max_tokens_per_minute=self.config["max_tokens_per_minute"],
token_encoding_name="cl100k_base",
max_attempts=5,
logging_level=logging.DEBUG,
Expand Down

0 comments on commit 1650e8b

Please sign in to comment.