From 066477f71f386824d49a0fcf716788bf3aee0416 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Mon, 4 Nov 2024 13:18:41 +0100 Subject: [PATCH] Fix codegen single-line logic for return types (#634) --- codegen/tests/test_codegen_utils.py | 28 ++++++ codegen/utils.py | 133 +++++++++++++++++++--------- 2 files changed, 120 insertions(+), 41 deletions(-) diff --git a/codegen/tests/test_codegen_utils.py b/codegen/tests/test_codegen_utils.py index 288f2828..8e1de9c0 100644 --- a/codegen/tests/test_codegen_utils.py +++ b/codegen/tests/test_codegen_utils.py @@ -144,6 +144,34 @@ def foo(a1, a2, a3): # hi ha ho assert code3 == code2 +def test_format_code_return_type(): + code1 = """ + def foo() -> None: + pass + def foo( + a1, + a2, + a3, + ) -> None: + pass + """ + + code2 = """ + def foo() -> None: + pass + def foo(a1, a2, a3) -> None: + pass + """ + + code1 = dedent(code1).strip() + code2 = dedent(code2).strip() + + code3 = format_code(code1, True) + code3 = code3.replace("\n\n", "\n").replace("\n\n", "\n").strip() + + assert code3 == code2 + + def test_patcher(): code = """ class Foo1: diff --git a/codegen/utils.py b/codegen/utils.py index 1d1df8ff..5b3b2a7a 100644 --- a/codegen/utils.py +++ b/codegen/utils.py @@ -144,51 +144,102 @@ def format_code(src, singleline=False): # Make defs single-line. You'd think that setting the line length # to a very high number would do the trick, but it does not. if singleline: - lines1 = result.splitlines() - lines2 = [] - in_sig = False - comment = "" - for line in lines1: - if in_sig: - # Handle comment - line, _, c = line.partition("#") - line = line.rstrip() - c = c.strip() - if c: - comment += " " + c.strip() - # Detect end - if line.endswith("):"): - in_sig = False - # Compose line - current_line = lines2[-1] - if not current_line.endswith("("): - current_line += " " - current_line += line.lstrip() - # Finalize - if not in_sig: - # Remove trailing spaces and commas - current_line = current_line.replace(" ):", "):") - current_line = current_line.replace(",):", "):") - # Add comment - if comment: - current_line += " #" + comment - comment = "" - lines2[-1] = current_line - else: - lines2.append(line) - line_nc = line.split("#")[0].strip() - if ( - line_nc.startswith(("def ", "async def", "class ")) - and "(" in line_nc - ): - if not line_nc.endswith("):"): - in_sig = True - lines2.append("") - result = "\n".join(lines2) + result = _make_sigs_singline(result) return result +def _make_sigs_singline(code): + lines1 = code.splitlines() + lines2 = [] + + sig_state = 0 + sig_brace_depth = 0 + sig_line = "" + sig_comment = "" + + for line in lines1: + # Check to enter in signature-retrieval mode + if not sig_state: + if line.lstrip().startswith(("def ", "async def", "class ")): + sig_state = 1 + sig_line = "" + sig_comment = "" + sig_brace_depth = 0 + if line.lstrip().startswith("class ") and "(" not in line: + sig_state = 3 # search for closing colon directly + else: + lines2.append(line) + continue + + # If we get here, we're in a signature + + # Handle comment + line, _, c = line.partition("#") + line = line.rstrip() + c = c.strip() + if c: + sig_comment += " " + c.strip() + + if sig_state == 1: + # Find the first opening brace + i = line.find("(") + if i >= 0: + i += 1 + sig_brace_depth = 1 + sig_line += line[:i] + line = line[i:] + sig_state = 2 + + line = line.lstrip() + if sig_line.endswith(","): + line = " " + line + + if sig_state == 2: + # Resolve braces until we find the closing brace + while True: + i1 = line.find("(") + i2 = line.find(")") + if i1 >= 0 and i1 < i2: + i = i1 + 1 + sig_brace_depth += 1 + sig_line += line[:i] + line = line[i:] + elif i2 >= 0: + i = i2 + 1 + sig_brace_depth -= 1 + sig_line += line[:i] + line = line[i:] + else: + break + if sig_brace_depth == 0: + sig_state = 3 + break + + if sig_state == 3: + # Find the closing colon + i = line.find(":") + if i >= 0: + i += 1 + sig_line += line[:i] + line = line[i:] + # Finish the signature line + sig_line = sig_line.replace(", )", ")") + if sig_comment: + sig_line += " #" + sig_comment + lines2.append(sig_line) + if line.strip(): + lines2.append(" " * 12 + line.strip()) + # End the search + sig_state = 0 + line = "" + + sig_line += line + + lines2.append("") + return "\n".join(lines2) + + class Patcher: """Class to help patch a Python module. Supports iterating (over lines, classes, properties, methods), and applying diffs (replace,