Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Add cwd parameter to Session (#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
waydegilliam authored Nov 20, 2023
1 parent 26749b4 commit 2752338
Show file tree
Hide file tree
Showing 17 changed files with 63 additions and 41 deletions.
11 changes: 11 additions & 0 deletions mentat/git_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from mentat.session_context import SESSION_CONTEXT


def get_git_diff_for_path(path: Path) -> str:
session_context = SESSION_CONTEXT.get()
git_root = session_context.git_root
return subprocess.check_output(
["git", "diff", path], cwd=git_root, text=True, stderr=subprocess.DEVNULL
)


def get_non_gitignored_files(path: Path) -> set[Path]:
return set(
# git returns / separated paths even on windows, convert so we can remove
Expand Down Expand Up @@ -173,6 +181,9 @@ def check_head_exists() -> bool:
git_root = session_context.git_root

try:
subprocess.check_output(
["git", "rev-parse", "HEAD", "--"], cwd=git_root, stderr=subprocess.DEVNULL
)
subprocess.check_output(
["git", "rev-parse", "HEAD", "--"], cwd=git_root, stderr=subprocess.DEVNULL
)
Expand Down
3 changes: 3 additions & 0 deletions mentat/python_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
class PythonClient:
def __init__(
self,
cwd: Path = Path.cwd(),
paths: List[Path] = [],
exclude_paths: List[Path] = [],
ignore_paths: List[Path] = [],
diff: str | None = None,
pr_diff: str | None = None,
config: Config = Config(),
):
self.cwd = cwd
self.paths = paths
self.exclude_paths = exclude_paths
self.ignore_paths = ignore_paths
Expand Down Expand Up @@ -63,6 +65,7 @@ async def _listen_for_client_exit(self):

async def startup(self):
self.session = Session(
self.cwd,
self.paths,
self.exclude_paths,
self.ignore_paths,
Expand Down
2 changes: 2 additions & 0 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Session:

def __init__(
self,
cwd: Path,
paths: List[Path] = [],
exclude_paths: List[Path] = [],
ignore_paths: List[Path] = [],
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
conversation = Conversation()

session_context = SessionContext(
cwd,
stream,
cost_tracker,
git_root,
Expand Down
1 change: 1 addition & 0 deletions mentat/session_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

@attr.define()
class SessionContext:
cwd: Path = attr.field()
stream: SessionStream = attr.field()
cost_tracker: CostTracker = attr.field()
git_root: Path = attr.field()
Expand Down
1 change: 1 addition & 0 deletions mentat/terminal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _init_signal_handlers(self):
async def _run(self):
self._init_signal_handlers()
self.session = Session(
Path.cwd(),
self.paths,
self.exclude_paths,
self.ignore_paths,
Expand Down
4 changes: 2 additions & 2 deletions tests/clients/python_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def test_editing_file_auto_accept(mock_call_llm_api, mock_setup_api_key):
# Line 2
@@end""")])

python_client = PythonClient(["."])
python_client = PythonClient(paths=["."])
await python_client.startup()
await python_client.call_mentat_auto_accept("Conversation")
await python_client.wait_for_edit_completion()
Expand Down Expand Up @@ -56,7 +56,7 @@ async def test_collects_mentat_response(mock_call_llm_api, mock_setup_api_key):
# Line 2
@@end""")])

python_client = PythonClient(["."])
python_client = PythonClient(paths=["."])
await python_client.startup()
response = await python_client.call_mentat("Conversation")
response += await python_client.call_mentat("y")
Expand Down
2 changes: 1 addition & 1 deletion tests/clients/terminal_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_editing_file(
# Line 2
@@end""")])

terminal_client = TerminalClient(["."])
terminal_client = TerminalClient(paths=["."])
terminal_client.run()
with open(file_name, "r") as f:
content = f.read()
Expand Down
4 changes: 2 additions & 2 deletions tests/code_file_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def test_run_from_subdirectory(
# Hello
@@end""")])

session = Session([Path("calculator.py"), Path("../scripts")])
session = Session(cwd=Path.cwd(), paths=[Path("calculator.py"), Path("../scripts")])
session.start()
await session.stream.recv(channel="client_exit")

Expand Down Expand Up @@ -148,7 +148,7 @@ async def test_change_after_creation(
print("Hello, World!")
@@end""")])

session = Session()
session = Session(cwd=Path.cwd())
session.start()
await session.stream.recv(channel="client_exit")

Expand Down
4 changes: 2 additions & 2 deletions tests/commands_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def test_clear_command(
)
mock_call_llm_api.set_generator_values(["Answer"])

session = Session()
session = Session(cwd=Path.cwd())
session.start()
await session.stream.recv(channel="client_exit")

Expand Down Expand Up @@ -215,7 +215,7 @@ async def test_search_command(
"mentat.code_context.CodeContext.search",
return_value=[(mock_feature, mock_score)],
)
session = Session()
session = Session(cwd=Path.cwd())
session.start()
await session.stream.recv(channel="client_exit")

Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ async def _mock_session_context(temp_testbed):
conversation = Conversation()

session_context = SessionContext(
Path.cwd(),
stream,
cost_tracker,
git_root,
Expand Down
2 changes: 1 addition & 1 deletion tests/diff_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async def test_diff_context_end_to_end(
# SESSION_CONTEXT isn't reset between tests
SESSION_CONTEXT.set(None)
mock_call_llm_api.set_generator_values([""])
python_client = PythonClient([], diff="HEAD~2")
python_client = PythonClient(paths=[], diff="HEAD~2")
await python_client.startup()

session_context = SESSION_CONTEXT.get()
Expand Down
3 changes: 2 additions & 1 deletion tests/parser_tests/block_format_error_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from textwrap import dedent

import pytest
Expand Down Expand Up @@ -63,7 +64,7 @@ async def error_test_template(
)
mock_call_llm_api.set_generator_values([changes])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
with open(temp_file_name, "r") as f:
Expand Down
21 changes: 11 additions & 10 deletions tests/parser_tests/block_format_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
from textwrap import dedent

import pytest
Expand Down Expand Up @@ -50,7 +51,7 @@ async def test_insert(
@@end""".format(file_name=temp_file_name))])

