Skip to content

Commit

Permalink
Add explainability to TLM (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulya-tkch authored Jul 30, 2024
1 parent a72cbea commit 4ad5763
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 35 deletions.
26 changes: 12 additions & 14 deletions cleanlab_studio/internal/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
import io
import os
import time
from typing import Callable, cast, List, Optional, Tuple, Dict, Union, Any
from io import StringIO
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

import aiohttp
import aiohttp.client_exceptions
import numpy as np
import numpy.typing as npt
import pandas as pd
import requests
from tqdm import tqdm

from cleanlab_studio.errors import (
APIError,
Expand All @@ -15,15 +24,6 @@
)
from cleanlab_studio.internal.tlm.concurrency import TlmRateHandler

import aiohttp
import aiohttp.client_exceptions
import requests
from tqdm import tqdm
import pandas as pd
import numpy as np
import numpy.typing as npt
from io import StringIO

try:
import snowflake
import snowflake.snowpark as snowpark
Expand All @@ -39,12 +39,10 @@
except ImportError:
pyspark_exists = False

from cleanlab_studio.errors import NotInstalledError
from cleanlab_studio.internal.api.api_helper import check_uuid_well_formed
from cleanlab_studio.internal.types import JSONDict, SchemaOverride
from cleanlab_studio.version import __version__
from cleanlab_studio.errors import NotInstalledError
from cleanlab_studio.internal.api.api_helper import (
check_uuid_well_formed,
)

base_url = os.environ.get("CLEANLAB_API_BASE_URL", "https://api.cleanlab.ai/api")
cli_base_url = f"{base_url}/cli/v0"
Expand Down
4 changes: 2 additions & 2 deletions cleanlab_studio/internal/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Set
from typing import List, Set, Tuple

# TLM constants
# prepend constants with _ so that they don't show up in help.cleanlab.ai docs
Expand All @@ -16,5 +16,5 @@
TLM_MAX_TOKEN_RANGE: Tuple[int, int] = (64, 512) # (min, max)
TLM_NUM_CANDIDATE_RESPONSES_RANGE: Tuple[int, int] = (1, 20) # (min, max)
TLM_NUM_CONSISTENCY_SAMPLES_RANGE: Tuple[int, int] = (0, 20) # (min, max)
TLM_VALID_LOG_OPTIONS: Set[str] = {"perplexity"}
TLM_VALID_LOG_OPTIONS: Set[str] = {"perplexity", "explanation"}
TLM_VALID_GET_TRUSTWORTHINESS_SCORE_KWARGS: Set[str] = {"perplexity"}
8 changes: 3 additions & 5 deletions cleanlab_studio/internal/tlm/validation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import os
from typing import Union, Sequence, List, Dict, Tuple, Any
from typing import Any, Dict, List, Sequence, Union

from cleanlab_studio.errors import ValidationError
from cleanlab_studio.internal.constants import (
_VALID_TLM_MODELS,
TLM_MAX_TOKEN_RANGE,
TLM_NUM_CANDIDATE_RESPONSES_RANGE,
TLM_NUM_CONSISTENCY_SAMPLES_RANGE,
TLM_VALID_LOG_OPTIONS,
TLM_VALID_GET_TRUSTWORTHINESS_SCORE_KWARGS,
TLM_VALID_LOG_OPTIONS,
)


SKIP_VALIDATE_TLM_OPTIONS: bool = (
os.environ.get("CLEANLAB_STUDIO_SKIP_VALIDATE_TLM_OPTIONS", "false").lower() == "true"
)
Expand Down Expand Up @@ -216,7 +216,6 @@ def process_response_and_kwargs(
)
if val is not None and not 0 <= val <= 1:
raise ValidationError("Perplexity values must be between 0 and 1")

elif isinstance(response, Sequence):
if not isinstance(val, Sequence):
raise ValidationError(
Expand All @@ -235,7 +234,6 @@ def process_response_and_kwargs(

if v is not None and not 0 <= v <= 1:
raise ValidationError("Perplexity values must be between 0 and 1")

else:
raise ValidationError(
f"Invalid type {type(val)}, perplexity must be either a sequence or a float"
Expand Down
3 changes: 1 addition & 2 deletions cleanlab_studio/internal/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Any, Dict, Optional, TypedDict, Literal

from typing import Any, Dict, Literal, Optional, TypedDict

JSONDict = Dict[str, Any]

Expand Down
27 changes: 15 additions & 12 deletions cleanlab_studio/studio/trustworthy_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,31 @@

import asyncio
import sys
from typing import Coroutine, List, Optional, Union, cast, Sequence, Any, Dict
from tqdm.asyncio import tqdm_asyncio
import numpy as np
from typing import Any, Coroutine, Dict, List, Optional, Sequence, Union, cast

import aiohttp
from typing_extensions import NotRequired, TypedDict # for Python <3.11 with (Not)Required
from tqdm.asyncio import tqdm_asyncio
from typing_extensions import ( # for Python <3.11 with (Not)Required
NotRequired,
TypedDict,
)

from cleanlab_studio.errors import ValidationError
from cleanlab_studio.internal.api import api
from cleanlab_studio.internal.constants import (
_TLM_MAX_RETRIES,
_VALID_TLM_QUALITY_PRESETS,
)
from cleanlab_studio.internal.tlm.concurrency import TlmRateHandler
from cleanlab_studio.internal.tlm.validation import (
process_response_and_kwargs,
validate_tlm_options,
validate_tlm_prompt,
validate_tlm_try_prompt,
validate_tlm_prompt_response,
validate_tlm_try_prompt,
validate_try_tlm_prompt_response,
validate_tlm_options,
process_response_and_kwargs,
)
from cleanlab_studio.internal.types import TLMQualityPreset
from cleanlab_studio.errors import ValidationError
from cleanlab_studio.internal.constants import (
_VALID_TLM_QUALITY_PRESETS,
_TLM_MAX_RETRIES,
)


class TLM:
Expand Down Expand Up @@ -699,6 +701,7 @@ class TLMOptions(TypedDict):
Setting this to False disables the use of self-reflection and may produce worse TLM trustworthiness scores, but will reduce costs/runtimes.
log (List[str], default = None): optionally specify additional logs or metadata to return.
For instance, include "explanation" here to get explanations of why a response is scored with low trustworthiness.
"""

model: NotRequired[str]
Expand Down

0 comments on commit 4ad5763

Please sign in to comment.