Skip to content

Commit

Permalink
this code is horrible time to clean it up
Browse files Browse the repository at this point in the history
  • Loading branch information
MadcowD committed Aug 22, 2024
1 parent c6d4903 commit 5f9741b
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions src/ell/decorators/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functools import wraps
from typing import Optional, List, Callable

def lm(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=False, **lm_kwargs):
def lm(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=False, tools: Optional[List[Callable]] = None, **lm_kwargs):
"""
Defines a basic language model program (a parameterization of an existing foundation model using a particular prompt.)
Expand All @@ -24,9 +24,6 @@ def parameterized_lm_decorator(
prompt: LMP,
) -> InvocableLM:
color = compute_color(prompt)
_under_fn = prompt


_warnings(model, prompt, default_client_from_decorator)


Expand All @@ -35,7 +32,7 @@ def model_call(
*fn_args,
_invocation_origin : str = None,
client: Optional[openai.Client] = None,
lm_params: LMPParams = {},
lm_params: Optional[LMPParams] = {},
invocation_kwargs=False,
**fn_kwargs,
) -> _lstr_generic:
Expand All @@ -47,16 +44,16 @@ def model_call(
if config.verbose and not exempt_from_tracking: model_usage_logger_pre(prompt, fn_args, fn_kwargs, "notimplemented", messages, color)

final_lm_kwargs = _get_lm_kwargs(lm_kwargs, lm_params)
_invocation_kwargs = dict(model=model, messages=messages, lm_kwargs=final_lm_kwargs, client=client or default_client_from_decorator)
api_params = dict(model=model, messages=messages, lm_kwargs=final_lm_kwargs, client=client or default_client_from_decorator)

tracked_str, metadata = _call(**_invocation_kwargs, _invocation_origin=_invocation_origin, exempt_from_tracking=exempt_from_tracking, _logging_color=color, name=prompt.__name__, tools=tools)
tracked_str, metadata = _call(**api_params, _invocation_origin=_invocation_origin, exempt_from_tracking=exempt_from_tracking, _logging_color=color, name=prompt.__name__, tools=tools)


return tracked_str, _invocation_kwargs, metadata
return tracked_str, api_params, metadata

# TODO: # we'll deal with type safety here later
model_call.__ell_lm_kwargs__ = lm_kwargs
model_call.__ell_func__ = _under_fn
model_call.__ell_func__ = prompt
model_call.__ell_type__ = LMPType.LM
model_call.__ell_exempt_from_tracking = exempt_from_tracking

Expand Down

0 comments on commit 5f9741b

Please sign in to comment.