Skip to content

Commit

Permalink
Larger Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Jan 6, 2025
1 parent 95261ab commit 4174628
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 283 deletions.
289 changes: 9 additions & 280 deletions weave/flow/eval.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import asyncio
import inspect
import logging
import textwrap
import time
import traceback
from collections.abc import Coroutine
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, Literal, Optional, Union, cast

Expand All @@ -16,7 +11,12 @@
import weave
from weave.flow import util
from weave.flow.dataset import Dataset
from weave.flow.model import Model, get_infer_method
from weave.flow.model import (
ApplyModelError,
Model,
PreprocessModelInput,
apply_model_async,
)
from weave.flow.obj import Object
from weave.flow.util import make_memorable_name
from weave.scorers import (
Expand All @@ -27,6 +27,7 @@
get_scorer_attributes,
transpose,
)
from weave.scorers.base_scorer import apply_scorer_async
from weave.trace.context.weave_client_context import get_weave_client
from weave.trace.env import get_weave_parallelism
from weave.trace.errors import OpCallError
Expand All @@ -44,36 +45,12 @@
)


PreprocessModelInput = Callable[[dict], dict]


def default_evaluation_display_name(call: Call) -> str:
date = datetime.now().strftime("%Y-%m-%d")
unique_name = make_memorable_name()
return f"eval-{date}-{unique_name}"


def async_call(func: Union[Callable, Op], *args: Any, **kwargs: Any) -> Coroutine:
is_async = False
if is_op(func):
func = as_op(func)
is_async = inspect.iscoroutinefunction(func.resolve_fn)
else:
is_async = inspect.iscoroutinefunction(func)
if is_async:
return func(*args, **kwargs) # type: ignore
return asyncio.to_thread(func, *args, **kwargs)


def async_call_op(
func: Op, *args: Any, **kwargs: Any
) -> Coroutine[Any, Any, tuple[Any, "Call"]]:
call_res = func.call(*args, __should_raise=True, **kwargs)
if inspect.iscoroutine(call_res):
return call_res
return asyncio.to_thread(lambda: call_res)


class EvaluationResults(Object):
rows: weave.Table

Expand Down Expand Up @@ -212,7 +189,7 @@ def _post_init_scorers(self) -> list[Union[Op, Scorer]]:

@weave.op()
async def predict_and_score(self, model: Union[Op, Model], example: dict) -> dict:
apply_model_result = await _apply_model_async(
apply_model_result = await apply_model_async(
model, example, self.preprocess_model_input
)

Expand All @@ -231,7 +208,7 @@ async def predict_and_score(self, model: Union[Op, Model], example: dict) -> dic
scorers = self._post_init_scorers

for scorer in scorers:
apply_scorer_result = await _apply_scorer_async(
apply_scorer_result = await apply_scorer_async(
scorer, example, model_output
)
result = apply_scorer_result.result
Expand Down Expand Up @@ -363,251 +340,3 @@ def is_valid_model(model: Any) -> bool:
and is_op(model.predict)
)
)


# Using `dataclass` because pydantic does not like `Call` as a property
@dataclass
class ApplyModelSuccess:
model_output: Any
model_call: Call
model_latency: float


@dataclass
class ApplyModelError:
model_latency: float


ApplyModelResult = Union[ApplyModelSuccess, ApplyModelError]


async def _apply_model_async(
model: Union[Op, Model],
example: dict,
preprocess_model_input: Optional[PreprocessModelInput] = None,
) -> ApplyModelResult:
if preprocess_model_input is None:
model_input = example
else:
model_input = preprocess_model_input(example) # type: ignore

model_self = None
model_predict_op: Op
if is_op(model):
model_predict_op = as_op(model)
elif weave_isinstance(model, Model):
model_self = model
model_predict_op = get_infer_method(model)
else:
raise ValueError(f"Unknown model type: {model}")

model_predict_fn_name = model_predict_op.name

predict_signature = inspect.signature(model_predict_op)
model_predict_arg_names = list(predict_signature.parameters.keys())

model_predict_args = {
k: v for k, v in model_input.items() if k in model_predict_arg_names
}
try:
model_start_time = time.time()
model_predict_op = as_op(model_predict_op)
if model_self is not None:
model_predict_args = {
**model_predict_args,
"self": model_self,
}
model_output, model_call = await async_call_op(
model_predict_op, **model_predict_args
)
except OpCallError as e:
dataset_column_names = list(example.keys())
dataset_column_names_str = ", ".join(dataset_column_names[:3])
if len(dataset_column_names) > 3:
dataset_column_names_str += ", ..."
required_arg_names = [
param.name
for param in predict_signature.parameters.values()
if param.default == inspect.Parameter.empty
]

message = textwrap.dedent(
f"""
Call error: {e}
Options for resolving:
a. change {model_predict_fn_name} argument names to match a subset of dataset column names: {dataset_column_names_str}
b. change dataset column names to match expected {model_predict_fn_name} argument names: {required_arg_names}
c. construct Evaluation with a preprocess_model_input function that accepts a dataset example and returns a dict with keys expected by {model_predict_fn_name}
"""
)
raise OpCallError(message)
except Exception:
print("model_output failed")
traceback.print_exc()
return ApplyModelError(model_latency=time.time() - model_start_time)