# Run the system with the temporary file path
session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")

Expand Down Expand Up @@ -100,7 +101,7 @@ async def test_replace(
@@end""".format(file_name=temp_file_name))])

# Run the system with the temporary file path
session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")

Expand Down Expand Up @@ -148,7 +149,7 @@ async def test_delete(
@@end""".format(file_name=temp_file_name))])

# Run the system with the temporary file path
session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")

Expand Down Expand Up @@ -189,7 +190,7 @@ async def test_create_file(
# I created this file
@@end""".format(file_name=temp_file_name))])

session = Session(["."])
session = Session(cwd=Path.cwd(), paths=["."])
session.start()
await session.stream.recv(channel="client_exit")

Expand Down Expand Up @@ -232,7 +233,7 @@ async def test_delete_file(
@@end""".format(file_name=temp_file_name))])

# Run the system with the temporary file path
session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")

Expand Down Expand Up @@ -270,7 +271,7 @@ async def test_rename_file(
}}
@@end""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
with open(temp_2_file_name) as new_file:
Expand Down Expand Up @@ -323,7 +324,7 @@ async def test_change_then_rename_file(
}}
@@end""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
with open(temp_2_file_name) as new_file:
Expand Down Expand Up @@ -376,7 +377,7 @@ async def test_rename_file_then_change(
# I inserted this comment!
@@end""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
with open(temp_2_file_name) as new_file:
Expand Down Expand Up @@ -437,7 +438,7 @@ async def test_multiple_blocks(
@@end""".format(file_name=temp_file_name))])

# Run the system with the temporary file path
session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")

Expand Down Expand Up @@ -503,7 +504,7 @@ async def test_json_strings(
@@end""".format(file_name=temp_file_name))])

# Run the system with the temporary file path
session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")

Expand Down
5 changes: 3 additions & 2 deletions tests/parser_tests/replacement_format_error_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from textwrap import dedent

import pytest
Expand Down Expand Up @@ -42,7 +43,7 @@ async def test_invalid_line_numbers(
# I also will not be used
@""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
with open(temp_file_name, "r") as f:
Expand Down Expand Up @@ -84,7 +85,7 @@ async def test_invalid_special_line(
# I will not be used
@""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
with open(temp_file_name, "r") as f:
Expand Down
14 changes: 7 additions & 7 deletions tests/parser_tests/replacement_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def test_insert(
# I inserted this comment
@""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
with open(temp_file_name, "r") as f:
Expand Down Expand Up @@ -73,7 +73,7 @@ async def test_delete(
@ {temp_file_name} starting_line=1 ending_line=1
@""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
with open(temp_file_name, "r") as f:
Expand Down Expand Up @@ -107,7 +107,7 @@ async def test_replace(
# I inserted this comment
@""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
with open(temp_file_name, "r") as f:
Expand Down Expand Up @@ -138,7 +138,7 @@ async def test_create_file(
# New line
@""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
with open(temp_file_name, "r") as f:
Expand Down Expand Up @@ -171,7 +171,7 @@ async def test_delete_file(
@ {temp_file_name} -""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
assert not Path(temp_file_name).exists()
Expand Down Expand Up @@ -200,7 +200,7 @@ async def test_rename_file(
@ {temp_file_name} {temp_file_name_2}""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
assert not Path(temp_file_name).exists()
Expand Down Expand Up @@ -241,7 +241,7 @@ async def test_change_then_rename_then_change(
# New line 2
@""")])

session = Session([temp_file_name])
session = Session(cwd=Path.cwd(), paths=[temp_file_name])
session.start()
await session.stream.recv(channel="client_exit")
assert not Path(temp_file_name).exists()
Expand Down
Loading

0 comments on commit 2752338

Please sign in to comment.