Skip to content

Commit

Permalink
Ok, everything works... time to refactor and fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Jan 4, 2025
1 parent 0b6b723 commit d4591e6
Showing 1 changed file with 64 additions and 4 deletions.
68 changes: 64 additions & 4 deletions tests/trace/test_call_apply_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,72 @@ def score(self, y, output):


def test_scorer_obj_with_context(client: WeaveClient):
raise NotImplementedError()
@weave.op
def predict(x):
return x + 1

class MyScorer(weave.Scorer):
offset: int

@weave.op
def score(self, x, output, correct_answer):
return output - correct_answer - self.offset

scorer = MyScorer(offset=0)

_, call = predict.call(1)
apply_score_res = call.apply_scorer(
scorer, additional_scorer_kwargs={"correct_answer": 2}
)
do_assertions_for_scorer_op(apply_score_res, call, scorer, client)

class MyScorerWithIncorrectArgs(weave.Scorer):
offset: int

@weave.op
def score(self, y, output, incorrect_arg):
return output - incorrect_arg - self.offset

with pytest.raises(OpCallError):
apply_score_res = call.apply_scorer(
MyScorerWithIncorrectArgs(offset=0),
additional_scorer_kwargs={"incorrect_arg": 2},
)

class MyScorerWithIncorrectArgsButCorrectColumnMapping(weave.Scorer):
offset: int

@weave.op
def score(self, y, output, incorrect_arg):
return output - incorrect_arg - self.offset

def test_scorer_obj_with_arg_mapping(client: WeaveClient):
raise NotImplementedError()
scorer = MyScorerWithIncorrectArgsButCorrectColumnMapping(
offset=0, column_map={"y": "x", "incorrect_arg": "correct_answer"}
)

_, call = predict.call(1)
apply_score_res = call.apply_scorer(
scorer, additional_scorer_kwargs={"correct_answer": 2}
)
do_assertions_for_scorer_op(apply_score_res, call, scorer, client)


def test_async_scorer_obj(client: WeaveClient):
raise NotImplementedError()
@weave.op
def predict(x):
return x + 1

class MyScorer(weave.Scorer):
offset: int

@weave.op
async def score(self, x, output):
return output - x - 1

scorer = MyScorer(offset=0)

_, call = predict.call(1)
apply_score_res = call.apply_scorer(
scorer, additional_scorer_kwargs={"correct_answer": 2}
)
do_assertions_for_scorer_op(apply_score_res, call, scorer, client)

0 comments on commit d4591e6

Please sign in to comment.