From 5f9741b7e0072a81d39b5d1b79464acaf412a97c Mon Sep 17 00:00:00 2001 From: William Guss Date: Wed, 21 Aug 2024 18:46:06 -0700 Subject: [PATCH] this code is horrible time to clean it up --- src/ell/decorators/lm.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/ell/decorators/lm.py b/src/ell/decorators/lm.py index a25a4480..0b91ccba 100644 --- a/src/ell/decorators/lm.py +++ b/src/ell/decorators/lm.py @@ -11,7 +11,7 @@ from functools import wraps from typing import Optional, List, Callable -def lm(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=False, **lm_kwargs): +def lm(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=False, tools: Optional[List[Callable]] = None, **lm_kwargs): """ Defines a basic language model program (a parameterization of an existing foundation model using a particular prompt.) @@ -24,9 +24,6 @@ def parameterized_lm_decorator( prompt: LMP, ) -> InvocableLM: color = compute_color(prompt) - _under_fn = prompt - - _warnings(model, prompt, default_client_from_decorator) @@ -35,7 +32,7 @@ def model_call( *fn_args, _invocation_origin : str = None, client: Optional[openai.Client] = None, - lm_params: LMPParams = {}, + lm_params: Optional[LMPParams] = {}, invocation_kwargs=False, **fn_kwargs, ) -> _lstr_generic: @@ -47,16 +44,16 @@ def model_call( if config.verbose and not exempt_from_tracking: model_usage_logger_pre(prompt, fn_args, fn_kwargs, "notimplemented", messages, color) final_lm_kwargs = _get_lm_kwargs(lm_kwargs, lm_params) - _invocation_kwargs = dict(model=model, messages=messages, lm_kwargs=final_lm_kwargs, client=client or default_client_from_decorator) + api_params = dict(model=model, messages=messages, lm_kwargs=final_lm_kwargs, client=client or default_client_from_decorator) - tracked_str, metadata = _call(**_invocation_kwargs, _invocation_origin=_invocation_origin, exempt_from_tracking=exempt_from_tracking, _logging_color=color, name=prompt.__name__, tools=tools) + tracked_str, metadata = _call(**api_params, _invocation_origin=_invocation_origin, exempt_from_tracking=exempt_from_tracking, _logging_color=color, name=prompt.__name__, tools=tools) - return tracked_str, _invocation_kwargs, metadata + return tracked_str, api_params, metadata # TODO: # we'll deal with type safety here later model_call.__ell_lm_kwargs__ = lm_kwargs - model_call.__ell_func__ = _under_fn + model_call.__ell_func__ = prompt model_call.__ell_type__ = LMPType.LM model_call.__ell_exempt_from_tracking = exempt_from_tracking