diff --git a/src/rapids_pre_commit_hooks/alpha_spec.py b/src/rapids_pre_commit_hooks/alpha_spec.py index c48d474..9d1e7d0 100644 --- a/src/rapids_pre_commit_hooks/alpha_spec.py +++ b/src/rapids_pre_commit_hooks/alpha_spec.py @@ -36,15 +36,15 @@ @cache -def all_metadata() -> RAPIDSMetadata: +def all_metadata() -> "RAPIDSMetadata": return fetch_latest() -def node_has_type(node: yaml.Node, tag_type: str) -> bool: +def node_has_type(node: "yaml.Node", tag_type: str) -> bool: return node.tag == f"tag:yaml.org,2002:{tag_type}" -def get_rapids_version(args: argparse.Namespace) -> RAPIDSVersion: +def get_rapids_version(args: argparse.Namespace) -> "RAPIDSVersion": md = all_metadata() return ( md.versions[args.rapids_version] @@ -62,7 +62,7 @@ def strip_cuda_suffix(args: argparse.Namespace, name: str) -> str: def check_and_mark_anchor( - anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node + anchors: dict[str, "yaml.Node"], used_anchors: set[str], node: "yaml.Node" ) -> tuple[bool, Optional[str]]: for key, value in anchors.items(): if value == node: @@ -80,9 +80,9 @@ def check_and_mark_anchor( def check_package_spec( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: @total_ordering class SpecPriority: @@ -154,9 +154,9 @@ def create_specifier_string(specifiers: set[str]) -> str: def check_packages( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "seq"): descend, _ = check_and_mark_anchor(anchors, used_anchors, node) @@ -168,9 +168,9 @@ def check_packages( def check_common( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "seq"): for dependency_set in node.value: @@ -188,9 +188,9 @@ def check_common( def check_matrices( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "seq"): for item in node.value: @@ -208,9 +208,9 @@ def check_matrices( def check_specific( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "seq"): for matrix_matcher in node.value: @@ -228,9 +228,9 @@ def check_specific( def check_dependencies( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "map"): for _, dependencies_value in node.value: @@ -250,9 +250,9 @@ def check_dependencies( def check_root( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "map"): for root_key, root_value in node.value: diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 8f236c4..e282869 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -108,7 +108,7 @@ def apply_copyright_check(linter: Linter, old_content: Optional[str]) -> None: linter.add_warning((0, 0), "no copyright notice found") -def get_target_branch(repo: git.Repo, args: argparse.Namespace) -> Optional[str]: +def get_target_branch(repo: "git.Repo", args: argparse.Namespace) -> Optional[str]: """Determine which branch is the "target" branch. The target branch is determined in the following order: @@ -175,7 +175,7 @@ def get_target_branch(repo: git.Repo, args: argparse.Namespace) -> Optional[str] def get_target_branch_upstream_commit( - repo: git.Repo, args: argparse.Namespace + repo: "git.Repo", args: argparse.Namespace ) -> Optional[git.Commit]: # If no target branch can be determined, use HEAD if it exists target_branch_name = get_target_branch(repo, args) @@ -232,7 +232,7 @@ def try_get_ref(remote: git.Remote) -> Optional[git.Reference]: def get_changed_files( args: argparse.Namespace, -) -> dict[Union[str, os.PathLike[str]], Optional[git.Blob]]: +) -> dict[Union[str, os.PathLike[str]], Optional["git.Blob"]]: try: repo = git.Repo() except git.InvalidGitRepositoryError: @@ -276,8 +276,8 @@ def normalize_git_filename(filename: Union[str, os.PathLike[str]]) -> Optional[s def find_blob( - tree: git.Tree, filename: Union[str, os.PathLike[str]] -) -> Optional[git.Blob]: + tree: "git.Tree", filename: Union[str, os.PathLike[str]] +) -> Optional["git.Blob"]: d1, d2 = os.path.split(filename) split = [d2] while d1: diff --git a/src/rapids_pre_commit_hooks/lint.py b/src/rapids_pre_commit_hooks/lint.py index ebf0110..670a6a7 100644 --- a/src/rapids_pre_commit_hooks/lint.py +++ b/src/rapids_pre_commit_hooks/lint.py @@ -93,7 +93,7 @@ def __init__(self, filename: str, content: str) -> None: self.filename: str = filename self.content: str = content self.warnings: list[LintWarning] = [] - self.console: Console = Console(highlight=False) + self.console: "Console" = Console(highlight=False) self._calculate_lines() def add_warning(self, pos: _PosType, msg: str) -> LintWarning: @@ -207,16 +207,19 @@ def print_highlighted_code( def line_for_pos(self, index: int) -> int: @functools.total_ordering class LineComparator: - def __init__(self, pos: _PosType): + def __init__(self, pos: _PosType) -> None: self.pos: _PosType = pos - def __lt__(self, other): + def __lt__(self, other: object) -> bool: + assert isinstance(other, LineComparator) return self.pos[1] < other - def __gt__(self, other): + def __gt__(self, other: object) -> bool: + assert isinstance(other, LineComparator) return self.pos[0] > other - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + assert isinstance(other, LineComparator) return self.pos[0] <= other <= self.pos[1] line_index = bisect.bisect_left( diff --git a/src/rapids_pre_commit_hooks/pyproject_license.py b/src/rapids_pre_commit_hooks/pyproject_license.py index 602816a..cd16723 100644 --- a/src/rapids_pre_commit_hooks/pyproject_license.py +++ b/src/rapids_pre_commit_hooks/pyproject_license.py @@ -32,7 +32,7 @@ def find_value_location( - document: tomlkit.TOMLDocument, key: tuple[str, ...], append: bool + document: "tomlkit.TOMLDocument", key: tuple[str, ...], append: bool ) -> _LocType: copied_document = copy.deepcopy(document) placeholder = uuid.uuid4()