From 2f7eb24d00a7d41f274bde5c50901125d588aef8 Mon Sep 17 00:00:00 2001 From: patcher99 Date: Fri, 8 Mar 2024 15:06:37 +0530 Subject: [PATCH] Update for Anthropic --- src/dokumetry/__init__.py | 4 +- src/dokumetry/anthropic.py | 141 ++++++++++++++++++++++++++++--------- 2 files changed, 108 insertions(+), 37 deletions(-) diff --git a/src/dokumetry/__init__.py b/src/dokumetry/__init__.py index 22af90d..f579eb0 100644 --- a/src/dokumetry/__init__.py +++ b/src/dokumetry/__init__.py @@ -41,13 +41,13 @@ def init(llm, doku_url, api_key, environment="default", application_name="defaul DokuConfig.skip_resp = skip_resp # pylint: disable=no-else-return, line-too-long - if hasattr(llm.chat, 'completions') and callable(llm.chat.completions.create) and ('.openai.azure.com/' not in str(llm.base_url)): + if hasattr(llm, 'chat') and callable(llm.chat.completions.create) and ('.openai.azure.com/' not in str(llm.base_url)): init_openai(llm, doku_url, api_key, environment, application_name, skip_resp) return # pylint: disable=no-else-return elif hasattr(llm, 'generate') and callable(llm.generate): init_cohere(llm, doku_url, api_key, environment, application_name, skip_resp) return - elif hasattr(llm, 'count_tokens') and callable(llm.count_tokens): + elif hasattr(llm, 'messages') and callable(llm.messages.create): init_anthropic(llm, doku_url, api_key, environment, application_name, skip_resp) return diff --git a/src/dokumetry/anthropic.py b/src/dokumetry/anthropic.py index a0e3688..ff4508c 100644 --- a/src/dokumetry/anthropic.py +++ b/src/dokumetry/anthropic.py @@ -19,11 +19,11 @@ def init(llm, doku_url, api_key, environment, application_name, skip_resp): skip_resp (bool): Skip response processing. """ - original_completions_create = llm.completions.create + original_messages_create = llm.messages.create - def patched_completions_create(*args, **kwargs): + def patched_messages_create(*args, **kwargs): """ - Patched version of Anthropic's completions.create method. + Patched version of Anthropic's messages.create method. Args: *args: Variable positional arguments. @@ -32,36 +32,107 @@ def patched_completions_create(*args, **kwargs): Returns: AnthropicResponse: The response from Anthropic's completions.create. """ - + streaming = kwargs.get('stream', False) start_time = time.time() - response = original_completions_create(*args, **kwargs) - end_time = time.time() - duration = end_time - start_time - - model = kwargs.get('model') if 'model' in kwargs else args[0] - prompt = kwargs.get('prompt') if 'prompt' in kwargs else args[2] - - prompt_tokens = llm.count_tokens(prompt) - completion_tokens = llm.count_tokens(response.completion) - - data = { - "llmReqId": response.id, - "environment": environment, - "applicationName": application_name, - "sourceLanguage": "python", - "endpoint": "anthropic.completions", - "skipResp": skip_resp, - "completionTokens": completion_tokens, - "promptTokens": prompt_tokens, - "requestDuration": duration, - "model": model, - "prompt": prompt, - "finishReason": response.stop_reason, - "response": response.completion - } - - send_data(data, doku_url, api_key) - - return response - - llm.completions.create = patched_completions_create + if streaming: + def stream_generator(): + accumulated_content = "" + for event in original_messages_create(*args, **kwargs): + if event.type == "message_start": + response_id = event.message.id + prompt_tokens = event.message.usage.input_tokens + if event.type == "content_block_delta": + accumulated_content += event.delta.text + if event.type == "message_delta": + completion_tokens = event.usage.output_tokens + yield event + end_time = time.time() + duration = end_time - start_time + message_prompt = kwargs.get('messages', "No prompt provided") + formatted_messages = [] + for message in message_prompt: + role = message["role"] + content = message["content"] + + if isinstance(content, list): + content_str = ", ".join( + #pylint: disable=line-too-long + f"{item['type']}: {item['text'] if 'text' in item else item['image_url']}" + if 'type' in item else f"text: {item['text']}" + for item in content + ) + formatted_messages.append(f"{role}: {content_str}") + else: + formatted_messages.append(f"{role}: {content}") + + prompt = "\n".join(formatted_messages) + data = { + "llmReqId": response_id, + "environment": environment, + "applicationName": application_name, + "sourceLanguage": "python", + "endpoint": "anthropic.messages", + "skipResp": skip_resp, + "requestDuration": duration, + "model": kwargs.get('model', "command"), + "prompt": prompt, + "response": accumulated_content, + "promptTokens": prompt_tokens, + "completionTokens": completion_tokens, + } + data["totalTokens"] = data["completionTokens"] + data["promptTokens"] + + send_data(data, doku_url, api_key) + + return stream_generator() + else: + start_time = time.time() + response = original_messages_create(*args, **kwargs) + end_time = time.time() + duration = end_time - start_time + message_prompt = kwargs.get('messages', "No prompt provided") + formatted_messages = [] + + for message in message_prompt: + role = message["role"] + content = message["content"] + + if isinstance(content, list): + content_str = ", ".join( + f"{item['type']}: {item['text'] if 'text' in item else item['image_url']}" + if 'type' in item else f"text: {item['text']}" + for item in content + ) + formatted_messages.append(f"{role}: {content_str}") + else: + formatted_messages.append(f"{role}: {content}") + + prompt = "\n".join(formatted_messages) + + model = kwargs.get('model') + + prompt_tokens = response.usage.input_tokens + completion_tokens = response.usage.output_tokens + + data = { + "llmReqId": response.id, + "environment": environment, + "applicationName": application_name, + "sourceLanguage": "python", + "endpoint": "anthropic.messages", + "skipResp": skip_resp, + "completionTokens": completion_tokens, + "promptTokens": prompt_tokens, + "requestDuration": duration, + "model": model, + "prompt": prompt, + "finishReason": response.stop_reason, + "response": response.content[0].text + } + + send_data(data, doku_url, api_key) + + return response + + + llm.messages.create = patched_messages_create \ No newline at end of file