diff --git a/src/dokumetry/cohere.py b/src/dokumetry/cohere.py index 7e44f3b..bb3dd94 100644 --- a/src/dokumetry/cohere.py +++ b/src/dokumetry/cohere.py @@ -26,7 +26,7 @@ def count_tokens(text): return num_tokens -# pylint: disable=too-many-arguments, too-many-statements +# pylint: disable=too-many-arguments, too-many-statements, too-many-locals def init(llm, doku_url, api_key, environment, application_name, skip_resp): """ Initialize Cohere monitoring for Doku. @@ -94,8 +94,8 @@ def stream_generator(): duration = end_time - start_time model = kwargs.get('model', 'command') prompt = kwargs.get('prompt') - promptTokens = response.meta.billed_units.input_tokens - completionTokens = response.meta.billed_units.output_tokens + prompt_tokens = response.meta.billed_units.input_tokens + completion_tokens = response.meta.billed_units.output_tokens for generation in response.generations: data = { "llmReqId": generation.id, @@ -105,8 +105,8 @@ def stream_generator(): "endpoint": "cohere.generate", "skipResp": skip_resp, "finishReason": generation.finish_reason, - "completionTokens": completionTokens, - "promptTokens": promptTokens, + "completionTokens": completion_tokens, + "promptTokens": prompt_tokens, "requestDuration": duration, "model": model, "prompt": prompt, @@ -220,7 +220,7 @@ def stream_generator(): "totalTokens": response.token_count["billed_tokens"], "response": response.text } - + send_data(data, doku_url, api_key) return response diff --git a/tests/test_cohere.py b/tests/test_cohere.py index 7682c3e..09d43ad 100644 --- a/tests/test_cohere.py +++ b/tests/test_cohere.py @@ -68,7 +68,7 @@ def test_summarize(): ) assert summarize_resp.id is not None - except cohere.error.CohereAPIError as e: + except cohere.core.api_error.ApiError as e: print("Rate Limited:", e) def test_generate_with_prompt(): @@ -82,7 +82,7 @@ def test_generate_with_prompt(): ) assert generate_resp.prompt == 'Doku' - except cohere.error.CohereAPIError as e: + except cohere.core.api_error.ApiError as e: print("Rate Limited:", e) def test_embed(): @@ -95,7 +95,7 @@ def test_embed(): ) assert embeddings_resp.meta is not None - except cohere.error.CohereAPIError as e: + except cohere.core.api_error.ApiError as e: print("Rate Limited:", e) def test_chat(): @@ -109,5 +109,5 @@ def test_chat(): ) assert chat_resp.response_id is not None - except cohere.error.CohereAPIError as e: + except cohere.core.api_error.ApiError as e: print("Rate Limited:", e)