Skip to content

Commit

Permalink
Add type hints.
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyandrewmeyer committed Sep 22, 2023
1 parent c4e3266 commit fcef4c8
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions test/test_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@
import subprocess
import sys
import tempfile
import typing
import unittest

import ops


def get_python_filepaths(include_tests=True):
def get_python_filepaths(include_tests: bool = True):
"""Helper to retrieve paths of Python files."""
python_paths = ['setup.py']
roots = ['ops']
if include_tests:
roots.append('test')
for root in roots:
for dirpath, dirnames, filenames in os.walk(root):
for dirpath, _, filenames in os.walk(root):
for filename in filenames:
if filename.endswith(".py"):
python_paths.append(os.path.join(dirpath, filename))
Expand All @@ -42,7 +43,7 @@ class InfrastructureTests(unittest.TestCase):

def test_quote_backslashes(self):
# ensure we're not using unneeded backslash to escape strings
issues = []
issues : typing.List[typing.Tuple[str, int, str]] = []
for filepath in get_python_filepaths():
with open(filepath, "rt", encoding="utf8") as fh:
for idx, line in enumerate(fh, 1):
Expand All @@ -54,7 +55,7 @@ def test_quote_backslashes(self):

def test_ensure_copyright(self):
# all non-empty Python files must have a proper copyright somewhere in the first 5 lines
issues = []
issues : typing.List[str] = []
regex = re.compile(r"# Copyright \d\d\d\d(-\d\d\d\d)? Canonical Ltd.\n")
for filepath in get_python_filepaths():
if os.stat(filepath).st_size == 0:
Expand All @@ -69,7 +70,7 @@ def test_ensure_copyright(self):
if issues:
self.fail("Please add copyright headers to the following files:\n" + "\n".join(issues))

def _run_setup(self, *args):
def _run_setup(self, *args: str) -> str:
proc = subprocess.run(
(sys.executable, 'setup.py') + args,
stdout=subprocess.PIPE,
Expand Down Expand Up @@ -101,7 +102,7 @@ def test_install_requires(self):

# For some reason "setup.py --requires" doesn't work, so do this the hard way
with open('setup.py', encoding='utf-8') as f:
lines = []
lines : typing.List[str] = []
for line in f:
if 'install_requires=[' in line:
break
Expand Down Expand Up @@ -131,7 +132,7 @@ def test_imports(self):
with self.subTest(name=name):
self.check(name)

def check(self, name):
def check(self, name : str):
"""Helper function to run the test."""
fd, testfile = tempfile.mkstemp()
self.addCleanup(os.unlink, testfile)
Expand Down

0 comments on commit fcef4c8

Please sign in to comment.