diff --git a/mentat/git_handler.py b/mentat/git_handler.py index 7097faf60..df682e612 100644 --- a/mentat/git_handler.py +++ b/mentat/git_handler.py @@ -7,6 +7,14 @@ from mentat.session_context import SESSION_CONTEXT +def get_git_diff_for_path(path: Path) -> str: + session_context = SESSION_CONTEXT.get() + git_root = session_context.git_root + return subprocess.check_output( + ["git", "diff", path], cwd=git_root, text=True, stderr=subprocess.DEVNULL + ) + + def get_non_gitignored_files(path: Path) -> set[Path]: return set( # git returns / separated paths even on windows, convert so we can remove @@ -173,6 +181,9 @@ def check_head_exists() -> bool: git_root = session_context.git_root try: + subprocess.check_output( + ["git", "rev-parse", "HEAD", "--"], cwd=git_root, stderr=subprocess.DEVNULL + ) subprocess.check_output( ["git", "rev-parse", "HEAD", "--"], cwd=git_root, stderr=subprocess.DEVNULL ) diff --git a/mentat/python_client/client.py b/mentat/python_client/client.py index 68bec6b67..dd9ca2bac 100644 --- a/mentat/python_client/client.py +++ b/mentat/python_client/client.py @@ -12,6 +12,7 @@ class PythonClient: def __init__( self, + cwd: Path = Path.cwd(), paths: List[Path] = [], exclude_paths: List[Path] = [], ignore_paths: List[Path] = [], @@ -19,6 +20,7 @@ def __init__( pr_diff: str | None = None, config: Config = Config(), ): + self.cwd = cwd self.paths = paths self.exclude_paths = exclude_paths self.ignore_paths = ignore_paths @@ -63,6 +65,7 @@ async def _listen_for_client_exit(self): async def startup(self): self.session = Session( + self.cwd, self.paths, self.exclude_paths, self.ignore_paths, diff --git a/mentat/session.py b/mentat/session.py index 3d4e35ba3..036842e97 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -37,6 +37,7 @@ class Session: def __init__( self, + cwd: Path, paths: List[Path] = [], exclude_paths: List[Path] = [], ignore_paths: List[Path] = [], @@ -70,6 +71,7 @@ def __init__( conversation = Conversation() session_context = SessionContext( + cwd, stream, cost_tracker, git_root, diff --git a/mentat/session_context.py b/mentat/session_context.py index 6e404dfc9..d2c0a493b 100644 --- a/mentat/session_context.py +++ b/mentat/session_context.py @@ -19,6 +19,7 @@ @attr.define() class SessionContext: + cwd: Path = attr.field() stream: SessionStream = attr.field() cost_tracker: CostTracker = attr.field() git_root: Path = attr.field() diff --git a/mentat/terminal/client.py b/mentat/terminal/client.py index 80c5d5df8..05eb829ef 100644 --- a/mentat/terminal/client.py +++ b/mentat/terminal/client.py @@ -123,6 +123,7 @@ def _init_signal_handlers(self): async def _run(self): self._init_signal_handlers() self.session = Session( + Path.cwd(), self.paths, self.exclude_paths, self.ignore_paths, diff --git a/tests/clients/python_client_test.py b/tests/clients/python_client_test.py index f39d00763..dc9ebdc67 100644 --- a/tests/clients/python_client_test.py +++ b/tests/clients/python_client_test.py @@ -25,7 +25,7 @@ async def test_editing_file_auto_accept(mock_call_llm_api, mock_setup_api_key): # Line 2 @@end""")]) - python_client = PythonClient(["."]) + python_client = PythonClient(paths=["."]) await python_client.startup() await python_client.call_mentat_auto_accept("Conversation") await python_client.wait_for_edit_completion() @@ -56,7 +56,7 @@ async def test_collects_mentat_response(mock_call_llm_api, mock_setup_api_key): # Line 2 @@end""")]) - python_client = PythonClient(["."]) + python_client = PythonClient(paths=["."]) await python_client.startup() response = await python_client.call_mentat("Conversation") response += await python_client.call_mentat("y") diff --git a/tests/clients/terminal_client_test.py b/tests/clients/terminal_client_test.py index 074b250c3..354641a34 100644 --- a/tests/clients/terminal_client_test.py +++ b/tests/clients/terminal_client_test.py @@ -60,7 +60,7 @@ def test_editing_file( # Line 2 @@end""")]) - terminal_client = TerminalClient(["."]) + terminal_client = TerminalClient(paths=["."]) terminal_client.run() with open(file_name, "r") as f: content = f.read() diff --git a/tests/code_file_manager_test.py b/tests/code_file_manager_test.py index 76b98340a..d818b710a 100644 --- a/tests/code_file_manager_test.py +++ b/tests/code_file_manager_test.py @@ -102,7 +102,7 @@ async def test_run_from_subdirectory( # Hello @@end""")]) - session = Session([Path("calculator.py"), Path("../scripts")]) + session = Session(cwd=Path.cwd(), paths=[Path("calculator.py"), Path("../scripts")]) session.start() await session.stream.recv(channel="client_exit") @@ -148,7 +148,7 @@ async def test_change_after_creation( print("Hello, World!") @@end""")]) - session = Session() + session = Session(cwd=Path.cwd()) session.start() await session.stream.recv(channel="client_exit") diff --git a/tests/commands_test.py b/tests/commands_test.py index 122d245ef..ae94448b3 100644 --- a/tests/commands_test.py +++ b/tests/commands_test.py @@ -187,7 +187,7 @@ async def test_clear_command( ) mock_call_llm_api.set_generator_values(["Answer"]) - session = Session() + session = Session(cwd=Path.cwd()) session.start() await session.stream.recv(channel="client_exit") @@ -215,7 +215,7 @@ async def test_search_command( "mentat.code_context.CodeContext.search", return_value=[(mock_feature, mock_score)], ) - session = Session() + session = Session(cwd=Path.cwd()) session.start() await session.stream.recv(channel="client_exit") diff --git a/tests/conftest.py b/tests/conftest.py index 037e63bfc..46c229b97 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -203,6 +203,7 @@ async def _mock_session_context(temp_testbed): conversation = Conversation() session_context = SessionContext( + Path.cwd(), stream, cost_tracker, git_root, diff --git a/tests/diff_context_test.py b/tests/diff_context_test.py index e666d34b0..ff0737e3d 100644 --- a/tests/diff_context_test.py +++ b/tests/diff_context_test.py @@ -172,7 +172,7 @@ async def test_diff_context_end_to_end( # SESSION_CONTEXT isn't reset between tests SESSION_CONTEXT.set(None) mock_call_llm_api.set_generator_values([""]) - python_client = PythonClient([], diff="HEAD~2") + python_client = PythonClient(paths=[], diff="HEAD~2") await python_client.startup() session_context = SESSION_CONTEXT.get() diff --git a/tests/parser_tests/block_format_error_test.py b/tests/parser_tests/block_format_error_test.py index f2afc8782..999996875 100644 --- a/tests/parser_tests/block_format_error_test.py +++ b/tests/parser_tests/block_format_error_test.py @@ -1,3 +1,4 @@ +from pathlib import Path from textwrap import dedent import pytest @@ -63,7 +64,7 @@ async def error_test_template( ) mock_call_llm_api.set_generator_values([changes]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: diff --git a/tests/parser_tests/block_format_test.py b/tests/parser_tests/block_format_test.py index a0dfe39fd..9fff9ad57 100644 --- a/tests/parser_tests/block_format_test.py +++ b/tests/parser_tests/block_format_test.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from textwrap import dedent import pytest @@ -50,7 +51,7 @@ async def test_insert( @@end""".format(file_name=temp_file_name))]) # Run the system with the temporary file path - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") @@ -100,7 +101,7 @@ async def test_replace( @@end""".format(file_name=temp_file_name))]) # Run the system with the temporary file path - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") @@ -148,7 +149,7 @@ async def test_delete( @@end""".format(file_name=temp_file_name))]) # Run the system with the temporary file path - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") @@ -189,7 +190,7 @@ async def test_create_file( # I created this file @@end""".format(file_name=temp_file_name))]) - session = Session(["."]) + session = Session(cwd=Path.cwd(), paths=["."]) session.start() await session.stream.recv(channel="client_exit") @@ -232,7 +233,7 @@ async def test_delete_file( @@end""".format(file_name=temp_file_name))]) # Run the system with the temporary file path - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") @@ -270,7 +271,7 @@ async def test_rename_file( }} @@end""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_2_file_name) as new_file: @@ -323,7 +324,7 @@ async def test_change_then_rename_file( }} @@end""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_2_file_name) as new_file: @@ -376,7 +377,7 @@ async def test_rename_file_then_change( # I inserted this comment! @@end""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_2_file_name) as new_file: @@ -437,7 +438,7 @@ async def test_multiple_blocks( @@end""".format(file_name=temp_file_name))]) # Run the system with the temporary file path - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") @@ -503,7 +504,7 @@ async def test_json_strings( @@end""".format(file_name=temp_file_name))]) # Run the system with the temporary file path - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") diff --git a/tests/parser_tests/replacement_format_error_test.py b/tests/parser_tests/replacement_format_error_test.py index fdebde840..84094a238 100644 --- a/tests/parser_tests/replacement_format_error_test.py +++ b/tests/parser_tests/replacement_format_error_test.py @@ -1,3 +1,4 @@ +from pathlib import Path from textwrap import dedent import pytest @@ -42,7 +43,7 @@ async def test_invalid_line_numbers( # I also will not be used @""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: @@ -84,7 +85,7 @@ async def test_invalid_special_line( # I will not be used @""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: diff --git a/tests/parser_tests/replacement_format_test.py b/tests/parser_tests/replacement_format_test.py index 195ee9d9f..8d6242a2b 100644 --- a/tests/parser_tests/replacement_format_test.py +++ b/tests/parser_tests/replacement_format_test.py @@ -38,7 +38,7 @@ async def test_insert( # I inserted this comment @""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: @@ -73,7 +73,7 @@ async def test_delete( @ {temp_file_name} starting_line=1 ending_line=1 @""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: @@ -107,7 +107,7 @@ async def test_replace( # I inserted this comment @""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: @@ -138,7 +138,7 @@ async def test_create_file( # New line @""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: @@ -171,7 +171,7 @@ async def test_delete_file( @ {temp_file_name} -""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") assert not Path(temp_file_name).exists() @@ -200,7 +200,7 @@ async def test_rename_file( @ {temp_file_name} {temp_file_name_2}""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") assert not Path(temp_file_name).exists() @@ -241,7 +241,7 @@ async def test_change_then_rename_then_change( # New line 2 @""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") assert not Path(temp_file_name).exists() diff --git a/tests/parser_tests/unified_diff_format_test.py b/tests/parser_tests/unified_diff_format_test.py index 217fac7b1..2278f81d0 100644 --- a/tests/parser_tests/unified_diff_format_test.py +++ b/tests/parser_tests/unified_diff_format_test.py @@ -45,7 +45,7 @@ async def test_replacement( # 4 lines @@ end @@""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: @@ -99,7 +99,7 @@ async def test_multiple_replacements( # lines @@ end @@""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: @@ -153,7 +153,7 @@ async def test_multiple_replacement_spots( +# more than @@ end @@""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: @@ -206,7 +206,7 @@ async def test_little_context_addition( # with @@ end @@""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: @@ -250,7 +250,7 @@ async def test_empty_file( +# line @@ end @@""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: @@ -283,7 +283,7 @@ async def test_creation(mock_call_llm_api, mock_collect_user_input, mock_setup_a @@ end @@ """)]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: @@ -322,7 +322,7 @@ async def test_deletion(mock_call_llm_api, mock_collect_user_input, mock_setup_a +++ /dev/null @@ end @@""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") assert not temp_file_name.exists() @@ -359,7 +359,7 @@ async def test_no_ending_marker( +# your captain speaking # 4 lines""")]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") with open(temp_file_name, "r") as f: diff --git a/tests/system_test.py b/tests/system_test.py index ea93a4d1b..57954f819 100644 --- a/tests/system_test.py +++ b/tests/system_test.py @@ -42,7 +42,7 @@ async def test_system(mock_call_llm_api, mock_setup_api_key, mock_collect_user_i print("Hello, world!") @@end""".format(file_name=temp_file_name))]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") @@ -67,7 +67,7 @@ async def test_system_exits_on_exception( # with "Task was destroyed but it is pending!" mock_call_llm_api.side_effect = Exception("Something went wrong") - session = Session() + session = Session(cwd=Path.cwd()) session.start() await session.stream.recv(channel="client_exit") @@ -131,7 +131,7 @@ async def test_interactive_change_selection( print("Change 3") @@end""".format(file_name=temp_file_name))]) - session = Session([temp_file_name]) + session = Session(cwd=Path.cwd(), paths=[temp_file_name]) session.start() await session.stream.recv(channel="client_exit") @@ -175,7 +175,7 @@ async def test_without_os_join( @@code print("Hello, world!") @@end""".format(file_name=fake_file_path))]) - session = Session([temp_file_path]) + session = Session(cwd=Path.cwd(), paths=[temp_file_path]) session.start() await session.stream.recv(channel="client_exit") mock_collect_user_input.reset_mock() @@ -214,7 +214,7 @@ async def test_sub_directory( print("Hello, world!") @@end""")]) - session = Session([file_name]) + session = Session(cwd=Path.cwd(), paths=[file_name]) session.start() await session.stream.recv(channel="client_exit")