Skip to content

Commit

Permalink
Add ability to add notes to warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleFromNVIDIA committed Aug 26, 2024
1 parent 2fb7ba6 commit 2ee7329
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 59 deletions.
102 changes: 47 additions & 55 deletions src/rapids_pre_commit_hooks/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse
import bisect
import contextlib
import dataclasses
import functools
import re
import warnings
Expand Down Expand Up @@ -47,43 +48,32 @@ class BinaryFileWarning(Warning):
pass


@dataclasses.dataclass
class Replacement:
def __init__(self, pos: _PosType, newtext: str) -> None:
self.pos: _PosType = pos
self.newtext: str = newtext
pos: _PosType
newtext: str

def __eq__(self, other: object) -> bool:
assert isinstance(other, Replacement)
return self.pos == other.pos and self.newtext == other.newtext

def __repr__(self) -> str:
return f"Replacement(pos={self.pos}, newtext={repr(self.newtext)})"
@dataclasses.dataclass
class Note:
pos: _PosType
msg: str


@dataclasses.dataclass
class LintWarning:
def __init__(self, pos: _PosType, msg: str) -> None:
self.pos: _PosType = pos
self.msg: str = msg
self.replacements: list[Replacement] = []
pos: _PosType
msg: str
replacements: list[Replacement] = dataclasses.field(
default_factory=list, init=False
)
notes: list[Note] = dataclasses.field(default_factory=list, init=False)

def add_replacement(self, pos: _PosType, newtext: str) -> None:
self.replacements.append(Replacement(pos, newtext))

def __eq__(self, other: object) -> bool:
assert isinstance(other, LintWarning)
return (
self.pos == other.pos
and self.msg == other.msg
and self.replacements == other.replacements
)

def __repr__(self) -> str:
return (
"LintWarning("
f"pos={self.pos}, "
f"msg={self.msg}, "
f"replacements={self.replacements})"
)
def add_note(self, pos: _PosType, msg: str) -> None:
self.notes.append(Note(pos, msg))


class Linter:
Expand Down Expand Up @@ -125,22 +115,30 @@ def fix(self) -> str:
replaced_content += self.content[cursor:]
return replaced_content

def _print_note(
self, note_type: str, pos: _PosType, msg: str, newtext: Optional[str] = None
) -> None:
line_index = self._line_for_pos(pos[0])
line_pos = self.lines[line_index]
self.console.print(
f"In file [bold]{escape(self.filename)}:{line_index + 1}:"
f"{pos[0] - line_pos[0] + 1}[/bold]:"
)
self._print_highlighted_code(pos, newtext)
self.console.print(f"[bold]{note_type}:[/bold] {escape(msg)}")
self.console.print()

def print_warnings(self, fix_applied: bool = False) -> None:
sorted_warnings = sorted(self.warnings, key=lambda warning: warning.pos)

for warning in sorted_warnings:
line_index = self.line_for_pos(warning.pos[0])
line_pos = self.lines[line_index]
self.console.print(
f"In file [bold]{escape(self.filename)}:{line_index + 1}:"
f"{warning.pos[0] - line_pos[0] + 1}[/bold]:"
)
self.print_highlighted_code(warning.pos)
self.console.print(f"[bold]warning:[/bold] {escape(warning.msg)}")
self.console.print()
self._print_note("warning", warning.pos, warning.msg)

for note in warning.notes:
self._print_note("note", note.pos, note.msg)

for replacement in warning.replacements:
line_index = self.line_for_pos(replacement.pos[0])
line_index = self._line_for_pos(replacement.pos[0])
line_pos = self.lines[line_index]
newtext = replacement.newtext
if match := self.NEWLINE_RE.search(newtext):
Expand All @@ -151,37 +149,31 @@ def print_warnings(self, fix_applied: bool = False) -> None:
if replacement.pos[1] > line_pos[1]:
long = True

