diff --git a/examples/quick_chat.py b/examples/quick_chat.py index 35643a35..fac2a0a7 100644 --- a/examples/quick_chat.py +++ b/examples/quick_chat.py @@ -35,13 +35,13 @@ def format_message_history(message_history : List[Tuple[str, str]]) -> str: @ell.lm(model="gpt-4o-mini", temperature=0.3, max_tokens=20) def chat(message_history : List[Tuple[str, str]], *, personality : str): - return [ - ell.system(f"""You are - {personality}. + return [ + ell.system(f"""Here is your description. +{personality}. - Your goal is to come up with a response to a chat. Only respond in one sentence (should be like a text message in informality.) Never use Emojis."""), - ell.user(format_message_history(message_history)), - ] +Your goal is to come up with a response to a chat. Only respond in one sentence (should be like a text message in informality.) Never use Emojis."""), + ell.user(format_message_history(message_history)), + ] diff --git a/poetry.lock b/poetry.lock index e34fb313..da86d491 100644 --- a/poetry.lock +++ b/poetry.lock @@ -52,6 +52,52 @@ tests = ["attrs[tests-no-zope]", "zope-interface"] tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] +[[package]] +name = "black" +version = "24.8.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.8" +files = [ + {file = "black-24.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:09cdeb74d494ec023ded657f7092ba518e8cf78fa8386155e4a03fdcc44679e6"}, + {file = "black-24.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:81c6742da39f33b08e791da38410f32e27d632260e599df7245cccee2064afeb"}, + {file = "black-24.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:707a1ca89221bc8a1a64fb5e15ef39cd755633daa672a9db7498d1c19de66a42"}, + {file = "black-24.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:d6417535d99c37cee4091a2f24eb2b6d5ec42b144d50f1f2e436d9fe1916fe1a"}, + {file = "black-24.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fb6e2c0b86bbd43dee042e48059c9ad7830abd5c94b0bc518c0eeec57c3eddc1"}, + {file = "black-24.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:837fd281f1908d0076844bc2b801ad2d369c78c45cf800cad7b61686051041af"}, + {file = "black-24.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:62e8730977f0b77998029da7971fa896ceefa2c4c4933fcd593fa599ecbf97a4"}, + {file = "black-24.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:72901b4913cbac8972ad911dc4098d5753704d1f3c56e44ae8dce99eecb0e3af"}, + {file = "black-24.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7c046c1d1eeb7aea9335da62472481d3bbf3fd986e093cffd35f4385c94ae368"}, + {file = "black-24.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:649f6d84ccbae73ab767e206772cc2d7a393a001070a4c814a546afd0d423aed"}, + {file = "black-24.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2b59b250fdba5f9a9cd9d0ece6e6d993d91ce877d121d161e4698af3eb9c1018"}, + {file = "black-24.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:6e55d30d44bed36593c3163b9bc63bf58b3b30e4611e4d88a0c3c239930ed5b2"}, + {file = "black-24.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:505289f17ceda596658ae81b61ebbe2d9b25aa78067035184ed0a9d855d18afd"}, + {file = "black-24.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b19c9ad992c7883ad84c9b22aaa73562a16b819c1d8db7a1a1a49fb7ec13c7d2"}, + {file = "black-24.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f13f7f386f86f8121d76599114bb8c17b69d962137fc70efe56137727c7047e"}, + {file = "black-24.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:f490dbd59680d809ca31efdae20e634f3fae27fba3ce0ba3208333b713bc3920"}, + {file = "black-24.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eab4dd44ce80dea27dc69db40dab62d4ca96112f87996bca68cd75639aeb2e4c"}, + {file = "black-24.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3c4285573d4897a7610054af5a890bde7c65cb466040c5f0c8b732812d7f0e5e"}, + {file = "black-24.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e84e33b37be070ba135176c123ae52a51f82306def9f7d063ee302ecab2cf47"}, + {file = "black-24.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:73bbf84ed136e45d451a260c6b73ed674652f90a2b3211d6a35e78054563a9bb"}, + {file = "black-24.8.0-py3-none-any.whl", hash = "sha256:972085c618ee94f402da1af548a4f218c754ea7e5dc70acb168bfaca4c2542ed"}, + {file = "black-24.8.0.tar.gz", hash = "sha256:2500945420b6784c38b9ee885af039f5e7471ef284ab03fa35ecdde4688cd83f"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "cattrs" version = "23.2.3" @@ -647,6 +693,17 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "numpy" version = "2.0.1" @@ -735,6 +792,33 @@ files = [ {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + +[[package]] +name = "platformdirs" +version = "4.2.2" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." +optional = false +python-versions = ">=3.8" +files = [ + {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, + {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] +type = ["mypy (>=1.8)"] + [[package]] name = "pluggy" version = "1.5.0" @@ -1499,4 +1583,4 @@ npm-install = [] [metadata] lock-version = "2.0" python-versions = ">=3.9" -content-hash = "5877feb54330301831212a72804ed9af0ccbe7bbebf5504e222c7f6d856ceea6" +content-hash = "0f7033afc73feeffedbe46b95f08e51723520a396bdeb8e86326b5ee88110fed" diff --git a/pyproject.toml b/pyproject.toml index 114d16ef..15c42dae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ requests = "^2.32.3" typing-extensions = "^4.12.2" +black = "^24.8.0" [tool.poetry.group.dev.dependencies] pytest = "^8.3.2" diff --git a/src/ell/util/closure.py b/src/ell/util/closure.py index a7416d48..2919ce8c 100644 --- a/src/ell/util/closure.py +++ b/src/ell/util/closure.py @@ -40,6 +40,7 @@ def xD(): import importlib.util import re from collections import deque +import black DELIM = "$$$$$$$$$$$$$$$$$$$$$$$$$" FORBIDDEN_NAMES = ["ell", "lstr"] @@ -94,12 +95,26 @@ def lexical_closure( CLOSURE_SOURCE[hash(func)] = dirty_src dsrc = _clean_src(dirty_src_without_func) + + # Format the sorce and dsrc soruce using Black + source = _format_source(source) + dsrc = _format_source(dsrc) + fn_hash = _generate_function_hash(source, dsrc, func.__qualname__) _update_ell_func(outer_ell_func, source, dsrc, globals_and_frees['globals'], globals_and_frees['frees'], fn_hash, uses) return (dirty_src, (source, dsrc), ({fn_hash} if not initial_call and hasattr(outer_ell_func, "__ell_func__") else uses)) + +def _format_source(source: str) -> str: + """Format the source code using Black.""" + try: + return black.format_str(source, mode=black.Mode()) + except: + # If Black formatting fails, return the original source + return source + def _get_globals_and_frees(func: Callable) -> Dict[str, Dict]: """Get global and free variables for a function.""" globals_dict = collections.OrderedDict(dill.detect.globalvars(func)) @@ -331,7 +346,10 @@ def get_referenced_names(code: str, module_name: str): def lexically_closured_source(func): _, fnclosure, uses = lexical_closure(func, initial_call=True, recursion_stack=[]) - return fnclosure, uses + source, dsrc = fnclosure + formatted_source = _format_source(source) + formatted_dsrc = _format_source(dsrc) + return (formatted_source, formatted_dsrc), uses import ast diff --git a/tests/test_closure.py b/tests/test_closure.py index 0c587679..3712ee6a 100644 --- a/tests/test_closure.py +++ b/tests/test_closure.py @@ -78,13 +78,13 @@ def test_get_referenced_names(): def test_is_function_called(): code = """ - def foo(): - pass - - def bar(): - foo() - - x = 1 + 2 +def foo(): + pass + +def bar(): + foo() + +x = 1 + 2 """ assert is_function_called("foo", code) assert not is_function_called("bar", code)