From 890cdfcb71902ad12094ced1f05dc669cc33925a Mon Sep 17 00:00:00 2001 From: shiftinv Date: Thu, 27 Apr 2023 13:30:58 +0200 Subject: [PATCH] refactor: move thingy to a separate method --- scripts/codemods/typed_events.py | 103 +++++++++++++++++-------------- 1 file changed, 56 insertions(+), 47 deletions(-) diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index de57083041..8ba0f93165 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: MIT +from __future__ import annotations + import types from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, cast @@ -58,60 +60,67 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef): event_data = EVENT_DATA[event] if event_data.event_only: continue - args = event_data.args + new_overloads.append(self.generate_wait_for_overload(node, event, event_data)) - new_overload = node.with_changes( - body=cst.IndentedBlock([cst.SimpleStatementLine([cst.Expr(cst.Ellipsis())])]), - decorators=[ - cst.Decorator(cst.Name("overload")), - cst.Decorator(cst.Name("_generated")), - ], - leading_lines=(), - ) + return cst.FlattenSentinel([*new_overloads, node]) - # set `event` annotation - new_annotation = cst.parse_expression( - # the lazy way of doing things - f'Literal[Event.{event.name}, "{event.value}"]', - config=self.module.config_for_parsing, - ) - new_overload = new_overload.with_deep_changes( - get_param(new_overload, "event"), annotation=cst.Annotation(new_annotation) - ) + def generate_wait_for_overload( + self, func: cst.FunctionDef, event: Event, event_data: EventData + ) -> cst.FunctionDef: + args = event_data.args - # set `check` annotation - callable_annotation = m.findall( - get_param(new_overload, "check"), m.Subscript(m.Name("Callable")) - )[0] - callable_params = m.findall(callable_annotation, m.Ellipsis())[0] - new_annotation = cst.parse_expression( - f'[{",".join(args)}]', - config=self.module.config_for_parsing, - ) - new_overload = new_overload.deep_replace(callable_params, new_annotation) + new_overload = func.with_changes( + body=cst.IndentedBlock([cst.SimpleStatementLine([cst.Expr(cst.Ellipsis())])]), + decorators=[ + cst.Decorator(cst.Name("overload")), + cst.Decorator(cst.Name("_generated")), + ], + leading_lines=(), + ) - # set return annotation - if len(args) == 0: - new_annotation_str = "None" - elif len(args) == 1: - new_annotation_str = args[0] - else: - new_annotation_str = f'Tuple[{",".join(args)}]' - new_annotation = cst.parse_expression( - f"Coroutine[Any, Any, {new_annotation_str}]", config=self.module.config_for_parsing - ) - new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation)) + # set `event` annotation + new_annotation = cst.parse_expression( + # the lazy way of doing things + f'Literal[Event.{event.name}, "{event.value}"]', + config=self.module.config_for_parsing, + ) + new_overload = new_overload.with_deep_changes( + get_param(new_overload, "event"), annotation=cst.Annotation(new_annotation) + ) - # set `self` annotation as a workaround for overloads in subclasses - if event_data.bot: - new_overload = new_overload.with_deep_changes( - get_param(new_overload, "self"), - annotation=cst.Annotation(cst.Name("AnyBot")), - ) + # set `check` annotation + callable_annotation = m.findall( + get_param(new_overload, "check"), m.Subscript(m.Name("Callable")) + )[0] + callable_params = m.findall(callable_annotation, m.Ellipsis())[0] + new_annotation = cst.parse_expression( + f'[{",".join(args)}]', + config=self.module.config_for_parsing, + ) + new_overload = cast( + cst.FunctionDef, new_overload.deep_replace(callable_params, new_annotation) + ) - new_overloads.append(new_overload) + # set return annotation + if len(args) == 0: + new_annotation_str = "None" + elif len(args) == 1: + new_annotation_str = args[0] + else: + new_annotation_str = f'Tuple[{",".join(args)}]' + new_annotation = cst.parse_expression( + f"Coroutine[Any, Any, {new_annotation_str}]", config=self.module.config_for_parsing + ) + new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation)) - return cst.FlattenSentinel([*new_overloads, node]) + # set `self` annotation as a workaround for overloads in subclasses + if event_data.bot: + new_overload = new_overload.with_deep_changes( + get_param(new_overload, "self"), + annotation=cst.Annotation(cst.Name("AnyBot")), + ) + + return new_overload @dataclass