Skip to content

Commit

Permalink
fix: DIA-1688: Fix non-serializable agent (#266)
Browse files Browse the repository at this point in the history
Co-authored-by: nik <[email protected]>
  • Loading branch information
niklub and nik authored Nov 27, 2024
1 parent 70b38bd commit ed9e277
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 90 deletions.
13 changes: 11 additions & 2 deletions adala/skills/collection/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import pandas as pd
from typing import Type, Iterator, Optional
from functools import cached_property
from copy import deepcopy
from collections import defaultdict
from adala.skills._base import TransformSkill
from adala.runtimes import AsyncLiteLLMVisionRuntime
from adala.runtimes._litellm import MessageChunkType
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, model_validator, computed_field

from adala.runtimes import Runtime, AsyncRuntime
from adala.utils.internal_data import InternalDataFrame
Expand Down Expand Up @@ -35,7 +36,7 @@ class LabelStudioSkill(TransformSkill):
label_config: str = "<View></View>"
allowed_control_tags: Optional[list[str]] = None
allowed_object_tags: Optional[list[str]] = None

# TODO: implement postprocessing to verify Taxonomy

@cached_property
Expand Down Expand Up @@ -63,6 +64,14 @@ def image_tags(self) -> Iterator[ObjectTag]:
tag = self.label_interface.get_object(tag_name)
if tag.tag.lower() == "image":
yield tag

def __getstate__(self):
"""Exclude cached properties when pickling - otherwise the 'Agent' can not be serialized in celery"""
state = deepcopy(super().__getstate__())
# Remove cached_property values
for key in ['label_interface', 'ner_tags', 'image_tags']:
state['__dict__'].pop(key, None)
return state

@model_validator(mode="after")
def validate_response_model(self):
Expand Down
1 change: 1 addition & 0 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ async def submit_streaming(request: SubmitStreamingRequest):
"""

task = streaming_parent_task

result = task.apply_async(
kwargs={"agent": request.agent, "result_handler": request.result_handler}
)
Expand Down
38 changes: 21 additions & 17 deletions tests/cassettes/test_serialization/test_agent_is_pickleable.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ interactions:
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.42.0
- OpenAI/Python 1.47.1
x-stainless-arch:
- arm64
x-stainless-async:
Expand All @@ -26,7 +26,9 @@ interactions:
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.42.0
- 1.47.1
x-stainless-raw-response:
- 'true'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
Expand All @@ -36,45 +38,47 @@ interactions:
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//dJFNbtswEIX3OsWUm26swpIdy/AmSIEC6R9QdNNFUQi0NJLYkDMsOWrq
BAZ6jVyvJykoO7a76IYA38f38Gb4mAEo06oNqGbQ0jhv85vX1IyVv33/wbc/Hh7erT5tqZObn28+
c/VRzZKDt9+xkWfXq4adtyiG6YCbgFowpRZVWS2qq2qxmoDjFm2y9V7yJefOkMnLebnM51VerI/u
gU2DUW3gawYA8DidqSe1+EttYD57VhzGqHtUm9MjABXYJkXpGE0UTaJmZ9gwCdJU/e1LBy0b6uEe
rZ2BDJruYMfjC7jle9BbHiVdr+HLoOXP76cITEkI4Ay1INzq3fVleMBujDoNSKO1R31/amu594G3
8chPemfIxKEOqCNTahaFvZroPgP4Nm1l/GdQ5QM7L7XwHVIKLJaHOHX+iwu4PkJh0fasL8r/meoW
RRsbLzarDv0M9eeA+ankNKWKuyjo6s5Qj8EHc1h05+uFLourVbHGtcr22V8AAAD//wMAdXW2u3YC
AAA=
H4sIAAAAAAAAA4xSwWrcMBS8+ytedenFLrZ3gzd7CSGXltIe0rKhlGK00rOtRtYTkkyyhIX+Rn+v
XxLk3awdkkAvAs28GWae9JAAMCXZGpjoeBC91dnlj6tLubrh/gtuvl4VeNbcX19vxPD982b4xtKo
oO1vFOFJ9UFQbzUGReZAC4c8YHQtqkVZFYvy/HwkepKoo6y1IVtS1iujsjIvl1leZcXqqO5ICfRs
DT8TAICH8Yw5jcR7toY8fUJ69J63yNanIQDmSEeEce+VD9wElk6kIBPQjNE/ve9BkjIt3KHWKYSO
m1vY0fAOPtId8C0NIV4v4Kbj4d+fvx7IRMBBr4yEQJLvLubmDpvB81jQDFof8f0prabWOtr6I3/C
G2WU72qH3JOJyXwgy0Z2nwD8GrcyPCvKrKPehjrQLZpoWCwPdmx6ixm5OpKBAtcTvijTV9xqiYEr
7WdbZYKLDuWknJ6AD1LRjEhmnV+Gec370FuZ9n/sJ0IItAFlbR1KJZ4XnsYcxp/61thpx2Ng5nc+
YF83yrTorFOHf9LYOq/ys22zqkTOkn3yCAAA//8DAGWxe8c1AwAA
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8c955acf3f0b03f6-LIS
- 8e92626a7bfb950c-LIS
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Thu, 26 Sep 2024 18:35:36 GMT
- Wed, 27 Nov 2024 13:15:00 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=N41cAS4gVSI3zIFhJgqnrJ1khMPelzutGHoZlTnMfAQ-1727375736-1.0.1.1-kYjE_oAWX3ZPAfxbqAEFKFQBA52gPWCFUIiK2wlGRrkNWiw6xImz50dKz_ZMuV91Yyx1YUQXKqRZxA_w6oWwlA;
path=/; expires=Thu, 26-Sep-24 19:05:36 GMT; domain=.api.openai.com; HttpOnly;
- __cf_bm=IA8w_0VVIJ_PmrdxOVx0wcdGucdEpE8vwXV0a.ekRmM-1732713300-1.0.1.1-O249f2Bl8pr69z65Ahy5jc.Ly8ioxedNWjJHlep7qW2vP4W_PMZYHTADufXH5iCEMDgSfMIe8d1a8WoJ0CMwpQ;
path=/; expires=Wed, 27-Nov-24 13:45:00 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=WW2hWPahLLIuqKIgXgTqyKSFskKsB09DNTbdc.uq4U8-1727375736877-0.0.1.1-604800000;
- _cfuvid=zk6dA9Fqv3CUJnz6.w.36xh03L.8I0Qd7lwPspYQvBE-1732713300177-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- heartex
openai-processing-ms:
- '460'
- '398'
openai-version:
- '2020-10-01'
strict-transport-security:
Expand All @@ -92,7 +96,7 @@ interactions:
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_80121246ed67b5036574096065af32e9
- req_4f6221dcc9c7a139212ea46990b5c6e2
status:
code: 200
message: OK
Expand Down
91 changes: 48 additions & 43 deletions tests/cassettes/test_stream_inference/test_run_streaming.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ interactions:
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.42.0
- OpenAI/Python 1.47.1
x-stainless-arch:
- arm64
x-stainless-async:
Expand All @@ -26,7 +26,9 @@ interactions:
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.42.0
- 1.47.1
x-stainless-raw-response:
- 'true'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
Expand All @@ -36,33 +38,34 @@ interactions:
response:
body:
string: !!binary |
H4sIAAAAAAAAA1SRQW7bMBBF9zrFlJtsrMJWpdrxJsgiQbroMkiBojAoaSQxIWcIcpTUCAzkGr1e
T1JQlp10Q4D/8Q/+H75mAMq0aguqGbQ0ztv8unq+acbHrtV+Qz+ub/d0eW/q6rn6+n28V4vk4PoR
Gzm5PjfsvEUxTEfcBNSCaepqXVSX62q5KifguEWbbL2XvOTcGTJ5sSzKfLnOV5vZPbBpMKot/MwA
AF6nM+WkFn+rLSwXJ8VhjLpHtT0/AlCBbVKUjtFE0SRq8Q4bJkGaon+7cNCyoR5e0NoFyKDpCfY8
foI7fgFd8yjpegUPg5a/b38iMCUhgDPUgnCr91cfhwfsxqhTQRqtnfXDOa3l3geu48zPemfIxGEX
UEemlCwKezXRQwbwa9rK+F9R5QM7LzvhJ6Q08LjgqeDpLz7AzQyFRdt3/UuRzflU3EdBt+sM9Rh8
MMcVdX5XlVh05aasUWWH7B8AAAD//wMAN8RVOzACAAA=
H4sIAAAAAAAAA4xSwWobMRS871e86pKLt+zaDja+hJJLA6GFQjGhlEWW3u7K0eoJ6S2pHQz9jf5e
v6Ro7XgdmkIvAs28GWae9JwBCKPFCoRqJavO2/zDw+31vv9k9tu4LL+2xfr+7vN2v79t1l8etmKS
FLTZouIX1XtFnbfIhtyRVgElY3ItF7PpopwVZTkQHWm0SdZ4zueUd8aZfFpM53mxyMvlSd2SURjF
Cr5lAADPw5lyOo0/xAqKyQvSYYyyQbE6DwGIQDYhQsZoIkvHYjKSihyjG6LfXXWgybgGntDaCXAr
3SPsqH8HH+kJ5IZ6TtcbWLeSf//8FYFcAgJ0xmlg0nJ3c2kesO6jTAVdb+0JP5zTWmp8oE088We8
Ns7EtgooI7mULDJ5MbCHDOD7sJX+VVHhA3WeK6ZHdMmwnB/txPgWF+TyRDKxtCM+m07ecKs0sjQ2
XmxVKKla1KNyfALZa0MXRHbR+e8wb3kfexvX/I/9SCiFnlFXPqA26nXhcSxg+qn/GjvveAgs4i4y
dlVtXIPBB3P8J7WvikVxvamXC1WI7JD9AQAA//8DAN/P8VI1AwAA
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8c0fc57f6d26489d-LIS
- 8e925b5dbbf74895-LIS
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Tue, 10 Sep 2024 13:30:15 GMT
- Wed, 27 Nov 2024 13:10:11 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=6jyHjMeVDBsV2QOhGV.4fQ3DSG.DHWyTNlg_qh.AHo8-1725975015-1.0.1.1-d_p5ctpNSGSs8ufQv10EFXhWNDTymg6.kMM128eJ6Mn42KAGrt41BCGG01CYqnwSLM0F9zqHlepmIFMm7Iju6w;
path=/; expires=Tue, 10-Sep-24 14:00:15 GMT; domain=.api.openai.com; HttpOnly;
- __cf_bm=s_1bk.xRBNtMM7HEMNwZjDq6aB06ekamNvokE.5sIOs-1732713011-1.0.1.1-h8FdrFbPq7Nm6gTo49b_M2Pg8v9Tau3bA2cBNZ3R055F4uRSS_i7TluiYfbDyes4kJk8gUQw.LKz.rlG_JF2ww;
path=/; expires=Wed, 27-Nov-24 13:40:11 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=YEhEqNmHj89iNxt0.TSP8pkGrJRIO3px8FsOHkdcFg8-1725975015002-0.0.1.1-604800000;
- _cfuvid=np_qT96suW_XXE9HjcKqg4LzjbaSpV0zCu3mhWZrOzU-1732713011428-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
Expand All @@ -75,38 +78,37 @@ interactions:
openai-organization:
- heartex
openai-processing-ms:
- '374'
- '435'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=15552000; includeSubDomains; preload
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999794'
- '149999793'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_7db8a1cb3f38c7fbf47baddf6843be63
- req_2b086571666146a6f3cff4a9d6dfab0f
status:
code: 200
message: OK
- request:
body: '{"messages": [{"role": "user", "content": "Classify sentiment of the input
text: I am happy"}], "model": "gpt-4o-mini", "max_tokens": 200, "seed": 47,
"temperature": 0.0, "tool_choice": {"type": "function", "function": {"name":
"ClassificationResult"}}, "tools": [{"type": "function", "function": {"name":
"ClassificationResult", "description": "Correctly extracted `ClassificationResult`
with all the required parameters with correct types", "parameters": {"properties":
{"output": {"description": "The classification label", "enum": ["positive",
"negative", "neutral"], "title": "Output", "type": "string"}}, "required": ["output"],
"type": "object"}}}]}'
"MyModel"}}, "tools": [{"type": "function", "function": {"name": "MyModel",
"description": "Correctly extracted `MyModel` with all the required parameters
with correct types", "parameters": {"properties": {"output": {"description":
"Choices for text", "enum": ["positive", "negative", "neutral"], "title": "Output",
"type": "string"}}, "required": ["output"], "type": "object"}}}]}'
headers:
accept:
- application/json
Expand All @@ -115,13 +117,13 @@ interactions:
connection:
- keep-alive
content-length:
- '656'
- '609'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- AsyncOpenAI/Python 1.42.0
- AsyncOpenAI/Python 1.47.1
x-stainless-arch:
- arm64
x-stainless-async:
Expand All @@ -131,7 +133,9 @@ interactions:
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.42.0
- 1.47.1
x-stainless-raw-response:
- 'true'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
Expand All @@ -141,34 +145,35 @@ interactions:
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//bFJdj5swEHznV1j7HCogoSS8XXtVpVPTqndqr1JTIccY4p6/ai9RT1H+
e2XgAheVB8va8czsznKKCAFRQ0mAHSgyZWV8kx8/1MfbZrn9/vHxXfo5sw/7J67p3SctvsIiMMz+
N2f4wnrDjLKSozB6gJnjFHlQTYss3xR5kuY9oEzNZaC1FuOViZXQIs6SbBUnRZyuR/bBCMY9lORn
RAghp/4Mfeqa/4WSJIuXiuLe05ZDeXlECDgjQwWo98Ij1QiLCWRGI9ehdd1JOQPQGFkxKuVkPHyn
2X0Ki0pZqbtt/Zi1tw9b+ye5ORTf7Jrd//giZ36D9LPtG2o6zS4hzfBLvbwyIwQ0VT33vQzjNILR
8PKe+07ilQ4hQF3bKa4xzACnHZgObYc7KHdgjRcojnwHZ3hFO0f/u/+aReN403kqx8zG+vmyBGla
68zeX2UKjdDCHyrHqe9nA4/GDt7Bp3eA7tX+wDqjLFZonrgOgpt0kIPpF5vAfMTQIJUzzttobA/8
s0euqkboljvrRL94aGy1Wi/r5aZe5wVE5+gfAAAA//8DAFbyr/sGAwAA
H4sIAAAAAAAAA4xTwYrbMBC9+yvEnOPiJBuS9a0sJUu7WZY2lG6bYhR5bGsrS0Iap01D/r3YTmwn
3UJ9EGLevDczT+NDwBjIFGIGouAkSqvCt893s98/P3xO/Wq/oOVu+fW2+miLh0e7/uJgVDPM9gUF
nVlvhCmtQpJGt7BwyAlr1fF8OpmPp9F43AClSVHVtNxSeGPCUmoZTqLJTRjNw/HixC6MFOghZt8C
xhg7NGfdp07xF8QsGp0jJXrPc4S4S2IMnFF1BLj30hPXBKMeFEYT6rp1XSk1AMgYlQiuVF+4/Q6D
e28WVyqZvHtafnpY31fvX7LH3brQz3ez9f30aVCvld7bpqGs0qIzaYB38fiqGGOgedlwV/tV493o
OoG7vCpRU902HDZgKrIVbSDegDVektzhBo5wQTsGr92/D9xwmFWeq5NNp/ix812Z3Dqz9Vc2Qia1
9EXikPtmHPBkbFu7rtNUgOriycA6U1pKyPxAXQvenp4X+q3qwdkJI0NcDTjn+IVYkiJx2Txot0OC
iwLTntnvEq9SaQZAMBj5715e027Hljr/H/keEAItYZpYh6kUl/P2aQ7rX+5faZ3FTcPg956wTDKp
c3TWyWbhIbNJNI9m22wxFxEEx+APAAAA//8DAKvv9cv+AwAA
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8c0fc5854c1603f6-LIS
- 8e925b6249ff94f4-LIS
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Tue, 10 Sep 2024 13:30:15 GMT
- Wed, 27 Nov 2024 13:10:11 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=80Qq9DZEAAwobsWLqKKXkKHd2kN4XyizsBMG3l6kcBc-1725975015-1.0.1.1-4w1gOtPVyGlthYp_kxGwvCyi26lkYZ_gcrB77wbqd.HNVR5NiCIg2_Ez462BdkQE_kPL5aFfMzYoDZtPfmTrdg;
path=/; expires=Tue, 10-Sep-24 14:00:15 GMT; domain=.api.openai.com; HttpOnly;
- __cf_bm=Rze_tf1qsWwR_8QujkWAs6xDPCjhF1EhOyAHTJRVIS4-1732713011-1.0.1.1-2LaAGNN.HICqWJ6mxxl2CLvL0leU5OPxcevV_E2cIBFl0F0jS_Xe2hHQvI4hUCTgKum4ONWEsP.xZvTYrthTZA;
path=/; expires=Wed, 27-Nov-24 13:40:11 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=iDPi3T9WoxBP.ET8vIGgTRLDBcPZg7._9XDWqC5OAHs-1725975015715-0.0.1.1-604800000;
- _cfuvid=casFCUelDW.rt7msWDCnalImxZa8O8ek7gps6QicC_g-1732713011934-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
Expand All @@ -181,11 +186,11 @@ interactions:
openai-organization:
- heartex
openai-processing-ms:
- '143'
- '222'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=15552000; includeSubDomains; preload
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
Expand All @@ -199,7 +204,7 @@ interactions:
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_012ec55a485db03c159a4a224cf04517
- req_50b976601cf6c0af3d7fe6944c839fc9
status:
code: 200
message: OK
Expand Down
27 changes: 15 additions & 12 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,20 @@ def test_agent_is_pickleable():
"timeout_ms": 1000,
},
"skills": [
{
"type": "ClassificationSkill",
"name": "ClassificationResult",
"instructions": "",
"input_template": "Classify sentiment of the input text: {input}",
"field_schema": {
"output": {
"type": "string",
"enum": ["positive", "negative", "neutral"],
}
},
{
"name": "label_studio_skill",
"type": "LabelStudioSkill",
"input_template": "Classify sentiment of the input text: {input}",
"label_config": """
<View>
<Text name="text" value="$text" />
<Choices name="output" toName="text">
<Choice value="positive" />
<Choice value="negative" />
<Choice value="neutral" />
</Choices>
</View>
"""
}
],
}
Expand All @@ -168,5 +171,5 @@ def test_agent_is_pickleable():
agent_roundtrip = pickle.loads(agent_pickle)
assert (
agent_json["skills"][0]["input_template"]
== agent_roundtrip.skills["ClassificationResult"].input_template
== agent_roundtrip.skills["label_studio_skill"].input_template
)
Loading

0 comments on commit ed9e277

Please sign in to comment.