diff --git a/src/codergpt/commenter/commenter.py b/src/codergpt/commenter/commenter.py index 18864f0..1915254 100644 --- a/src/codergpt/commenter/commenter.py +++ b/src/codergpt/commenter/commenter.py @@ -1,7 +1,7 @@ """Commenter Module.""" import os -from typing import Any, Dict +from typing import Any, Dict, Optional from langchain_core.runnables.base import RunnableSerializable @@ -17,7 +17,7 @@ def __init__(self, chain: RunnableSerializable[Dict, Any]): """ self.chain = chain - def comment(self, code: str, filename: str, overwrite: bool = False): + def comment(self, code: str, filename: str, overwrite: bool = False, language: Optional[str] = None): """ Comment the contents of the code string by invoking the runnable chain and write it to a new file. @@ -27,7 +27,7 @@ def comment(self, code: str, filename: str, overwrite: bool = False): """ response = self.chain.invoke( { - "input": f"Rewrite and return this code with\ + "input": f"Rewrite and return this {language} code with\ comments and docstrings in :param: format: \n{code}\n" } ) diff --git a/src/codergpt/explainer/explainer.py b/src/codergpt/explainer/explainer.py index f43dafd..dd651f6 100644 --- a/src/codergpt/explainer/explainer.py +++ b/src/codergpt/explainer/explainer.py @@ -16,7 +16,9 @@ def __init__(self, chain: RunnableSerializable[Dict, Any]): """ self.chain = chain - def explain(self, code: str, function: Optional[str] = None, classname: Optional[str] = None): + def explain( + self, code: str, function: Optional[str] = None, classname: Optional[str] = None, language: Optional[str] = None + ): """ Explain the contents of the code file by invoking the runnable chain. @@ -25,15 +27,15 @@ def explain(self, code: str, function: Optional[str] = None, classname: Optional :param classname: The name of the class to explain. Default is None. """ if function: - response = self.chain.invoke({"input": f"Explain the following code: \n\n```\n{code}\n```"}) + response = self.chain.invoke({"input": f"Explain the following {language} code: \n\n```\n{code}\n```"}) # Pretty print the response print(f"Explanation for '{function}':\n{response.content}") elif classname: - response = self.chain.invoke({"input": f"Explain the following code: \n\n```\n{code}\n```"}) + response = self.chain.invoke({"input": f"Explain the following {language} code: \n\n```\n{code}\n```"}) # Pretty print the response print(f"Explanation for '{classname}':\n{response.content}") else: # Explain full code - response = self.chain.invoke({"input": f"Explain the following code: \n\n```\n{code}\n```"}) + response = self.chain.invoke({"input": f"Explain the following {language} code: \n\n```\n{code}\n```"}) # Pretty print the response print(f"Explanation for the code:\n{response.content}") diff --git a/src/codergpt/main.py b/src/codergpt/main.py index a6c676e..9391e31 100644 --- a/src/codergpt/main.py +++ b/src/codergpt/main.py @@ -41,23 +41,27 @@ def inspect_package(self, path: Union[str, Path]): path = Path(path) file_language_list = [] + file_language_dict = {} if path.is_dir(): for file in path.rglob("*.*"): language = extension_to_language["language-map"].get(file.suffix) if language is not None: file_language_list.append((str(file), language)) + file_language_dict[str(file)] = language elif path.is_file(): language = extension_to_language["language-map"].get(path.suffix) if language is not None: file_language_list.append((str(path), language)) + file_language_dict[str(path)] = language else: print(f"The path {path} is neither a file nor a directory.") - return + return {} print(tabulate(file_language_list, headers=INSPECTION_HEADERS)) + return file_language_dict def get_code( self, filename: str, function_name: Optional[str] = None, class_name: Optional[str] = None @@ -73,16 +77,19 @@ def get_code( with open(filename, "r") as source_file: source_code = source_file.read() + language_map = self.inspect_package(filename) + language = language_map.get(str(filename)) + parsed_code = ast.parse(source_code) visitor = ExpressionEvaluator(source_code=source_code, function_name=function_name, class_name=class_name) visitor.visit(parsed_code) if function_name: - return visitor.function_code + return (visitor.function_code, language) elif class_name: - return visitor.class_code + return (visitor.class_code, language) else: - return source_code + return (source_code, language) def explainer(self, path: Union[str, Path], function: str = None, classname=None): """ @@ -93,8 +100,8 @@ def explainer(self, path: Union[str, Path], function: str = None, classname=None :param classname: The name of the class to explain. Default is None. """ code_explainer = CodeExplainer(self.chain) - code = self.get_code(filename=path, function_name=function, class_name=classname) - code_explainer.explain(code) + code, language = self.get_code(filename=path, function_name=function, class_name=classname) + code_explainer.explain(code, language) def commenter(self, path: Union[str, Path], overwrite: bool = False): """ @@ -104,8 +111,8 @@ def commenter(self, path: Union[str, Path], overwrite: bool = False): :param overwrite: Whether to overwrite the existing comments. Default is False. """ code_commenter = CodeCommenter(self.chain) - code = self.get_code(filename=path) - code_commenter.comment(code=code, filename=path, overwrite=overwrite) + code, language = self.get_code(filename=path) + code_commenter.comment(code=code, filename=path, overwrite=overwrite, language=language) if __name__ == "__main__": diff --git a/tests/test_commenter.py b/tests/test_commenter.py index 5507f13..e1be3b2 100644 --- a/tests/test_commenter.py +++ b/tests/test_commenter.py @@ -28,7 +28,7 @@ def test_comment_with_overwrite(self): self.mock_chain.invoke.return_value.content = expected_commented_code # Act - self.commenter.comment(code=code, filename=filename, overwrite=True) + self.commenter.comment(code=code, filename=filename, overwrite=True, language="python") # Assert self.mock_chain.invoke.assert_called_once() diff --git a/tests/test_explainer.py b/tests/test_explainer.py index 45b13b6..3ebb19c 100644 --- a/tests/test_explainer.py +++ b/tests/test_explainer.py @@ -30,10 +30,10 @@ def test_explain_function(self): sample_function_name = "example" # Replace with actual function name # Call the explain method with a sample code snippet and function name - self.code_explainer.explain(code=sample_code, function=sample_function_name) + self.code_explainer.explain(code=sample_code, function=sample_function_name, language="python") # Verify that invoke was called once with the correct parameters - expected_invoke_input = {"input": f"Explain the following code: \n\n```\n{sample_code}\n```"} + expected_invoke_input = {"input": f"Explain the following python code: \n\n```\n{sample_code}\n```"} self.mock_chain.invoke.assert_called_once_with(expected_invoke_input) # Check if the expected explanation message is in the captured output @@ -48,10 +48,10 @@ def test_explain_class(self): sample_class_name = "Example" # Replace with actual class name # Call the explain method with a sample code snippet and class name - self.code_explainer.explain(code=sample_code, classname=sample_class_name) + self.code_explainer.explain(code=sample_code, classname=sample_class_name, language="python") # Verify that invoke was called once with the correct parameters - expected_invoke_input = {"input": f"Explain the following code: \n\n```\n{sample_code}\n```"} + expected_invoke_input = {"input": f"Explain the following python code: \n\n```\n{sample_code}\n```"} self.mock_chain.invoke.assert_called_once_with(expected_invoke_input) # Check if the expected explanation message is in the captured output @@ -64,10 +64,10 @@ def test_explain_full_code(self): sample_code = "# Your full code here" # Replace with actual code # Call the explain method with a sample code snippet - self.code_explainer.explain(code=sample_code) + self.code_explainer.explain(code=sample_code, language="python") # Verify that invoke was called once with the correct parameters - expected_invoke_input = {"input": f"Explain the following code: \n\n```\n{sample_code}\n```"} + expected_invoke_input = {"input": f"Explain the following python code: \n\n```\n{sample_code}\n```"} self.mock_chain.invoke.assert_called_once_with(expected_invoke_input) # Check if the expected explanation message is in the captured output