From 092e05e64c5248d6030041930c7df22aac40695d Mon Sep 17 00:00:00 2001 From: Delfer Date: Thu, 4 Jul 2024 05:46:49 +0000 Subject: [PATCH] Added max_tokens check --- bot/openai_helper.py | 44 +++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/bot/openai_helper.py b/bot/openai_helper.py index 5a1896cf..44093800 100644 --- a/bot/openai_helper.py +++ b/bot/openai_helper.py @@ -624,23 +624,31 @@ async def __summarise(self, conversation) -> str: def __max_model_tokens(self): base = 4096 + max_tokens = 0 if self.config['model'] in GPT_3_MODELS: - return base - if self.config['model'] in GPT_3_16K_MODELS: - return base * 4 - if self.config['model'] in GPT_4_MODELS: - return base * 2 - if self.config['model'] in GPT_4_32K_MODELS: - return base * 8 - if self.config['model'] in GPT_4_VISION_MODELS: - return base * 31 - if self.config['model'] in GPT_4_128K_MODELS: - return base * 31 - if self.config['model'] in GPT_4O_MODELS: - return base * 31 - raise NotImplementedError( - f"Max tokens for model {self.config['model']} is not implemented yet." - ) + max_tokens = base + elif self.config['model'] in GPT_3_16K_MODELS: + max_tokens = base * 4 + elif self.config['model'] in GPT_4_MODELS: + max_tokens = base * 2 + elif self.config['model'] in GPT_4_32K_MODELS: + max_tokens = base * 8 + elif self.config['model'] in GPT_4_VISION_MODELS: + max_tokens = base * 31 + elif self.config['model'] in GPT_4_128K_MODELS: + max_tokens = base * 31 + elif self.config['model'] in GPT_4O_MODELS: + max_tokens = base * 31 + else: + logging.warning( + f"Max tokens for model {self.config['model']} is not implemented yet." + ) + max_tokens = 200000 + + if self.config['max_tokens'] >= max_tokens: + raise Exception(f"max_tokens {self.config['max_tokens']} should be less than max tokens {max_tokens} for model {self.config['model']}.") + + return max_tokens # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def __count_tokens(self, messages) -> int: @@ -662,7 +670,9 @@ def __count_tokens(self, messages) -> int: tokens_per_message = 3 tokens_per_name = 1 else: - raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}.""") + tokens_per_message = 3 + tokens_per_name = 1 + logging.warn(f"""num_tokens_from_messages() is not implemented for model {model}.""") num_tokens = 0 for message in messages: num_tokens += tokens_per_message