-
Notifications
You must be signed in to change notification settings - Fork 511
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cae6a53
commit 4feef0e
Showing
2 changed files
with
207 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
"""A bot that uses an LLM to write unit tests under coverage conditions. | ||
This requires the following extra packages: | ||
- coverage | ||
- openai | ||
- tiktoken | ||
""" | ||
import importlib | ||
import sys | ||
import time | ||
from textwrap import dedent | ||
from typing import Callable, List | ||
|
||
import coverage | ||
|
||
import outlines.models as models | ||
import outlines.text as text | ||
|
||
|
||
def get_missing_coverage_lines(code: str): | ||
cov = coverage.Coverage(branch=True) | ||
cov.erase() | ||
|
||
# Since `target_code` is now a Python function, consider evaluating the code | ||
# instead of creating a python file and loading it. | ||
modname = "test_mod" | ||
with open(modname + ".py", "wb") as f: | ||
f.write(code.encode("utf-8")) | ||
|
||
cov.start() | ||
try: | ||
mod = sys.modules.get(modname, None) | ||
if mod: | ||
mod = importlib.reload(mod) | ||
else: | ||
mod = importlib.import_module(modname) | ||
finally: | ||
cov.stop() | ||
|
||
analysis = cov._analyze(mod) | ||
|
||
return sorted(analysis.missing) | ||
|
||
|
||
def collect_test_functions(string): | ||
"""Collects the names of the test functions in a given string. | ||
Args: | ||
string: The string to collect the test functions from. | ||
Returns: | ||
A list of the names of the test functions. | ||
(This was generated by Bard!) | ||
""" | ||
|
||
# Split the string into lines. | ||
lines = string.splitlines() | ||
|
||
# Create a list to store the test function names. | ||
test_function_names = [] | ||
|
||
# Iterate over the lines, looking for lines that start with the `def` keyword. | ||
for line in lines: | ||
if line.startswith("def"): | ||
# Get the name of the test function from the line. | ||
test_function_name = line.split("def")[1].split("(")[0].strip() | ||
|
||
# Add the test function name to the list. | ||
test_function_names.append(test_function_name) | ||
|
||
# Return the list of test function names. | ||
return test_function_names | ||
|
||
|
||
@text.prompt | ||
def construct_prompt(target_code: Callable): | ||
"""The Python module code is as follows: | ||
{{ target_code | source }} | ||
Print only the code for a completed version of the following Python function named \ | ||
`test_some_function` that attains full coverage over the Python module code above: | ||
def test_some_function(): | ||
""" | ||
|
||
|
||
@text.prompt | ||
def construct_prompt_test_code(target_code: Callable, test_code, lines: List[int]): | ||
"""The Python module code is as follows: | ||
{{ target_code | source }} | ||
Print only the code for a completed version of the following Python function named \ | ||
`test_some_function` that attains coverage for lines {{ lines | join(" and ") }} in the Python \ | ||
module code above: | ||
{{ test_code }} | ||
""" | ||
|
||
|
||
def query(target_code, test_code, lines): | ||
"""Sample a query completion.""" | ||
if test_code == "": | ||
prompt = construct_prompt(target_code) | ||
else: | ||
prompt = construct_prompt_test_code(target_code, test_code, lines) | ||
|
||
completer = models.text_completion.openai("gpt-3.5-turbo") | ||
|
||
# Naive way to prevent ourselves from spamming the service | ||
# TODO: Do we really need this? | ||
# Apparently, the limit is 3 / min. | ||
time.sleep(2) | ||
response = completer(prompt) | ||
|
||
return prompt, response | ||
|
||
|
||
def get_missing_lines_for_completion(response): | ||
"""Get the missing lines for the given completion.""" | ||
import inspect | ||
|
||
def create_call(name): | ||
return dedent( | ||
f""" | ||
try: | ||
{name}() | ||
except AssertionError: | ||
pass""" | ||
) | ||
|
||
test_function_names = collect_test_functions(response) | ||
run_statements = "\n".join([create_call(fname) for fname in test_function_names]) | ||
|
||
c1 = f""" | ||
{inspect.getsource(target_code)} | ||
{response} | ||
{run_statements} | ||
""" | ||
|
||
# Get missing coverage lines on the generated code, | ||
# remove the lines that correspond to the tests. | ||
lines = get_missing_coverage_lines(c1) | ||
lines = [str(ln) for ln in lines if ln < inspect.getsource(target_code).count("\n")] | ||
return lines | ||
|
||
|
||
def some_function(x, y): | ||
if x < 0: | ||
z = y - x | ||
else: | ||
z = y + x | ||
|
||
if z > y: | ||
return True | ||
|
||
return False | ||
|
||
|
||
def some_other_function(x: int, y: int) -> bool: | ||
z = 0 | ||
for i in range(x): | ||
if x < 3: | ||
z = y - x | ||
else: | ||
z = y + x | ||
|
||
if z > y: | ||
return True | ||
|
||
return False | ||
|
||
|
||
target_code_examples = [some_function, some_other_function] | ||
dialogs = [] | ||
|
||
for target_code in target_code_examples: | ||
target_dialog = [] | ||
test_code = "" | ||
lines: list = [] | ||
for i in range(5): | ||
prompt, test_code = query(target_code, test_code, lines) | ||
|
||
print(f"\nATTEMPT: {i}\n") | ||
print(f"PROMPT:\n{prompt.strip()}\n") | ||
print(f"RESPONSE:\n{test_code.strip()}\n") | ||
|
||
target_dialog.append((target_code, prompt, test_code)) | ||
|
||
try: | ||
lines = get_missing_lines_for_completion(test_code) | ||
except SyntaxError: | ||
pass | ||
|
||
if len(lines) == 0: | ||
break | ||
|
||
dialogs.append([target_code, target_dialog]) | ||
|
||
i = 0 | ||
print(dialogs[i][0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters