From a52790d008139633723d128c6de9a5d71c367ba4 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Fri, 28 Apr 2023 14:51:30 +0200 Subject: [PATCH] refactor: make codemod more extensible --- scripts/codemods/typed_events.py | 49 +++++++++++++++++++------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index 8ba0f93165..846675bc01 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -49,7 +49,9 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef): # ignore return node - if node.name.value != "wait_for": + if node.name.value == "wait_for": + generator = self.generate_wait_for_overload + else: raise RuntimeError( f"unknown method '{node.name.value}' with @_overload_with_events decorator" ) @@ -60,16 +62,12 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef): event_data = EVENT_DATA[event] if event_data.event_only: continue - new_overloads.append(self.generate_wait_for_overload(node, event, event_data)) + new_overloads.append(generator(node, event, event_data)) return cst.FlattenSentinel([*new_overloads, node]) - def generate_wait_for_overload( - self, func: cst.FunctionDef, event: Event, event_data: EventData - ) -> cst.FunctionDef: - args = event_data.args - - new_overload = func.with_changes( + def create_empty_overload(self, func: cst.FunctionDef) -> cst.FunctionDef: + return func.with_changes( body=cst.IndentedBlock([cst.SimpleStatementLine([cst.Expr(cst.Ellipsis())])]), decorators=[ cst.Decorator(cst.Name("overload")), @@ -78,14 +76,29 @@ def generate_wait_for_overload( leading_lines=(), ) - # set `event` annotation - new_annotation = cst.parse_expression( - # the lazy way of doing things + def create_literal(self, event: Event) -> cst.BaseExpression: + return cst.parse_expression( f'Literal[Event.{event.name}, "{event.value}"]', config=self.module.config_for_parsing, ) + + def create_args_list(self, event_data: EventData) -> cst.BaseExpression: + return cst.parse_expression( + f'[{",".join(event_data.args)}]', + config=self.module.config_for_parsing, + ) + + def generate_wait_for_overload( + self, func: cst.FunctionDef, event: Event, event_data: EventData + ) -> cst.FunctionDef: + args = event_data.args + + new_overload = self.create_empty_overload(func) + + # set `event` annotation new_overload = new_overload.with_deep_changes( - get_param(new_overload, "event"), annotation=cst.Annotation(new_annotation) + get_param(new_overload, "event"), + annotation=cst.Annotation(self.create_literal(event)), ) # set `check` annotation @@ -93,12 +106,9 @@ def generate_wait_for_overload( get_param(new_overload, "check"), m.Subscript(m.Name("Callable")) )[0] callable_params = m.findall(callable_annotation, m.Ellipsis())[0] - new_annotation = cst.parse_expression( - f'[{",".join(args)}]', - config=self.module.config_for_parsing, - ) new_overload = cast( - cst.FunctionDef, new_overload.deep_replace(callable_params, new_annotation) + cst.FunctionDef, + new_overload.deep_replace(callable_params, self.create_args_list(event_data)), ) # set return annotation @@ -109,7 +119,8 @@ def generate_wait_for_overload( else: new_annotation_str = f'Tuple[{",".join(args)}]' new_annotation = cst.parse_expression( - f"Coroutine[Any, Any, {new_annotation_str}]", config=self.module.config_for_parsing + f"Coroutine[Any, Any, {new_annotation_str}]", + config=self.module.config_for_parsing, ) new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation)) @@ -136,7 +147,7 @@ class EventData: EVENT_DATA: Dict[Event, EventData] = { Event.connect: EventData(()), Event.disconnect: EventData(()), - # TODO: figure out how to specify varargs for these two + # FIXME: figure out how to specify varargs for these two if we ever add overloads for @event Event.error: EventData((), event_only=True), Event.gateway_error: EventData((), event_only=True), Event.ready: EventData(()),