self.console.print(
f"In file [bold]{escape(self.filename)}:{line_index + 1}:"
f"{replacement.pos[0] - line_pos[0] + 1}[/bold]:"
)
self.print_highlighted_code(replacement.pos, newtext)
if fix_applied:
if long:
self.console.print(
"[bold]note:[/bold] suggested fix applied but is too long "
"to display"
replacement_msg = (
"suggested fix applied but is too long to display"
)
else:
self.console.print("[bold]note:[/bold] suggested fix applied")
replacement_msg = "suggested fix applied"
else:
if long:
self.console.print(
"[bold]note:[/bold] suggested fix is too long to display, "
"use --fix to apply it"
replacement_msg = (
"suggested fix is too long to display, use --fix to apply "
"it"
)
else:
self.console.print("[bold]note:[/bold] suggested fix")
self.console.print()
replacement_msg = "suggested fix"
self._print_note("note", replacement.pos, replacement_msg, newtext)

def print_highlighted_code(
def _print_highlighted_code(
self, pos: _PosType, replacement: Optional[str] = None
) -> None:
line_index = self.line_for_pos(pos[0])
line_index = self._line_for_pos(pos[0])
line_pos = self.lines[line_index]
left = pos[0]

if self.line_for_pos(pos[1]) == line_index:
if self._line_for_pos(pos[1]) == line_index:
right = pos[1]
else:
right = line_pos[1]
Expand All @@ -204,7 +196,7 @@ def print_highlighted_code(
f"{escape(self.content[right:line_pos[1]])}[/green]"
)

def line_for_pos(self, index: int) -> int:
def _line_for_pos(self, index: int) -> int:
@functools.total_ordering
class LineComparator:
def __init__(self, pos: _PosType) -> None:
Expand Down
52 changes: 48 additions & 4 deletions test/rapids_pre_commit_hooks/test_lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_line_for_pos(
):
linter = Linter("test.txt", contents)
with raises:
assert linter.line_for_pos(pos) == line
assert linter._line_for_pos(pos) == line

def test_fix(self):
linter = Linter("test.txt", "Hello world!")
Expand Down Expand Up @@ -190,9 +190,10 @@ def mock_console(self):

def the_check(self, linter, args):
assert args.check_test
linter.add_warning((0, 5), "say good bye instead").add_replacement(
(0, 5), "Good bye"
)
w = linter.add_warning((0, 5), "say good bye instead")
w.add_replacement((0, 5), "Good bye")
if args.check_test_note:
w.add_note((6, 11), "it's a small world after all")
if linter.content[5] != "!":
linter.add_warning((5, 5), "use punctuation").add_replacement((5, 5), ",")

Expand Down Expand Up @@ -220,6 +221,7 @@ def test_no_warnings_no_fix(self, hello_world_file):
), self.mock_console() as console:
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
m.argparser.add_argument("--check-test-note", action="store_true")
with m.execute():
pass
assert hello_world_file.read() == "Hello world!"
Expand All @@ -233,6 +235,7 @@ def test_no_warnings_fix(self, hello_world_file):
), self.mock_console() as console:
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
m.argparser.add_argument("--check-test-note", action="store_true")
with m.execute():
pass
assert hello_world_file.read() == "Hello world!"
Expand All @@ -246,6 +249,7 @@ def test_warnings_no_fix(self, hello_world_file):
), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"):
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
m.argparser.add_argument("--check-test-note", action="store_true")
with m.execute() as ctx:
ctx.add_check(self.the_check)
assert hello_world_file.read() == "Hello world!"
Expand Down Expand Up @@ -277,6 +281,7 @@ def test_warnings_fix(self, hello_world_file):
), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"):
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
m.argparser.add_argument("--check-test-note", action="store_true")
with m.execute() as ctx:
ctx.add_check(self.the_check)
assert hello_world_file.read() == "Good bye, world!"
Expand All @@ -302,6 +307,43 @@ def test_warnings_fix(self, hello_world_file):
call().print(),
]

def test_warnings_note(self, hello_world_file):
with patch(
"sys.argv",
["check-test", "--check-test", "--check-test-note", hello_world_file.name],
), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"):
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
m.argparser.add_argument("--check-test-note", action="store_true")
with m.execute() as ctx:
ctx.add_check(self.the_check)
assert hello_world_file.read() == "Hello world!"
assert console.mock_calls == [
call(highlight=False),
call().print(f"In file [bold]{hello_world_file.name}:1:1[/bold]:"),
call().print(" [bold]Hello[/bold] world!"),
call().print("[bold]warning:[/bold] say good bye instead"),
call().print(),
call().print(f"In file [bold]{hello_world_file.name}:1:7[/bold]:"),
call().print(" Hello [bold]world[/bold]!"),
call().print("[bold]note:[/bold] it's a small world after all"),
call().print(),
call().print(f"In file [bold]{hello_world_file.name}:1:1[/bold]:"),
call().print("[red]-[bold]Hello[/bold] world![/red]"),
call().print("[green]+[bold]Good bye[/bold] world![/green]"),
call().print("[bold]note:[/bold] suggested fix"),
call().print(),
call().print(f"In file [bold]{hello_world_file.name}:1:6[/bold]:"),
call().print(" Hello[bold][/bold] world!"),
call().print("[bold]warning:[/bold] use punctuation"),
call().print(),
call().print(f"In file [bold]{hello_world_file.name}:1:6[/bold]:"),
call().print("[red]-Hello[bold][/bold] world![/red]"),
call().print("[green]+Hello[bold],[/bold] world![/green]"),
call().print("[bold]note:[/bold] suggested fix"),
call().print(),
]

def test_multiple_files(self, hello_world_file, hello_file):
with patch(
"sys.argv",
Expand All @@ -315,6 +357,7 @@ def test_multiple_files(self, hello_world_file, hello_file):
), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"):
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
m.argparser.add_argument("--check-test-note", action="store_true")
with m.execute() as ctx:
ctx.add_check(self.the_check)
assert hello_world_file.read() == "Good bye, world!"
Expand Down Expand Up @@ -367,6 +410,7 @@ def test_binary_file(self, binary_file):
):
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
m.argparser.add_argument("--check-test-note", action="store_true")
with m.execute() as ctx:
ctx.add_check(self.the_check)
mock_linter.assert_not_called()
Expand Down

0 comments on commit 2ee7329

Please sign in to comment.