Skip to content

Commit

Permalink
feat: Add LDAIConfigTracker.get_summary method (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
keelerm84 authored Dec 13, 2024
1 parent fcc720a commit e425b1f
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 25 deletions.
6 changes: 3 additions & 3 deletions ldai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class LDAIClient:
"""The LaunchDarkly AI SDK client object."""

def __init__(self, client: LDClient):
self.client = client
self._client = client

def config(
self,
Expand All @@ -147,7 +147,7 @@ def config(
:param variables: Additional variables for the model configuration.
:return: The value of the model configuration along with a tracker used for gathering metrics.
"""
variation = self.client.variation(key, context, default_value.to_dict())
variation = self._client.variation(key, context, default_value.to_dict())

all_variables = {}
if variables:
Expand Down Expand Up @@ -184,7 +184,7 @@ def config(
)

tracker = LDAIConfigTracker(
self.client,
self._client,
variation.get('_ldMeta', {}).get('variationKey', ''),
key,
context,
Expand Down
97 changes: 97 additions & 0 deletions ldai/testing/test_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from unittest.mock import MagicMock

import pytest
from ldclient import Config, Context, LDClient
from ldclient.integrations.test_data import TestData

from ldai.tracker import FeedbackKind, LDAIConfigTracker


@pytest.fixture
def td() -> TestData:
td = TestData.data_source()
td.update(
td.flag('model-config')
.variations(
{
'model': {'name': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}, 'custom': {'extra-attribute': 'value'}},
'provider': {'name': 'fakeProvider'},
'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': True, 'variationKey': 'abcd'},
},
"green",
)
.variation_for_all(0)
)

return td


@pytest.fixture
def client(td: TestData) -> LDClient:
config = Config('sdk-key', update_processor_class=td, send_events=False)
client = LDClient(config=config)
client.track = MagicMock() # type: ignore
return client


def test_summary_starts_empty(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

assert tracker.get_summary().duration is None
assert tracker.get_summary().feedback is None
assert tracker.get_summary().success is None
assert tracker.get_summary().usage is None


def test_tracks_duration(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
tracker.track_duration(100)

client.track.assert_called_with( # type: ignore
'$ld:ai:duration:total',
context,
{'variationKey': 'variation-key', 'configKey': 'config-key'},
100
)

assert tracker.get_summary().duration == 100


@pytest.mark.parametrize(
"kind,label",
[
pytest.param(FeedbackKind.Positive, "positive", id="positive"),
pytest.param(FeedbackKind.Negative, "negative", id="negative"),
],
)
def test_tracks_feedback(client: LDClient, kind: FeedbackKind, label: str):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

tracker.track_feedback({'kind': kind})

client.track.assert_called_with( # type: ignore
f'$ld:ai:feedback:user:{label}',
context,
{'variationKey': 'variation-key', 'configKey': 'config-key'},
1
)
assert tracker.get_summary().feedback == {'kind': kind}


def test_tracks_success(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
tracker.track_success()

client.track.assert_called_with( # type: ignore
'$ld:ai:generation',
context,
{'variationKey': 'variation-key', 'configKey': 'config-key'},
1
)

assert tracker.get_summary().success is True
84 changes: 62 additions & 22 deletions ldai/tracker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Union
from typing import Dict, Optional, Union

from ldclient import Context, LDClient

Expand All @@ -21,7 +21,6 @@ class TokenMetrics:
output: int # type: ignore


@dataclass
class FeedbackKind(Enum):
"""
Types of feedback that can be provided for AI operations.
Expand Down Expand Up @@ -131,6 +130,34 @@ def to_metrics(self) -> TokenMetrics:
)


class LDAIMetricSummary:
"""
Summary of metrics which have been tracked.
"""

def __init__(self):
self._duration = None
self._success = None
self._feedback = None
self._usage = None

@property
def duration(self) -> Optional[int]:
return self._duration

@property
def success(self) -> Optional[bool]:
return self._success

@property
def feedback(self) -> Optional[Dict[str, FeedbackKind]]:
return self._feedback

@property
def usage(self) -> Optional[Union[TokenUsage, BedrockTokenUsage]]:
return self._usage


class LDAIConfigTracker:
"""
Tracks configuration and usage metrics for LaunchDarkly AI operations.
Expand All @@ -147,10 +174,11 @@ def __init__(
:param config_key: Configuration key for tracking.
:param context: Context for evaluation.
"""
self.ld_client = ld_client
self.variation_key = variation_key
self.config_key = config_key
self.context = context
self._ld_client = ld_client
self._variation_key = variation_key
self._config_key = config_key
self._context = context
self._summary = LDAIMetricSummary()

def __get_track_data(self):
"""
Expand All @@ -159,8 +187,8 @@ def __get_track_data(self):
:return: Dictionary containing variation and config keys.
"""
return {
'variationKey': self.variation_key,
'configKey': self.config_key,
'variationKey': self._variation_key,
'configKey': self._config_key,
}

def track_duration(self, duration: int) -> None:
Expand All @@ -169,8 +197,9 @@ def track_duration(self, duration: int) -> None:
:param duration: Duration in milliseconds.
"""
self.ld_client.track(
'$ld:ai:duration:total', self.context, self.__get_track_data(), duration
self._summary._duration = duration
self._ld_client.track(
'$ld:ai:duration:total', self._context, self.__get_track_data(), duration
)

def track_duration_of(self, func):
Expand All @@ -193,17 +222,18 @@ def track_feedback(self, feedback: Dict[str, FeedbackKind]) -> None:
:param feedback: Dictionary containing feedback kind.
"""
self._summary._feedback = feedback
if feedback['kind'] == FeedbackKind.Positive:
self.ld_client.track(
self._ld_client.track(
'$ld:ai:feedback:user:positive',
self.context,
self._context,
self.__get_track_data(),
1,
)
elif feedback['kind'] == FeedbackKind.Negative:
self.ld_client.track(
self._ld_client.track(
'$ld:ai:feedback:user:negative',
self.context,
self._context,
self.__get_track_data(),
1,
)
Expand All @@ -212,8 +242,9 @@ def track_success(self) -> None:
"""
Track a successful AI generation.
"""
self.ld_client.track(
'$ld:ai:generation', self.context, self.__get_track_data(), 1
self._summary._success = True
self._ld_client.track(
'$ld:ai:generation', self._context, self.__get_track_data(), 1
)

def track_openai_metrics(self, func):
Expand Down Expand Up @@ -253,25 +284,34 @@ def track_tokens(self, tokens: Union[TokenUsage, BedrockTokenUsage]) -> None:
:param tokens: Token usage data from either custom, OpenAI, or Bedrock sources.
"""
self._summary._usage = tokens
token_metrics = tokens.to_metrics()
if token_metrics.total > 0:
self.ld_client.track(
self._ld_client.track(
'$ld:ai:tokens:total',
self.context,
self._context,
self.__get_track_data(),
token_metrics.total,
)
if token_metrics.input > 0:
self.ld_client.track(
self._ld_client.track(
'$ld:ai:tokens:input',
self.context,
self._context,
self.__get_track_data(),
token_metrics.input,
)
if token_metrics.output > 0:
self.ld_client.track(
self._ld_client.track(
'$ld:ai:tokens:output',
self.context,
self._context,
self.__get_track_data(),
token_metrics.output,
)

def get_summary(self) -> LDAIMetricSummary:
"""
Get the current summary of AI metrics.
:return: Summary of AI metrics.
"""
return self._summary

0 comments on commit e425b1f

Please sign in to comment.