From d4591e69fd8b314b80d682194722ddefbae89b0c Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 3 Jan 2025 20:46:38 -0800 Subject: [PATCH] Ok, everything works... time to refactor and fix test --- tests/trace/test_call_apply_scorer.py | 68 +++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/tests/trace/test_call_apply_scorer.py b/tests/trace/test_call_apply_scorer.py index 7997dfb2411..91491694dea 100644 --- a/tests/trace/test_call_apply_scorer.py +++ b/tests/trace/test_call_apply_scorer.py @@ -131,12 +131,72 @@ def score(self, y, output): def test_scorer_obj_with_context(client: WeaveClient): - raise NotImplementedError() + @weave.op + def predict(x): + return x + 1 + + class MyScorer(weave.Scorer): + offset: int + + @weave.op + def score(self, x, output, correct_answer): + return output - correct_answer - self.offset + + scorer = MyScorer(offset=0) + + _, call = predict.call(1) + apply_score_res = call.apply_scorer( + scorer, additional_scorer_kwargs={"correct_answer": 2} + ) + do_assertions_for_scorer_op(apply_score_res, call, scorer, client) + + class MyScorerWithIncorrectArgs(weave.Scorer): + offset: int + + @weave.op + def score(self, y, output, incorrect_arg): + return output - incorrect_arg - self.offset + + with pytest.raises(OpCallError): + apply_score_res = call.apply_scorer( + MyScorerWithIncorrectArgs(offset=0), + additional_scorer_kwargs={"incorrect_arg": 2}, + ) + class MyScorerWithIncorrectArgsButCorrectColumnMapping(weave.Scorer): + offset: int + + @weave.op + def score(self, y, output, incorrect_arg): + return output - incorrect_arg - self.offset -def test_scorer_obj_with_arg_mapping(client: WeaveClient): - raise NotImplementedError() + scorer = MyScorerWithIncorrectArgsButCorrectColumnMapping( + offset=0, column_map={"y": "x", "incorrect_arg": "correct_answer"} + ) + + _, call = predict.call(1) + apply_score_res = call.apply_scorer( + scorer, additional_scorer_kwargs={"correct_answer": 2} + ) + do_assertions_for_scorer_op(apply_score_res, call, scorer, client) def test_async_scorer_obj(client: WeaveClient): - raise NotImplementedError() + @weave.op + def predict(x): + return x + 1 + + class MyScorer(weave.Scorer): + offset: int + + @weave.op + async def score(self, x, output): + return output - x - 1 + + scorer = MyScorer(offset=0) + + _, call = predict.call(1) + apply_score_res = call.apply_scorer( + scorer, additional_scorer_kwargs={"correct_answer": 2} + ) + do_assertions_for_scorer_op(apply_score_res, call, scorer, client)