Skip to content

Commit

Permalink
simplify params serialization, add tool reference
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-dixon committed Nov 18, 2024
1 parent 26d1eeb commit 1276eb6
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/ell/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _content_block_to_anthropic_format(content_block: ContentBlock):
type="tool_use",
id=tool_call.tool_call_id,
name=tool_call.tool.__name__,
input=tool_call.params.model_dump() if isinstance(tool_call.params, BaseModel) else tool_call.params,
input=tool_call.serialize_params(),
)
elif (tool_result := content_block.tool_result):
return dict(
Expand Down
2 changes: 1 addition & 1 deletion src/ell/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def content_block_to_bedrock_format(content_block: ContentBlock) -> Dict[str, An
"toolUse": {
"toolUseId": content_block.tool_call.tool_call_id,
"name": content_block.tool_call.tool.__name__,
"input": content_block.tool_call.params.model_dump() if isinstance(content_block.tool_call.params, BaseModel) else content_block.tool_call.params,
"input": content_block.tool_call.serialize_params(),
}
}
elif content_block.tool_result:
Expand Down
2 changes: 1 addition & 1 deletion src/ell/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]:
type="function",
function=dict(
name=tool_call.tool.__name__,
arguments=tool_call.params.model_dump_json() if isinstance(tool_call.params,BaseModel) else json.dumps(tool_call.params, ensure_ascii=False)
arguments=json.dumps(tool_call.serialize_params(), ensure_ascii=False)
)
) for tool_call in tool_calls ],
role="assistant",
Expand Down
38 changes: 25 additions & 13 deletions src/ell/types/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@ def text_only(self) -> str:
def __repr__(self):
return f"{self.__class__.__name__}(tool_call_id={self.tool_call_id}, result={_content_to_text(self.result)})"

class ToolReference(BaseModel):
"""A reference to an invocable tool"""
fqn: str = Field(description="The fully qualified name of the tool")
hash: str = Field(description="The hash of the tool and its dependencies")

class ToolCall(BaseModel):
tool: Union[InvocableTool, str] = Field(description="The tool function to call or a reference to it when serialized")
tool: Union[InvocableTool, ToolReference] = Field(description="The tool function to call or a reference to it when serialized")
tool_call_id: Optional[_lstr_generic] = Field(default=None)
params: Union[Dict[str, Any], BaseModel]
params: Union[Dict[str, Any], BaseModel] = Field(description="Arguments for the tool call provided by the model.")

def __init__(self, tool, params: Optional[Union[BaseModel, Dict[str, Any]]], tool_call_id: Optional[_lstr_generic]=None):
if (not isinstance(params, BaseModel)) and isinstance(tool, FunctionType) and hasattr(tool, '__ell_params_model__'):
Expand All @@ -53,18 +58,25 @@ def __init__(self, tool, params: Optional[Union[BaseModel, Dict[str, Any]]], too

super().__init__(tool=tool, tool_call_id=tool_call_id, params=params)

# TODO. This should reference a tool fqn + version if possible
# ell should have an InvocableTool with __ properties that have this info at serialization time
@field_serializer('tool')
def serialize_tool(self, tool: InvocableTool, _info):
return tool.__name__ if hasattr(tool, '__name__') else str(tool)
def serialize_tool(self, tool: Union[InvocableTool, ToolReference], _info):
if isinstance(tool, ToolReference):
return tool
return ToolReference(
# todo(alex). add the value of fqn we want to standardize on to all lmps so we don't keep using qualname
fqn=tool.__qualname__,
hash=getattr(tool, '__ell_hash__', 'unknown')
)

@field_serializer('params')
def serialize_params(self, params: Union[Dict[str,Any],BaseModel], _info):
def _serialize_params(self, params: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
if isinstance(params, dict):
return params
return params.model_dump(exclude_none=True, exclude_unset=True)

def serialize_params(self) -> Dict[str, Any]:
return self._serialize_params(self.params)

@field_serializer('tool_call_id')
def serialize_tool_call_id(self, tool_call_id: _lstr_generic):
if tool_call_id is None:
Expand All @@ -76,19 +88,19 @@ def serialize_tool_call_id(self, tool_call_id: _lstr_generic):

def __call__(self, **kwargs):
assert not kwargs, "Unexpected arguments provided. Calling a tool uses the params provided in the ToolCall."
assert not isinstance(self.tool, str), "ToolCall.tool is a string. Tools are not invocable once serialized."
assert not isinstance(self.tool, ToolReference), f"Tools are not invocable once serialized. ToolCall.tool is a ToolReference: {self.tool}"

# XXX: TODO: MOVE TRACKING CODE TO _TRACK AND OUT OF HERE AND API.
return self.tool(**self.params.model_dump())
return self.tool(**self.serialize_params())

# XXX: Deprecate in 0.1.0
def call_and_collect_as_message_block(self):
raise DeprecationWarning("call_and_collect_as_message_block is deprecated. Use collect_as_content_block instead.")

def call_and_collect_as_content_block(self):
if isinstance(self.tool, str):
raise ValueError("Cannot call a tool that is a string reference.")
res = self.tool(**(self.params.model_dump() if isinstance(self.params, BaseModel) else self.params),
if isinstance(self.tool, ToolReference):
raise ValueError(f"Cannot call a tool that is a ToolReference: {self.tool}")
res = self.tool(**self.serialize_params(),
_tool_call_id=self.tool_call_id)
return ContentBlock(tool_result=res)

Expand Down Expand Up @@ -203,7 +215,7 @@ def type(self):

@property
def content(self):
return getattr(self, self.type)
return getattr(self, self.type) # type: ignore

@classmethod
def coerce(cls, content: AnyContent) -> "ContentBlock":
Expand Down
4 changes: 4 additions & 0 deletions tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def my_sample_tool(args: MySampleToolInput = Field(


def test_invocation_json_round_trip():
# pretend it's being tracked
my_sample_tool.__ell_hash__ = "lmp-123"
invocation_id = "invocation-" + uuid4().hex
tool_call = ToolCall(
tool=my_sample_tool,
Expand Down Expand Up @@ -281,6 +283,8 @@ def test_write_invocation_tool_call(async_sqlite_serializer: AsyncSQLiteSerializ
print(response.json())
raise e

# pretend it's being tracked
my_sample_tool.__ell_hash__ = "lmp-123"
invocation_id = "invocation-" + uuid4().hex
tool_call = ToolCall(
tool=my_sample_tool,
Expand Down

0 comments on commit 1276eb6

Please sign in to comment.