From 1650e8b1a91c002edc644e3a30969842a93da468 Mon Sep 17 00:00:00 2001 From: pnadolny13 Date: Wed, 20 Nov 2024 23:42:38 -0500 Subject: [PATCH] make embedding requests parameterized --- map_gpt_embeddings/mappers.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/map_gpt_embeddings/mappers.py b/map_gpt_embeddings/mappers.py index 16358fe..0b7c32b 100644 --- a/map_gpt_embeddings/mappers.py +++ b/map_gpt_embeddings/mappers.py @@ -96,6 +96,24 @@ def map_schema_message(self, message_dict: dict) -> t.Iterable[Message]: description="Whether to split document into chunks.", default=True, ), + th.Property( + "embedding_model", + th.StringType, + description="The embedding model to use.", + default="text-embedding-ada-002", + ), + th.Property( + "max_requests_per_minute", + th.NumberType, + description="The embedding model to use.", + default=3_000 * 0.5, + ), + th.Property( + "max_tokens_per_minute", + th.NumberType, + description="The embedding model to use.", + default=1_000_000 * 0.5, + ), ).to_dict() def _validate_config(self, *, raise_errors: bool = True) -> list[str]: @@ -166,7 +184,7 @@ def map_record_message(self, message_dict: dict) -> t.Iterable[RecordMessage]: text = message_dict["record"][self.config["document_text_property"]] request = { "input": text.replace("\n", " "), - "model":"text-embedding-ada-002", + "model": self.config["embedding_model"], "metadata": message_dict, } file.write( @@ -182,8 +200,8 @@ def map_record_message(self, message_dict: dict) -> t.Iterable[RecordMessage]: self.save_filepath.name, request_url="https://api.openai.com/v1/embeddings", api_key=self.config.get("openai_api_key", os.environ.get("OPENAI_API_KEY")), - max_requests_per_minute=3_000 * 0.5, - max_tokens_per_minute=1_000_000 * 0.5, + max_requests_per_minute=self.config["max_requests_per_minute"], + max_tokens_per_minute=self.config["max_tokens_per_minute"], token_encoding_name="cl100k_base", max_attempts=5, logging_level=logging.DEBUG,