Skip to content
This repository has been archived by the owner on Apr 18, 2024. It is now read-only.

Commit

Permalink
Merge pull request #8 from dokulabs/cohere-fixes
Browse files Browse the repository at this point in the history
[Fix]: Cohere `generate` and `summarize` function monitoring
  • Loading branch information
patcher9 authored Mar 24, 2024
2 parents bd30ee8 + 26fa851 commit 8a4de93
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 26 deletions.
7 changes: 7 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[DESIGN]

# Maximum number of locals for function / method body
max-locals=25

# Maximum number of arguments for function / method
max-args=25
3 changes: 1 addition & 2 deletions src/dokumetry/__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,4 @@ def send_data(data, doku_url, doku_token):
timeout=30)
response.raise_for_status()
except requests.exceptions.RequestException as req_err:
logging.error("Error sending data to Doku: %s", req_err)
raise # Re-raise the exception after logging
logging.error("DokuMetry: Error sending data to Doku: %s", req_err)
69 changes: 61 additions & 8 deletions src/dokumetry/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def count_tokens(text):
return num_tokens

# pylint: disable=too-many-arguments, too-many-statements
def init(llm, doku_url, api_key, environment, application_name, skip_resp):
def init(llm, doku_url, api_key, environment, application_name, skip_resp): #pylint: disable=too-many-locals
"""
Initialize Cohere monitoring for Doku.
Expand All @@ -43,6 +43,7 @@ def init(llm, doku_url, api_key, environment, application_name, skip_resp):
original_generate = llm.generate
original_embed = llm.embed
original_chat = llm.chat
original_chat_stream = llm.chat_stream
original_summarize = llm.summarize

def patched_generate(*args, **kwargs):
Expand Down Expand Up @@ -93,8 +94,9 @@ def stream_generator():
duration = end_time - start_time
model = kwargs.get('model', 'command')
prompt = kwargs.get('prompt')

for generation in response:
prompt_tokens = response.meta.billed_units.input_tokens
completion_tokens = response.meta.billed_units.output_tokens
for generation in response.generations:
data = {
"llmReqId": generation.id,
"environment": environment,
Expand All @@ -103,8 +105,8 @@ def stream_generator():
"endpoint": "cohere.generate",
"skipResp": skip_resp,
"finishReason": generation.finish_reason,
"completionTokens": count_tokens(generation.text),
"promptTokens": count_tokens(prompt),
"completionTokens": completion_tokens,
"promptTokens": prompt_tokens,
"requestDuration": duration,
"model": model,
"prompt": prompt,
Expand Down Expand Up @@ -144,7 +146,7 @@ def embeddings_generate(*args, **kwargs):
"requestDuration": duration,
"model": model,
"prompt": prompt,
"promptTokens": response.meta["billed_units"]["input_tokens"],
"promptTokens": response.meta.billed_units.input_tokens,
}

send_data(data, doku_url, api_key)
Expand Down Expand Up @@ -223,6 +225,56 @@ def stream_generator():

return response

#pylint: disable=too-many-locals
def patched_chat_stream(*args, **kwargs):
"""
Patched version of Cohere's chat_stream method.
Args:
*args: Variable positional arguments.
**kwargs: Variable keyword arguments.
Returns:
CohereResponse: The response from Cohere's chat_stream.
"""
start_time = time.time()
def stream_generator():
accumulated_content = ""
for event in original_chat_stream(*args, **kwargs):
if event.event_type == "stream-end":
accumulated_content = event.response.text
response_id = event.response.response_id
prompt_tokens = event.response.meta["billed_units"]["input_tokens"]
completion_tokens = event.response.meta["billed_units"]["output_tokens"]
total_tokens = event.response.token_count["billed_tokens"]
finish_reason = event.finish_reason
yield event
end_time = time.time()
duration = end_time - start_time
prompt = kwargs.get('message', "No prompt provided")

data = {
"llmReqId": response_id,
"environment": environment,
"applicationName": application_name,
"sourceLanguage": "python",
"endpoint": "cohere.chat",
"skipResp": skip_resp,
"requestDuration": duration,
"model": kwargs.get('model', "command"),
"prompt": prompt,
"response": accumulated_content,
"promptTokens": prompt_tokens,
"completionTokens": completion_tokens,
"totalTokens": total_tokens,
"finishReason": finish_reason
}

send_data(data, doku_url, api_key)

return stream_generator()


def summarize_generate(*args, **kwargs):
"""
Patched version of Cohere's summarize generate method.
Expand Down Expand Up @@ -250,8 +302,8 @@ def summarize_generate(*args, **kwargs):
"endpoint": "cohere.summarize",
"skipResp": skip_resp,
"requestDuration": duration,
"completionTokens": response.meta["billed_units"]["output_tokens"],
"promptTokens": response.meta["billed_units"]["input_tokens"],
"completionTokens": response.meta.billed_units.output_tokens,
"promptTokens": response.meta.billed_units.input_tokens,
"model": model,
"prompt": prompt,
"response": response.summary
Expand All @@ -265,4 +317,5 @@ def summarize_generate(*args, **kwargs):
llm.generate = patched_generate
llm.embed = embeddings_generate
llm.chat = chat_generate
llm.chat_stream = patched_chat_stream
llm.summarize = summarize_generate
28 changes: 17 additions & 11 deletions tests/test_anthropic.py.hold → tests/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,20 @@ def test_messages():
"""
Test the 'messages.create' function of the Anthropic client.
"""
message = client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Hello, Claude",
}
],
model="claude-3-opus-20240229",
)
assert message.type == 'message'
try:
message = client.messages.create(
max_tokens=1,
messages=[
{
"role": "user",
"content": "Hello, Claude",
}
],
model="claude-3-haiku-20240307",
)
assert message.type == 'message'

# pylint: disable=broad-exception-caught
except Exception as e:
if "rate limit" in str(e).lower():
print("Rate Limited:", e)
8 changes: 4 additions & 4 deletions tests/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_summarize():
)
assert summarize_resp.id is not None

except cohere.error.CohereAPIError as e:
except cohere.core.api_error.ApiError as e:
print("Rate Limited:", e)

def test_generate_with_prompt():
Expand All @@ -82,7 +82,7 @@ def test_generate_with_prompt():
)
assert generate_resp.prompt == 'Doku'

except cohere.error.CohereAPIError as e:
except cohere.core.api_error.ApiError as e:
print("Rate Limited:", e)

def test_embed():
Expand All @@ -95,7 +95,7 @@ def test_embed():
)
assert embeddings_resp.meta is not None

except cohere.error.CohereAPIError as e:
except cohere.core.api_error.ApiError as e:
print("Rate Limited:", e)

def test_chat():
Expand All @@ -109,5 +109,5 @@ def test_chat():
)
assert chat_resp.response_id is not None

except cohere.error.CohereAPIError as e:
except cohere.core.api_error.ApiError as e:
print("Rate Limited:", e)
3 changes: 2 additions & 1 deletion tests/test_mistral.py.hold → tests/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def test_chat():

# No streaming
message = client.chat(
model="mistral-large-latest",
model="open-mistral-7b",
messages=messages,
max_tokens=1,
)
assert message.object == 'chat.completion'

Expand Down

0 comments on commit 8a4de93

Please sign in to comment.