diff --git a/fawltydeps/extract_imports.py b/fawltydeps/extract_imports.py index 7c06f08d..c59ebd9d 100644 --- a/fawltydeps/extract_imports.py +++ b/fawltydeps/extract_imports.py @@ -5,7 +5,7 @@ import logging import tokenize from pathlib import Path -from typing import Iterable, Iterator, Optional, TextIO, Tuple, Union +from typing import BinaryIO, Iterable, Iterator, Optional, Tuple, Union import isort @@ -167,7 +167,7 @@ def parse_python_file( def parse_source( - src: CodeSource, stdin: Optional[TextIO] = None + src: CodeSource, stdin: Optional[BinaryIO] = None ) -> Iterator[ParsedImport]: """Invoke a suitable parser for the given source. @@ -206,7 +206,7 @@ def parse_source( def parse_sources( - sources: Iterable[CodeSource], stdin: Optional[TextIO] = None + sources: Iterable[CodeSource], stdin: Optional[BinaryIO] = None ) -> Iterator[ParsedImport]: """Parse import statements from the given sources.""" for source in sources: diff --git a/fawltydeps/main.py b/fawltydeps/main.py index 6e387419..1b969265 100644 --- a/fawltydeps/main.py +++ b/fawltydeps/main.py @@ -15,7 +15,7 @@ import sys from functools import partial from operator import attrgetter -from typing import Dict, Iterator, List, Optional, Set, TextIO, Type +from typing import BinaryIO, Dict, Iterator, List, Optional, Set, TextIO, Type try: # import from Pydantic V2 from pydantic.v1.json import custom_pydantic_encoder @@ -78,7 +78,7 @@ class Analysis: # pylint: disable=too-many-instance-attributes .imports). """ - def __init__(self, settings: Settings, stdin: Optional[TextIO] = None): + def __init__(self, settings: Settings, stdin: Optional[BinaryIO] = None): self.settings = settings self.stdin = stdin self.version = version() @@ -167,7 +167,7 @@ def unused_deps(self) -> List[UnusedDependency]: ) @classmethod - def create(cls, settings: Settings, stdin: Optional[TextIO] = None) -> "Analysis": + def create(cls, settings: Settings, stdin: Optional[BinaryIO] = None) -> "Analysis": """Exercise FawltyDeps' core logic according to the given settings. Perform the actions specified in 'settings.actions' and apply the other @@ -347,7 +347,7 @@ def print_output( def main( cmdline_args: Optional[List[str]] = None, # defaults to sys.argv[1:] - stdin: TextIO = sys.stdin, + stdin: BinaryIO = sys.stdin.buffer, stdout: TextIO = sys.stdout, ) -> int: """Command-line entry point.""" diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index d3257609..2c08ca8a 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -23,6 +23,7 @@ from .test_extract_imports_simple import generate_notebook from .utils import ( assert_unordered_equivalence, + dedent_bytes, run_fawltydeps_function, run_fawltydeps_subprocess, ) @@ -282,6 +283,25 @@ def test_list_imports__pick_multiple_files_dir_and_code__prints_all_imports( assert returncode == 0 +def test_list_imports__stdin_with_legacy_encoding__prints_all_imports(): + code = dedent_bytes( + b"""\ + # -*- coding: big5 -*- + + # Some Traditional Chinese characters: + chars = "\xa4@\xa8\xc7\xa4\xa4\xa4\xe5\xa6r\xb2\xc5" + + import numpy + """ + ) + output, returncode = run_fawltydeps_function( + "--list-imports", "--code", "-", to_stdin=code + ) + expect = ["numpy"] + assert_unordered_equivalence(output.splitlines()[:-2], expect) + assert returncode == 0 + + def test_list_deps_detailed__dir__prints_deps_from_requirements_txt(fake_project): tmp_path = fake_project( imports=["requests", "pandas"], diff --git a/tests/test_extract_imports_simple.py b/tests/test_extract_imports_simple.py index a6cfc18c..5907dd91 100644 --- a/tests/test_extract_imports_simple.py +++ b/tests/test_extract_imports_simple.py @@ -1,6 +1,7 @@ """Test that we can extract simple imports from Python code.""" import json import logging +from io import BytesIO from pathlib import Path from textwrap import dedent from typing import Dict, List, Tuple, Union @@ -561,3 +562,19 @@ def test_parse_sources__ignore_first_party_imports( ] assert list(parse_sources(code_sources)) == expect + + +def test_parse_sources__legacy_encoding_on_stdin__extracts_import(): + code = dedent_bytes( + b"""\ + # -*- coding: big5 -*- + + # Some Traditional Chinese characters: + chars = "\xa4@\xa8\xc7\xa4\xa4\xa4\xe5\xa6r\xb2\xc5" + + import numpy + """ + ) + + expect = imports_w_linenos([("numpy", 6)], "") + assert list(parse_sources([CodeSource("")], BytesIO(code))) == expect diff --git a/tests/utils.py b/tests/utils.py index 19aafc45..417bf9de 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field, replace from pathlib import Path from textwrap import dedent -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from fawltydeps.main import main from fawltydeps.packages import IdentityMapping, LocalPackageResolver, Package @@ -129,19 +129,21 @@ def run_fawltydeps_subprocess( def run_fawltydeps_function( *args: str, config_file: Path = Path("/dev/null"), - to_stdin: Optional[str] = None, + to_stdin: Optional[Union[str, bytes]] = None, basepath: Optional[Path] = None, ) -> Tuple[str, int]: """Run FawltyDeps with `main` function. Designed for unit tests. Ignores logging output and returns stdout and the exit code """ + if isinstance(to_stdin, str): + to_stdin = to_stdin.encode() output = io.StringIO() exit_code = main( cmdline_args=([str(basepath)] if basepath else []) + [f"--config-file={str(config_file)}"] + list(args), - stdin=io.StringIO(to_stdin), + stdin=io.BytesIO(to_stdin or b""), stdout=output, )