return ApplyModelSuccess(
model_output=model_output,
model_call=model_call,
model_latency=time.time() - model_start_time,
)


@dataclass
class ApplyScorerSuccess:
result: Any
score_call: Call


ApplyScorerResult = ApplyScorerSuccess


async def _apply_scorer_async(
scorer: Union[Op, Scorer], example: dict, model_output: dict
) -> ApplyScorerResult:
scorer_self = None
if weave_isinstance(scorer, Scorer):
scorer_self = scorer
scorer_name, score_op, _ = get_scorer_attributes(scorer)
score_signature = inspect.signature(score_op)
score_arg_names = list(score_signature.parameters.keys())

# the actual kwarg name depends on the scorer
if "output" in score_arg_names:
score_output_name = "output"
elif "model_output" in score_arg_names:
score_output_name = "model_output"
else:
message = textwrap.dedent(
f"""
Scorer {scorer_name} must have an `output` or `model_output` argument, to receive the
output of the model function.
"""
)
raise OpCallError(message)

# The keys of `score_args` must match the argument names of the scorer's `score` method.
# If scorer.column_map is set, then user is indicating that the dataset column(s)
# being passed to the scorer have different names to the `score` functions' argument names.
# So we need to remap the dataset columns to the expected argument names in the scorer,
#
# column_map k:v pairs must be structured as `scorer param name : dataset column name`
#
# For instance, if the scorer expects "input" and "ground_truth" and we have a dataset
# with columns "question" and "answer", column_map should be defined as follows:
# {"input": "question", "ground_truth": "answer"}
#
# input: is the full row, we have access to it via example
# output: is the model output, we have access to it via model_output
score_arg_names = [param for param in score_arg_names if (param != "self")]
score_args = {}

if isinstance(scorer, Scorer) and scorer.column_map is not None:
# Ensure that all keys in column_map are in score_arg_names
for key in scorer.column_map.keys():
if key not in score_arg_names:
message = textwrap.dedent(
f"""
You have created `{scorer_name}(column_map={scorer.column_map}, ...)`.
The `column_map` contains a key, `{key}`, which is not in the `score` methods' argument names.
`score` methods' argument names: {score_arg_names}
Hint:
- Ensure that the keys in `column_map` match the scorer's argument names.
"""
)
raise ValueError(message)

for arg in score_arg_names:
if arg == "output" or arg == "model_output":
continue
if arg in example:
score_args[arg] = example[arg]
elif arg in scorer.column_map:
dataset_column_name = scorer.column_map[arg]
if dataset_column_name in example:
score_args[arg] = example[dataset_column_name]
else:
message = textwrap.dedent(
f"""
You have created `{scorer_name}(column_map={scorer.column_map}, ...)`.
You are mapping `{arg}` to `{dataset_column_name}`, but `{dataset_column_name}`
was not found in the dataset columns.
Available dataset columns: {list(example.keys())}
Hint:
- Ensure that `column_map` maps the `score` methods' argument names to existing dataset column names.
"""
)
raise ValueError(message)
else:
message = textwrap.dedent(
f"""
You have created `{scorer_name}(column_map={scorer.column_map}, ...)`.
`score` method argument `{arg}` is not found in the dataset columns and is not mapped in `column_map`.
Available dataset columns: {list(example.keys())}
`column_map`: {scorer.column_map}
Hint:
Either:
- map the argument name to the dataset column using the scorers `column_map` attribute, in the form {{score_arg_name : dataset_column_name}} or
- rename a column in the dataset to `{arg}` or
- re-name the `{arg}` argument in your `score` method to match a dataset column name
"""
)
raise ValueError(message)
else:
score_args = {k: v for k, v in example.items() if k in score_arg_names}

score_args[score_output_name] = model_output

try:
score_op = as_op(score_op)
if scorer_self is not None:
score_args = {
**score_args,
"self": scorer_self,
}
result, score_call = await async_call_op(score_op, **score_args)
except OpCallError as e:
dataset_column_names = list(example.keys())
dataset_column_names_str = ", ".join(dataset_column_names[:3])
if len(dataset_column_names) > 10:
dataset_column_names_str += ", ..."
required_arg_names = [
param.name
for param in score_signature.parameters.values()
if param.default == inspect.Parameter.empty
]
required_arg_names.remove(score_output_name)

message = textwrap.dedent(
f"""
Call error: {e}
If using the `Scorer` weave class, you can set the `scorer.column_map`
attribute to map scorer argument names to dataset columns.
For example, if the `score` expects "output", "input" and "ground_truth" and we have a dataset
with columns "question" and "answer", `column_map` can be used to map the non-output parameter like so:
{{"input": "question", "ground_truth": "answer"}}
scorer argument names: {score_arg_names}
dataset keys: {example.keys()}
scorer.column_map: {getattr(scorer, 'column_map', '{}')}
Options for resolving:
a. if using the `Scorer` weave class, you can set the `scorer.column_map` attribute to map scorer argument names to dataset column names or
b. change the argument names the in the scoring function of {scorer_name} to match a subset of dataset column names: ({dataset_column_names_str}) or
c. change dataset column names to match expected {scorer_name} argument names: {required_arg_names}
"""
)
raise OpCallError(message)

return ApplyScorerSuccess(result=result, score_call=score_call)
Loading

0 comments on commit 4174628

Please sign in to comment.