diff --git a/src/codeinterpreterapi/schema.py b/src/codeinterpreterapi/schema.py index 79f4654f..4dd3638b 100644 --- a/src/codeinterpreterapi/schema.py +++ b/src/codeinterpreterapi/schema.py @@ -88,6 +88,7 @@ def __repr__(self) -> str: class CodeInput(BaseModel): + filename: str code: str diff --git a/src/codeinterpreterapi/tools/python.py b/src/codeinterpreterapi/tools/python.py index 986929ca..98473657 100644 --- a/src/codeinterpreterapi/tools/python.py +++ b/src/codeinterpreterapi/tools/python.py @@ -1,4 +1,5 @@ import base64 +import os import re import subprocess from io import BytesIO @@ -13,6 +14,9 @@ from codeinterpreterapi.schema import CodeInput, File from codeinterpreterapi.utils.file_util import FileUtil +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +INVOKE_TASKS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../invoke_tasks")) + class PythonTools: def __init__(self, ci_params: CodeInterpreterParams): @@ -46,14 +50,13 @@ def get_tools_python(cls, ci_params: CodeInterpreterParams) -> None: return tools - def _get_handler_local_command(self, code: str): - python_file_path = FileUtil.write_python_file(code) - command = f"cd src/codeinterpreterapi/invoke_tasks && invoke -c python run-code-file '{python_file_path}'" + def _get_handler_local_command(self, filename: str, code: str): + python_file_path = FileUtil.write_python_file(filename, code) + command = f"cd {INVOKE_TASKS_DIR} && invoke -c python run-code-file '{python_file_path}'" return command - def _run_handler_local(self, code: str): - print("_run_handler_local code=", code) - command = self._get_handler_local_command(code) + def _run_handler_local(self, filename: str, code: str): + command = self._get_handler_local_command(filename, code) try: output_content = subprocess.check_output(command, shell=True, universal_newlines=True) self.code_log.append((code, output_content)) @@ -62,9 +65,9 @@ def _run_handler_local(self, code: str): print(f"An error occurred: {e}") return None - async def _arun_handler_local(self, code: str): + async def _arun_handler_local(self, filename: str, code: str): print("_arun_handler_local code=", code) - command = self._get_handler_local_command(code) + command = self._get_handler_local_command(filename, code) try: output_content = await subprocess.check_output(command, shell=True, universal_newlines=True) self.code_log.append((code, output_content)) @@ -73,11 +76,11 @@ async def _arun_handler_local(self, code: str): print(f"An error occurred: {e}") return None - def _run_handler(self, code: str) -> str: + def _run_handler(self, filename: str, code: str) -> str: """Run code in container and send the output to the user""" self.show_code(code) if self.ci_params.codebox is None: - return self._run_handler_local(code) + return self._run_handler_local(filename, code) output: CodeBoxOutput = self.ci_params.codebox.run(code) self.code_log.append((code, output.content)) @@ -118,7 +121,7 @@ def _run_handler(self, code: str) -> str: return output.content - async def _arun_handler(self, code: str) -> str: + async def _arun_handler(self, filename: str, code: str) -> str: """Run code in container and send the output to the user""" await self.ashow_code(code) if self.ci_params.codebox is None: @@ -179,7 +182,7 @@ def test(): ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) tools_instance = PythonTools(ci_params=ci_params) test_code = "print('test output')" - result = tools_instance._run_handler(test_code) + result = tools_instance._run_handler("main.py", test_code) print("result=", result) assert "test output" in result diff --git a/src/codeinterpreterapi/utils/file_util.py b/src/codeinterpreterapi/utils/file_util.py index 35ac4596..c83b04ed 100644 --- a/src/codeinterpreterapi/utils/file_util.py +++ b/src/codeinterpreterapi/utils/file_util.py @@ -16,17 +16,18 @@ def get_python_file_path(filename: str) -> str: return python_file_path @staticmethod - def write_python_file(code: str): + def write_python_file(filename: str, code: str): if FileUtil.is_raw_string(code): - print("FileUtil write_python_file raw string by ast.parse()") + print("FileUtil write_python_file raw string by ast.parse() filename=", filename) parsed_code = ast.parse(code) code_content = ast.unparse(parsed_code) else: - print("FileUtil write_python_file regular string by ast.literal_eval()") - code_content = ast.literal_eval(f'"""{code}"""') + print("FileUtil write_python_file regular string by ast.literal_eval() filename=", filename) + code_content = code.replace('"""', '\\"\\"\\"') + code_content = ast.literal_eval(f'"""{code_content}"""') - if settings.PYTHON_OUT_FILE: - python_file_path = FileUtil.get_python_file_path(filename=settings.PYTHON_OUT_FILE) + if filename: + python_file_path = FileUtil.get_python_file_path(filename=filename) with open(python_file_path, "w", encoding="utf-8") as python_file: python_file.write(code_content) return python_file_path