Skip to content

Commit

Permalink
Add a coverage bot example
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Sep 25, 2023
1 parent cae6a53 commit 4feef0e
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 0 deletions.
206 changes: 206 additions & 0 deletions examples/cover_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""A bot that uses an LLM to write unit tests under coverage conditions.
This requires the following extra packages:
- coverage
- openai
- tiktoken
"""
import importlib
import sys
import time
from textwrap import dedent
from typing import Callable, List

import coverage

import outlines.models as models
import outlines.text as text


def get_missing_coverage_lines(code: str):
cov = coverage.Coverage(branch=True)
cov.erase()

# Since `target_code` is now a Python function, consider evaluating the code
# instead of creating a python file and loading it.
modname = "test_mod"
with open(modname + ".py", "wb") as f:
f.write(code.encode("utf-8"))

cov.start()
try:
mod = sys.modules.get(modname, None)
if mod:
mod = importlib.reload(mod)
else:
mod = importlib.import_module(modname)
finally:
cov.stop()

analysis = cov._analyze(mod)

return sorted(analysis.missing)


def collect_test_functions(string):
"""Collects the names of the test functions in a given string.
Args:
string: The string to collect the test functions from.
Returns:
A list of the names of the test functions.
(This was generated by Bard!)
"""

# Split the string into lines.
lines = string.splitlines()

# Create a list to store the test function names.
test_function_names = []

# Iterate over the lines, looking for lines that start with the `def` keyword.
for line in lines:
if line.startswith("def"):
# Get the name of the test function from the line.
test_function_name = line.split("def")[1].split("(")[0].strip()

# Add the test function name to the list.
test_function_names.append(test_function_name)

# Return the list of test function names.
return test_function_names


@text.prompt
def construct_prompt(target_code: Callable):
"""The Python module code is as follows:
{{ target_code | source }}
Print only the code for a completed version of the following Python function named \
`test_some_function` that attains full coverage over the Python module code above:
def test_some_function():
"""


@text.prompt
def construct_prompt_test_code(target_code: Callable, test_code, lines: List[int]):
"""The Python module code is as follows:
{{ target_code | source }}
Print only the code for a completed version of the following Python function named \
`test_some_function` that attains coverage for lines {{ lines | join(" and ") }} in the Python \
module code above:
{{ test_code }}
"""


def query(target_code, test_code, lines):
"""Sample a query completion."""
if test_code == "":
prompt = construct_prompt(target_code)
else:
prompt = construct_prompt_test_code(target_code, test_code, lines)

completer = models.text_completion.openai("gpt-3.5-turbo")

# Naive way to prevent ourselves from spamming the service
# TODO: Do we really need this?
# Apparently, the limit is 3 / min.
time.sleep(2)
response = completer(prompt)

return prompt, response


def get_missing_lines_for_completion(response):
"""Get the missing lines for the given completion."""
import inspect

def create_call(name):
return dedent(
f"""
try:
{name}()
except AssertionError:
pass"""
)

test_function_names = collect_test_functions(response)
run_statements = "\n".join([create_call(fname) for fname in test_function_names])

c1 = f"""
{inspect.getsource(target_code)}
{response}
{run_statements}
"""

# Get missing coverage lines on the generated code,
# remove the lines that correspond to the tests.
lines = get_missing_coverage_lines(c1)
lines = [str(ln) for ln in lines if ln < inspect.getsource(target_code).count("\n")]
return lines


def some_function(x, y):
if x < 0:
z = y - x
else:
z = y + x

if z > y:
return True

return False


def some_other_function(x: int, y: int) -> bool:
z = 0
for i in range(x):
if x < 3:
z = y - x
else:
z = y + x

if z > y:
return True

return False


target_code_examples = [some_function, some_other_function]
dialogs = []

for target_code in target_code_examples:
target_dialog = []
test_code = ""
lines: list = []
for i in range(5):
prompt, test_code = query(target_code, test_code, lines)

print(f"\nATTEMPT: {i}\n")
print(f"PROMPT:\n{prompt.strip()}\n")
print(f"RESPONSE:\n{test_code.strip()}\n")

target_dialog.append((target_code, prompt, test_code))

try:
lines = get_missing_lines_for_completion(test_code)
except SyntaxError:
pass

if len(lines) == 0:
break

dialogs.append([target_code, target_dialog])

i = 0
print(dialogs[i][0])
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ module = [
"transformers.*",
"lark.*",
"interegular.*",
"coverage.*",
]
ignore_missing_imports = true

Expand Down

0 comments on commit 4feef0e

Please sign in to comment.