Skip to content

Commit

Permalink
Add quote to third-party type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleFromNVIDIA committed Aug 26, 2024
1 parent 62f24b6 commit f10af82
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 29 deletions.
36 changes: 18 additions & 18 deletions src/rapids_pre_commit_hooks/alpha_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@


@cache
def all_metadata() -> RAPIDSMetadata:
def all_metadata() -> "RAPIDSMetadata":
return fetch_latest()


def node_has_type(node: yaml.Node, tag_type: str) -> bool:
def node_has_type(node: "yaml.Node", tag_type: str) -> bool:
return node.tag == f"tag:yaml.org,2002:{tag_type}"


def get_rapids_version(args: argparse.Namespace) -> RAPIDSVersion:
def get_rapids_version(args: argparse.Namespace) -> "RAPIDSVersion":
md = all_metadata()
return (
md.versions[args.rapids_version]
Expand All @@ -62,7 +62,7 @@ def strip_cuda_suffix(args: argparse.Namespace, name: str) -> str:


def check_and_mark_anchor(
anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node
anchors: dict[str, "yaml.Node"], used_anchors: set[str], node: "yaml.Node"
) -> tuple[bool, Optional[str]]:
for key, value in anchors.items():
if value == node:
Expand All @@ -80,9 +80,9 @@ def check_and_mark_anchor(
def check_package_spec(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
anchors: dict[str, "yaml.Node"],
used_anchors: set[str],
node: yaml.Node,
node: "yaml.Node",
) -> None:
@total_ordering
class SpecPriority:
Expand Down Expand Up @@ -154,9 +154,9 @@ def create_specifier_string(specifiers: set[str]) -> str:
def check_packages(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
anchors: dict[str, "yaml.Node"],
used_anchors: set[str],
node: yaml.Node,
node: "yaml.Node",
) -> None:
if node_has_type(node, "seq"):
descend, _ = check_and_mark_anchor(anchors, used_anchors, node)
Expand All @@ -168,9 +168,9 @@ def check_packages(
def check_common(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
anchors: dict[str, "yaml.Node"],
used_anchors: set[str],
node: yaml.Node,
node: "yaml.Node",
) -> None:
if node_has_type(node, "seq"):
for dependency_set in node.value:
Expand All @@ -188,9 +188,9 @@ def check_common(
def check_matrices(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
anchors: dict[str, "yaml.Node"],
used_anchors: set[str],
node: yaml.Node,
node: "yaml.Node",
) -> None:
if node_has_type(node, "seq"):
for item in node.value:
Expand All @@ -208,9 +208,9 @@ def check_matrices(
def check_specific(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
anchors: dict[str, "yaml.Node"],
used_anchors: set[str],
node: yaml.Node,
node: "yaml.Node",
) -> None:
if node_has_type(node, "seq"):
for matrix_matcher in node.value:
Expand All @@ -228,9 +228,9 @@ def check_specific(
def check_dependencies(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
anchors: dict[str, "yaml.Node"],
used_anchors: set[str],
node: yaml.Node,
node: "yaml.Node",
) -> None:
if node_has_type(node, "map"):
for _, dependencies_value in node.value:
Expand All @@ -250,9 +250,9 @@ def check_dependencies(
def check_root(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
anchors: dict[str, "yaml.Node"],
used_anchors: set[str],
node: yaml.Node,
node: "yaml.Node",
) -> None:
if node_has_type(node, "map"):
for root_key, root_value in node.value:
Expand Down
10 changes: 5 additions & 5 deletions src/rapids_pre_commit_hooks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def apply_copyright_check(linter: Linter, old_content: Optional[str]) -> None:
linter.add_warning((0, 0), "no copyright notice found")


def get_target_branch(repo: git.Repo, args: argparse.Namespace) -> Optional[str]:
def get_target_branch(repo: "git.Repo", args: argparse.Namespace) -> Optional[str]:
"""Determine which branch is the "target" branch.
The target branch is determined in the following order:
Expand Down Expand Up @@ -175,7 +175,7 @@ def get_target_branch(repo: git.Repo, args: argparse.Namespace) -> Optional[str]


def get_target_branch_upstream_commit(
repo: git.Repo, args: argparse.Namespace
repo: "git.Repo", args: argparse.Namespace
) -> Optional[git.Commit]:
# If no target branch can be determined, use HEAD if it exists
target_branch_name = get_target_branch(repo, args)
Expand Down Expand Up @@ -232,7 +232,7 @@ def try_get_ref(remote: git.Remote) -> Optional[git.Reference]:

def get_changed_files(
args: argparse.Namespace,
) -> dict[Union[str, os.PathLike[str]], Optional[git.Blob]]:
) -> dict[Union[str, os.PathLike[str]], Optional["git.Blob"]]:
try:
repo = git.Repo()
except git.InvalidGitRepositoryError:
Expand Down Expand Up @@ -276,8 +276,8 @@ def normalize_git_filename(filename: Union[str, os.PathLike[str]]) -> Optional[s


def find_blob(
tree: git.Tree, filename: Union[str, os.PathLike[str]]
) -> Optional[git.Blob]:
tree: "git.Tree", filename: Union[str, os.PathLike[str]]
) -> Optional["git.Blob"]:
d1, d2 = os.path.split(filename)
split = [d2]
while d1:
Expand Down
13 changes: 8 additions & 5 deletions src/rapids_pre_commit_hooks/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, filename: str, content: str) -> None:
self.filename: str = filename
self.content: str = content
self.warnings: list[LintWarning] = []
self.console: Console = Console(highlight=False)
self.console: "Console" = Console(highlight=False)
self._calculate_lines()

def add_warning(self, pos: _PosType, msg: str) -> LintWarning:
Expand Down Expand Up @@ -207,16 +207,19 @@ def print_highlighted_code(
def line_for_pos(self, index: int) -> int:
@functools.total_ordering
class LineComparator:
def __init__(self, pos: _PosType):
def __init__(self, pos: _PosType) -> None:
self.pos: _PosType = pos

def __lt__(self, other):
def __lt__(self, other: object) -> bool:
assert isinstance(other, LineComparator)
return self.pos[1] < other

def __gt__(self, other):
def __gt__(self, other: object) -> bool:
assert isinstance(other, LineComparator)
return self.pos[0] > other

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
assert isinstance(other, LineComparator)
return self.pos[0] <= other <= self.pos[1]

line_index = bisect.bisect_left(
Expand Down
2 changes: 1 addition & 1 deletion src/rapids_pre_commit_hooks/pyproject_license.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def find_value_location(
document: tomlkit.TOMLDocument, key: tuple[str, ...], append: bool
document: "tomlkit.TOMLDocument", key: tuple[str, ...], append: bool
) -> _LocType:
copied_document = copy.deepcopy(document)
placeholder = uuid.uuid4()
Expand Down

0 comments on commit f10af82

Please sign in to comment.