diff --git a/tests/chain_test.py b/tests/chain_test.py new file mode 100644 index 00000000..d654dde7 --- /dev/null +++ b/tests/chain_test.py @@ -0,0 +1,37 @@ +from codeinterpreterapi.chains import remove_download_link, get_file_modifications + + +def test_remove_download_link() -> None: + example = ( + "I have created the plot to your dataset.\n\n" + "Link to the file [here](sandbox:/plot.png)." + ) + assert ( + remove_download_link(example).formatted_response.strip() + == "I have created the plot to your dataset." + ) + + +def test_get_file_modifications() -> None: + base_code = """ + import matplotlib.pyplot as plt + + x = list(range(1, 11)) + y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77] + + plt.plot(x, y, marker='o') + plt.xlabel('Index') + plt.ylabel('Value') + plt.title('Data Plot') + """ + code_with_mod = base_code + "\nplt.savefig('plot.png')" + + code_no_mod = base_code + "\nplt.show()" + + assert get_file_modifications(code_with_mod).modifications == ["plot.png"] + assert get_file_modifications(code_no_mod).modifications == [] + + +if __name__ == "__main__": + # test_remove_download_link() + test_get_file_modifications() diff --git a/tests/general_test.py b/tests/general_test.py index 1880203e..80a6205c 100644 --- a/tests/general_test.py +++ b/tests/general_test.py @@ -40,6 +40,15 @@ def run_sync(session: CodeInterpreterSession) -> bool: .name ) + assert ( + ".png" + in session.generate_response( + "Plot the current stock price of Tesla.", + ) + .files[0] + .name + ) + finally: assert session.stop() == "stopped" @@ -71,6 +80,17 @@ async def run_async(session: CodeInterpreterSession) -> bool: .name ) + assert ( + ".jpeg" + in ( + await session.agenerate_response( + "Plot the current stock price of Tesla.", + ) + ) + .files[0] + .name + ) + finally: assert await session.astop() == "stopped"