diff --git a/src/rapids_pre_commit_hooks/lint.py b/src/rapids_pre_commit_hooks/lint.py index d10cbff..781cecb 100644 --- a/src/rapids_pre_commit_hooks/lint.py +++ b/src/rapids_pre_commit_hooks/lint.py @@ -15,6 +15,7 @@ import argparse import bisect import contextlib +import dataclasses import functools import re import warnings @@ -47,43 +48,32 @@ class BinaryFileWarning(Warning): pass +@dataclasses.dataclass class Replacement: - def __init__(self, pos: _PosType, newtext: str) -> None: - self.pos: _PosType = pos - self.newtext: str = newtext + pos: _PosType + newtext: str - def __eq__(self, other: object) -> bool: - assert isinstance(other, Replacement) - return self.pos == other.pos and self.newtext == other.newtext - def __repr__(self) -> str: - return f"Replacement(pos={self.pos}, newtext={repr(self.newtext)})" +@dataclasses.dataclass +class Note: + pos: _PosType + msg: str +@dataclasses.dataclass class LintWarning: - def __init__(self, pos: _PosType, msg: str) -> None: - self.pos: _PosType = pos - self.msg: str = msg - self.replacements: list[Replacement] = [] + pos: _PosType + msg: str + replacements: list[Replacement] = dataclasses.field( + default_factory=list, init=False + ) + notes: list[Note] = dataclasses.field(default_factory=list, init=False) def add_replacement(self, pos: _PosType, newtext: str) -> None: self.replacements.append(Replacement(pos, newtext)) - def __eq__(self, other: object) -> bool: - assert isinstance(other, LintWarning) - return ( - self.pos == other.pos - and self.msg == other.msg - and self.replacements == other.replacements - ) - - def __repr__(self) -> str: - return ( - "LintWarning(" - f"pos={self.pos}, " - f"msg={self.msg}, " - f"replacements={self.replacements})" - ) + def add_note(self, pos: _PosType, msg: str) -> None: + self.notes.append(Note(pos, msg)) class Linter: @@ -125,22 +115,30 @@ def fix(self) -> str: replaced_content += self.content[cursor:] return replaced_content + def _print_note( + self, note_type: str, pos: _PosType, msg: str, newtext: Optional[str] = None + ) -> None: + line_index = self._line_for_pos(pos[0]) + line_pos = self.lines[line_index] + self.console.print( + f"In file [bold]{escape(self.filename)}:{line_index + 1}:" + f"{pos[0] - line_pos[0] + 1}[/bold]:" + ) + self._print_highlighted_code(pos, newtext) + self.console.print(f"[bold]{note_type}:[/bold] {escape(msg)}") + self.console.print() + def print_warnings(self, fix_applied: bool = False) -> None: sorted_warnings = sorted(self.warnings, key=lambda warning: warning.pos) for warning in sorted_warnings: - line_index = self.line_for_pos(warning.pos[0]) - line_pos = self.lines[line_index] - self.console.print( - f"In file [bold]{escape(self.filename)}:{line_index + 1}:" - f"{warning.pos[0] - line_pos[0] + 1}[/bold]:" - ) - self.print_highlighted_code(warning.pos) - self.console.print(f"[bold]warning:[/bold] {escape(warning.msg)}") - self.console.print() + self._print_note("warning", warning.pos, warning.msg) + + for note in warning.notes: + self._print_note("note", note.pos, note.msg) for replacement in warning.replacements: - line_index = self.line_for_pos(replacement.pos[0]) + line_index = self._line_for_pos(replacement.pos[0]) line_pos = self.lines[line_index] newtext = replacement.newtext if match := self.NEWLINE_RE.search(newtext): @@ -151,37 +149,31 @@ def print_warnings(self, fix_applied: bool = False) -> None: if replacement.pos[1] > line_pos[1]: long = True - self.console.print( - f"In file [bold]{escape(self.filename)}:{line_index + 1}:" - f"{replacement.pos[0] - line_pos[0] + 1}[/bold]:" - ) - self.print_highlighted_code(replacement.pos, newtext) if fix_applied: if long: - self.console.print( - "[bold]note:[/bold] suggested fix applied but is too long " - "to display" + replacement_msg = ( + "suggested fix applied but is too long to display" ) else: - self.console.print("[bold]note:[/bold] suggested fix applied") + replacement_msg = "suggested fix applied" else: if long: - self.console.print( - "[bold]note:[/bold] suggested fix is too long to display, " - "use --fix to apply it" + replacement_msg = ( + "suggested fix is too long to display, use --fix to apply " + "it" ) else: - self.console.print("[bold]note:[/bold] suggested fix") - self.console.print() + replacement_msg = "suggested fix" + self._print_note("note", replacement.pos, replacement_msg, newtext) - def print_highlighted_code( + def _print_highlighted_code( self, pos: _PosType, replacement: Optional[str] = None ) -> None: - line_index = self.line_for_pos(pos[0]) + line_index = self._line_for_pos(pos[0]) line_pos = self.lines[line_index] left = pos[0] - if self.line_for_pos(pos[1]) == line_index: + if self._line_for_pos(pos[1]) == line_index: right = pos[1] else: right = line_pos[1] @@ -204,7 +196,7 @@ def print_highlighted_code( f"{escape(self.content[right:line_pos[1]])}[/green]" ) - def line_for_pos(self, index: int) -> int: + def _line_for_pos(self, index: int) -> int: @functools.total_ordering class LineComparator: def __init__(self, pos: _PosType) -> None: diff --git a/test/rapids_pre_commit_hooks/test_lint.py b/test/rapids_pre_commit_hooks/test_lint.py index 285b5c4..aef3da2 100644 --- a/test/rapids_pre_commit_hooks/test_lint.py +++ b/test/rapids_pre_commit_hooks/test_lint.py @@ -113,7 +113,7 @@ def test_line_for_pos( ): linter = Linter("test.txt", contents) with raises: - assert linter.line_for_pos(pos) == line + assert linter._line_for_pos(pos) == line def test_fix(self): linter = Linter("test.txt", "Hello world!") @@ -190,9 +190,10 @@ def mock_console(self): def the_check(self, linter, args): assert args.check_test - linter.add_warning((0, 5), "say good bye instead").add_replacement( - (0, 5), "Good bye" - ) + w = linter.add_warning((0, 5), "say good bye instead") + w.add_replacement((0, 5), "Good bye") + if args.check_test_note: + w.add_note((6, 11), "it's a small world after all") if linter.content[5] != "!": linter.add_warning((5, 5), "use punctuation").add_replacement((5, 5), ",") @@ -220,6 +221,7 @@ def test_no_warnings_no_fix(self, hello_world_file): ), self.mock_console() as console: m = LintMain() m.argparser.add_argument("--check-test", action="store_true") + m.argparser.add_argument("--check-test-note", action="store_true") with m.execute(): pass assert hello_world_file.read() == "Hello world!" @@ -233,6 +235,7 @@ def test_no_warnings_fix(self, hello_world_file): ), self.mock_console() as console: m = LintMain() m.argparser.add_argument("--check-test", action="store_true") + m.argparser.add_argument("--check-test-note", action="store_true") with m.execute(): pass assert hello_world_file.read() == "Hello world!" @@ -246,6 +249,7 @@ def test_warnings_no_fix(self, hello_world_file): ), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"): m = LintMain() m.argparser.add_argument("--check-test", action="store_true") + m.argparser.add_argument("--check-test-note", action="store_true") with m.execute() as ctx: ctx.add_check(self.the_check) assert hello_world_file.read() == "Hello world!" @@ -277,6 +281,7 @@ def test_warnings_fix(self, hello_world_file): ), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"): m = LintMain() m.argparser.add_argument("--check-test", action="store_true") + m.argparser.add_argument("--check-test-note", action="store_true") with m.execute() as ctx: ctx.add_check(self.the_check) assert hello_world_file.read() == "Good bye, world!" @@ -302,6 +307,43 @@ def test_warnings_fix(self, hello_world_file): call().print(), ] + def test_warnings_note(self, hello_world_file): + with patch( + "sys.argv", + ["check-test", "--check-test", "--check-test-note", hello_world_file.name], + ), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"): + m = LintMain() + m.argparser.add_argument("--check-test", action="store_true") + m.argparser.add_argument("--check-test-note", action="store_true") + with m.execute() as ctx: + ctx.add_check(self.the_check) + assert hello_world_file.read() == "Hello world!" + assert console.mock_calls == [ + call(highlight=False), + call().print(f"In file [bold]{hello_world_file.name}:1:1[/bold]:"), + call().print(" [bold]Hello[/bold] world!"), + call().print("[bold]warning:[/bold] say good bye instead"), + call().print(), + call().print(f"In file [bold]{hello_world_file.name}:1:7[/bold]:"), + call().print(" Hello [bold]world[/bold]!"), + call().print("[bold]note:[/bold] it's a small world after all"), + call().print(), + call().print(f"In file [bold]{hello_world_file.name}:1:1[/bold]:"), + call().print("[red]-[bold]Hello[/bold] world![/red]"), + call().print("[green]+[bold]Good bye[/bold] world![/green]"), + call().print("[bold]note:[/bold] suggested fix"), + call().print(), + call().print(f"In file [bold]{hello_world_file.name}:1:6[/bold]:"), + call().print(" Hello[bold][/bold] world!"), + call().print("[bold]warning:[/bold] use punctuation"), + call().print(), + call().print(f"In file [bold]{hello_world_file.name}:1:6[/bold]:"), + call().print("[red]-Hello[bold][/bold] world![/red]"), + call().print("[green]+Hello[bold],[/bold] world![/green]"), + call().print("[bold]note:[/bold] suggested fix"), + call().print(), + ] + def test_multiple_files(self, hello_world_file, hello_file): with patch( "sys.argv", @@ -315,6 +357,7 @@ def test_multiple_files(self, hello_world_file, hello_file): ), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"): m = LintMain() m.argparser.add_argument("--check-test", action="store_true") + m.argparser.add_argument("--check-test-note", action="store_true") with m.execute() as ctx: ctx.add_check(self.the_check) assert hello_world_file.read() == "Good bye, world!" @@ -367,6 +410,7 @@ def test_binary_file(self, binary_file): ): m = LintMain() m.argparser.add_argument("--check-test", action="store_true") + m.argparser.add_argument("--check-test-note", action="store_true") with m.execute() as ctx: ctx.add_check(self.the_check) mock_linter.assert_not_called()