diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 56fd05c..b899550 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -92,3 +92,32 @@ def test_errorcodes_distinct(): def test_parse_error_code_str(case, expected): assert process_error_code_str(case) == expected + + +def test_stderr_suppression(tmp_path): + data = "import torchvision.datasets as datasets\n" + data_path = tmp_path / "fixable.py" + data_path.write_text(data) + result = subprocess.run( + ["torchfix", "--select", "TOR203", "--fix", str(data_path)], + stderr=subprocess.PIPE, + text=True, + ) + assert ( + result.stderr + == "Finished checking 1 files.\nTransformed 1 files successfully.\n" + ) + + data = "import torchvision.datasets as datasets\n" + data_path = tmp_path / "fixable.py" + data_path.write_text(data) + result = subprocess.run( + ["torchfix", "--select", "TOR203", "--show-stderr", "--fix", str(data_path)], + stderr=subprocess.PIPE, + text=True, + ) + assert ( + result.stderr + == "Executing codemod...\nFailed to determine module name for {path}: '{path}' is not in the subpath of '' OR one path is relative and the other is absolute.\nFinished checking 1 files.\nTransformed 1 files successfully.\n" + ) +