diff --git a/src/dokumetry/cohere.py b/src/dokumetry/cohere.py index df33175..7e44f3b 100644 --- a/src/dokumetry/cohere.py +++ b/src/dokumetry/cohere.py @@ -43,6 +43,7 @@ def init(llm, doku_url, api_key, environment, application_name, skip_resp): original_generate = llm.generate original_embed = llm.embed original_chat = llm.chat + original_chat_stream = llm.chat_stream original_summarize = llm.summarize def patched_generate(*args, **kwargs): @@ -224,6 +225,56 @@ def stream_generator(): return response + #pylint: disable=too-many-locals + def patched_chat_stream(*args, **kwargs): + """ + Patched version of Cohere's chat_stream method. + + Args: + *args: Variable positional arguments. + **kwargs: Variable keyword arguments. + + Returns: + CohereResponse: The response from Cohere's chat_stream. + """ + start_time = time.time() + def stream_generator(): + accumulated_content = "" + for event in original_chat_stream(*args, **kwargs): + if event.event_type == "stream-end": + accumulated_content = event.response.text + response_id = event.response.response_id + prompt_tokens = event.response.meta["billed_units"]["input_tokens"] + completion_tokens = event.response.meta["billed_units"]["output_tokens"] + total_tokens = event.response.token_count["billed_tokens"] + finish_reason = event.finish_reason + yield event + end_time = time.time() + duration = end_time - start_time + prompt = kwargs.get('message', "No prompt provided") + + data = { + "llmReqId": response_id, + "environment": environment, + "applicationName": application_name, + "sourceLanguage": "python", + "endpoint": "cohere.chat", + "skipResp": skip_resp, + "requestDuration": duration, + "model": kwargs.get('model', "command"), + "prompt": prompt, + "response": accumulated_content, + "promptTokens": prompt_tokens, + "completionTokens": completion_tokens, + "totalTokens": total_tokens, + "finishReason": finish_reason + } + + send_data(data, doku_url, api_key) + + return stream_generator() + + def summarize_generate(*args, **kwargs): """ Patched version of Cohere's summarize generate method. @@ -266,4 +317,5 @@ def summarize_generate(*args, **kwargs): llm.generate = patched_generate llm.embed = embeddings_generate llm.chat = chat_generate + llm.chat_stream = patched_chat_stream llm.summarize = summarize_generate