Skip to content

Commit

Permalink
Added language
Browse files Browse the repository at this point in the history
  • Loading branch information
hrshdhgd committed Feb 10, 2024
1 parent 292dc06 commit 30047e2
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 22 deletions.
6 changes: 3 additions & 3 deletions src/codergpt/commenter/commenter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Commenter Module."""

import os
from typing import Any, Dict
from typing import Any, Dict, Optional

from langchain_core.runnables.base import RunnableSerializable

Expand All @@ -17,7 +17,7 @@ def __init__(self, chain: RunnableSerializable[Dict, Any]):
"""
self.chain = chain

def comment(self, code: str, filename: str, overwrite: bool = False):
def comment(self, code: str, filename: str, overwrite: bool = False, language: Optional[str] = None):
"""
Comment the contents of the code string by invoking the runnable chain and write it to a new file.
Expand All @@ -27,7 +27,7 @@ def comment(self, code: str, filename: str, overwrite: bool = False):
"""
response = self.chain.invoke(
{
"input": f"Rewrite and return this code with\
"input": f"Rewrite and return this {language} code with\
comments and docstrings in :param: format: \n{code}\n"
}
)
Expand Down
10 changes: 6 additions & 4 deletions src/codergpt/explainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def __init__(self, chain: RunnableSerializable[Dict, Any]):
"""
self.chain = chain

def explain(self, code: str, function: Optional[str] = None, classname: Optional[str] = None):
def explain(
self, code: str, function: Optional[str] = None, classname: Optional[str] = None, language: Optional[str] = None
):
"""
Explain the contents of the code file by invoking the runnable chain.
Expand All @@ -25,15 +27,15 @@ def explain(self, code: str, function: Optional[str] = None, classname: Optional
:param classname: The name of the class to explain. Default is None.
"""
if function:
response = self.chain.invoke({"input": f"Explain the following code: \n\n```\n{code}\n```"})
response = self.chain.invoke({"input": f"Explain the following {language} code: \n\n```\n{code}\n```"})
# Pretty print the response
print(f"Explanation for '{function}':\n{response.content}")
elif classname:
response = self.chain.invoke({"input": f"Explain the following code: \n\n```\n{code}\n```"})
response = self.chain.invoke({"input": f"Explain the following {language} code: \n\n```\n{code}\n```"})
# Pretty print the response
print(f"Explanation for '{classname}':\n{response.content}")
else:
# Explain full code
response = self.chain.invoke({"input": f"Explain the following code: \n\n```\n{code}\n```"})
response = self.chain.invoke({"input": f"Explain the following {language} code: \n\n```\n{code}\n```"})
# Pretty print the response
print(f"Explanation for the code:\n{response.content}")
23 changes: 15 additions & 8 deletions src/codergpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,27 @@ def inspect_package(self, path: Union[str, Path]):
path = Path(path)

file_language_list = []
file_language_dict = {}

if path.is_dir():
for file in path.rglob("*.*"):
language = extension_to_language["language-map"].get(file.suffix)
if language is not None:
file_language_list.append((str(file), language))
file_language_dict[str(file)] = language

elif path.is_file():
language = extension_to_language["language-map"].get(path.suffix)
if language is not None:
file_language_list.append((str(path), language))
file_language_dict[str(path)] = language

else:
print(f"The path {path} is neither a file nor a directory.")
return
return {}

print(tabulate(file_language_list, headers=INSPECTION_HEADERS))
return file_language_dict

def get_code(
self, filename: str, function_name: Optional[str] = None, class_name: Optional[str] = None
Expand All @@ -73,16 +77,19 @@ def get_code(
with open(filename, "r") as source_file:
source_code = source_file.read()

language_map = self.inspect_package(filename)
language = language_map.get(str(filename))

parsed_code = ast.parse(source_code)

visitor = ExpressionEvaluator(source_code=source_code, function_name=function_name, class_name=class_name)
visitor.visit(parsed_code)
if function_name:
return visitor.function_code
return (visitor.function_code, language)
elif class_name:
return visitor.class_code
return (visitor.class_code, language)
else:
return source_code
return (source_code, language)

def explainer(self, path: Union[str, Path], function: str = None, classname=None):
"""
Expand All @@ -93,8 +100,8 @@ def explainer(self, path: Union[str, Path], function: str = None, classname=None
:param classname: The name of the class to explain. Default is None.
"""
code_explainer = CodeExplainer(self.chain)
code = self.get_code(filename=path, function_name=function, class_name=classname)
code_explainer.explain(code)
code, language = self.get_code(filename=path, function_name=function, class_name=classname)
code_explainer.explain(code, language)

def commenter(self, path: Union[str, Path], overwrite: bool = False):
"""
Expand All @@ -104,8 +111,8 @@ def commenter(self, path: Union[str, Path], overwrite: bool = False):
:param overwrite: Whether to overwrite the existing comments. Default is False.
"""
code_commenter = CodeCommenter(self.chain)
code = self.get_code(filename=path)
code_commenter.comment(code=code, filename=path, overwrite=overwrite)
code, language = self.get_code(filename=path)
code_commenter.comment(code=code, filename=path, overwrite=overwrite, language=language)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_commenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_comment_with_overwrite(self):
self.mock_chain.invoke.return_value.content = expected_commented_code

# Act
self.commenter.comment(code=code, filename=filename, overwrite=True)
self.commenter.comment(code=code, filename=filename, overwrite=True, language="python")

# Assert
self.mock_chain.invoke.assert_called_once()
Expand Down
12 changes: 6 additions & 6 deletions tests/test_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def test_explain_function(self):
sample_function_name = "example" # Replace with actual function name

# Call the explain method with a sample code snippet and function name
self.code_explainer.explain(code=sample_code, function=sample_function_name)
self.code_explainer.explain(code=sample_code, function=sample_function_name, language="python")

# Verify that invoke was called once with the correct parameters
expected_invoke_input = {"input": f"Explain the following code: \n\n```\n{sample_code}\n```"}
expected_invoke_input = {"input": f"Explain the following python code: \n\n```\n{sample_code}\n```"}
self.mock_chain.invoke.assert_called_once_with(expected_invoke_input)

# Check if the expected explanation message is in the captured output
Expand All @@ -48,10 +48,10 @@ def test_explain_class(self):
sample_class_name = "Example" # Replace with actual class name

# Call the explain method with a sample code snippet and class name
self.code_explainer.explain(code=sample_code, classname=sample_class_name)
self.code_explainer.explain(code=sample_code, classname=sample_class_name, language="python")

# Verify that invoke was called once with the correct parameters
expected_invoke_input = {"input": f"Explain the following code: \n\n```\n{sample_code}\n```"}
expected_invoke_input = {"input": f"Explain the following python code: \n\n```\n{sample_code}\n```"}
self.mock_chain.invoke.assert_called_once_with(expected_invoke_input)

# Check if the expected explanation message is in the captured output
Expand All @@ -64,10 +64,10 @@ def test_explain_full_code(self):
sample_code = "# Your full code here" # Replace with actual code

# Call the explain method with a sample code snippet
self.code_explainer.explain(code=sample_code)
self.code_explainer.explain(code=sample_code, language="python")

# Verify that invoke was called once with the correct parameters
expected_invoke_input = {"input": f"Explain the following code: \n\n```\n{sample_code}\n```"}
expected_invoke_input = {"input": f"Explain the following python code: \n\n```\n{sample_code}\n```"}
self.mock_chain.invoke.assert_called_once_with(expected_invoke_input)

# Check if the expected explanation message is in the captured output
Expand Down

0 comments on commit 30047e2

Please sign in to comment.