diff --git a/cyaron/tests/__init__.py b/cyaron/tests/__init__.py index 328a930..6fcff5c 100644 --- a/cyaron/tests/__init__.py +++ b/cyaron/tests/__init__.py @@ -5,3 +5,4 @@ from .compare_test import TestCompare from .graph_test import TestGraph from .vector_test import TestVector +from .general_test import TestGeneral diff --git a/cyaron/tests/compare_test.py b/cyaron/tests/compare_test.py index c52a6a7..7b6c487 100644 --- a/cyaron/tests/compare_test.py +++ b/cyaron/tests/compare_test.py @@ -15,10 +15,12 @@ class TestCompare(unittest.TestCase): def setUp(self): + self.original_directory = os.getcwd() self.temp_directory = tempfile.mkdtemp() os.chdir(self.temp_directory) def tearDown(self): + os.chdir(self.original_directory) try: shutil.rmtree(self.temp_directory) except: diff --git a/cyaron/tests/general_test.py b/cyaron/tests/general_test.py index e69de29..69f7a7f 100644 --- a/cyaron/tests/general_test.py +++ b/cyaron/tests/general_test.py @@ -0,0 +1,44 @@ +import subprocess +import unittest +import os +import tempfile +import shutil +import sys + + +class TestGeneral(unittest.TestCase): + + def setUp(self): + self.original_directory = os.getcwd() + self.temp_directory = tempfile.mkdtemp() + os.chdir(self.temp_directory) + + def tearDown(self): + os.chdir(self.original_directory) + try: + shutil.rmtree(self.temp_directory) + except: + pass + + def test_randseed_arg(self): + with open("test_randseed.py", 'w', encoding='utf-8') as f: + f.write("import cyaron as c\n" + "c.process_args()\n" + "for i in range(10):\n" + " print(c.randint(1,1000000000),end=' ')\n") + + env = os.environ.copy() + env['PYTHONPATH'] = self.original_directory + os.pathsep + env.get( + 'PYTHONPATH', '') + result = subprocess.run([ + sys.executable, 'test_randseed.py', + '--randseed=pinkrabbit147154220' + ], + env=env, + stdout=subprocess.PIPE, + universal_newlines=True, + check=True) + self.assertEqual( + result.stdout, + "243842479 490459912 810766286 646030451 191412261 929378523 273000814 982402032 436668773 957169453 " + ) diff --git a/cyaron/tests/io_test.py b/cyaron/tests/io_test.py index 6a0032d..02b5a98 100644 --- a/cyaron/tests/io_test.py +++ b/cyaron/tests/io_test.py @@ -12,10 +12,12 @@ class TestIO(unittest.TestCase): def setUp(self): + self.original_directory = os.getcwd() self.temp_directory = tempfile.mkdtemp() os.chdir(self.temp_directory) def tearDown(self): + os.chdir(self.original_directory) try: shutil.rmtree(self.temp_directory) except: diff --git a/cyaron/utils.py b/cyaron/utils.py index 5b7cfd5..4ef5ad8 100644 --- a/cyaron/utils.py +++ b/cyaron/utils.py @@ -1,67 +1,81 @@ -def ati(array): - """ati(array) -> list - Convert all the elements in the array and return them in a list. - """ +"""Some utility functions.""" +import sys +import random +from typing import cast, Any, Dict, Iterable, Tuple, Union + +__all__ = [ + "ati", "list_like", "int_like", "strtolines", "make_unicode", + "unpack_kwargs", "process_args" +] + + +def ati(array: Iterable[Any]): + """Convert all the elements in the array and return them in a list.""" return [int(i) for i in array] -def list_like(data): - """list_like(data) -> bool - Judge whether the object data is like a list or a tuple. - object data -> the data to judge - """ +def list_like(data: Any): + """Judge whether the object data is like a list or a tuple.""" return isinstance(data, (tuple, list)) -def int_like(data): - isint = False - try: - isint = isint or isinstance(data, long) - except NameError: - pass - isint = isint or isinstance(data, int) - return isint +def int_like(data: Any): + """Judge whether the object data is like a int.""" + return isinstance(data, int) -def strtolines(str): - lines = str.split('\n') - for i in range(len(lines)): +def strtolines(string: str): + """ + Split the string by the newline character, remove trailing spaces from each line, + and remove any blank lines at the end of the the string. + """ + lines = string.split("\n") + for i, _ in enumerate(lines): lines[i] = lines[i].rstrip() - while len(lines) > 0 and len(lines[len(lines) - 1]) == 0: - del lines[len(lines) - 1] + while len(lines) > 0 and len(lines[-1]) == 0: + lines.pop() return lines -def make_unicode(data): +def make_unicode(data: Any): + """Convert the data to a string.""" return str(data) -def unpack_kwargs(funcname, kwargs, arg_pattern): +def unpack_kwargs( + funcname: str, + kwargs: Dict[str, Any], + arg_pattern: Iterable[Union[str, Tuple[str, Any]]], +): + """Parse the keyword arguments.""" rv = {} kwargs = kwargs.copy() for tp in arg_pattern: if list_like(tp): - k, v = tp - rv[k] = kwargs.get(k, v) - try: - del kwargs[k] - except KeyError: - pass + k, v = cast(Tuple[str, Any], tp) + rv[k] = kwargs.pop(k, v) else: - error = False + tp = cast(str, tp) try: - rv[tp] = kwargs[tp] - del kwargs[tp] - except KeyError as e: - error = True - if error: + rv[tp] = kwargs.pop(tp) + except KeyError: raise TypeError( - '{}() missing 1 required keyword-only argument: \'{}\''. - format(funcname, tp)) + f"{funcname}() missing 1 required keyword-only argument: '{tp}'" + ) from None if kwargs: raise TypeError( - '{}() got an unexpected keyword argument \'{}\''.format( - funcname, - next(iter(kwargs.items()))[0])) + f"{funcname}() got an unexpected keyword argument '{next(iter(kwargs.items()))[0]}'" + ) return rv + + +def process_args(): + """ + Process the command line arguments. + Now we support: + - randseed: set the random seed + """ + for s in sys.argv: + if s.startswith("--randseed="): + random.seed(s.split("=")[1])