Skip to content

Commit

Permalink
refactor: make codemod more extensible
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv committed Apr 28, 2023
1 parent 890cdfc commit a52790d
Showing 1 changed file with 30 additions and 19 deletions.
49 changes: 30 additions & 19 deletions scripts/codemods/typed_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef):
# ignore
return node

if node.name.value != "wait_for":
if node.name.value == "wait_for":
generator = self.generate_wait_for_overload
else:
raise RuntimeError(
f"unknown method '{node.name.value}' with @_overload_with_events decorator"
)
Expand All @@ -60,16 +62,12 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef):
event_data = EVENT_DATA[event]
if event_data.event_only:
continue
new_overloads.append(self.generate_wait_for_overload(node, event, event_data))
new_overloads.append(generator(node, event, event_data))

return cst.FlattenSentinel([*new_overloads, node])

def generate_wait_for_overload(
self, func: cst.FunctionDef, event: Event, event_data: EventData
) -> cst.FunctionDef:
args = event_data.args

new_overload = func.with_changes(
def create_empty_overload(self, func: cst.FunctionDef) -> cst.FunctionDef:
return func.with_changes(
body=cst.IndentedBlock([cst.SimpleStatementLine([cst.Expr(cst.Ellipsis())])]),
decorators=[
cst.Decorator(cst.Name("overload")),
Expand All @@ -78,27 +76,39 @@ def generate_wait_for_overload(
leading_lines=(),
)

# set `event` annotation
new_annotation = cst.parse_expression(
# the lazy way of doing things
def create_literal(self, event: Event) -> cst.BaseExpression:
return cst.parse_expression(
f'Literal[Event.{event.name}, "{event.value}"]',
config=self.module.config_for_parsing,
)

def create_args_list(self, event_data: EventData) -> cst.BaseExpression:
return cst.parse_expression(
f'[{",".join(event_data.args)}]',
config=self.module.config_for_parsing,
)

def generate_wait_for_overload(
self, func: cst.FunctionDef, event: Event, event_data: EventData
) -> cst.FunctionDef:
args = event_data.args

new_overload = self.create_empty_overload(func)

# set `event` annotation
new_overload = new_overload.with_deep_changes(
get_param(new_overload, "event"), annotation=cst.Annotation(new_annotation)
get_param(new_overload, "event"),
annotation=cst.Annotation(self.create_literal(event)),
)

# set `check` annotation
callable_annotation = m.findall(
get_param(new_overload, "check"), m.Subscript(m.Name("Callable"))
)[0]
callable_params = m.findall(callable_annotation, m.Ellipsis())[0]
new_annotation = cst.parse_expression(
f'[{",".join(args)}]',
config=self.module.config_for_parsing,
)
new_overload = cast(
cst.FunctionDef, new_overload.deep_replace(callable_params, new_annotation)
cst.FunctionDef,
new_overload.deep_replace(callable_params, self.create_args_list(event_data)),
)

# set return annotation
Expand All @@ -109,7 +119,8 @@ def generate_wait_for_overload(
else:
new_annotation_str = f'Tuple[{",".join(args)}]'
new_annotation = cst.parse_expression(
f"Coroutine[Any, Any, {new_annotation_str}]", config=self.module.config_for_parsing
f"Coroutine[Any, Any, {new_annotation_str}]",
config=self.module.config_for_parsing,
)
new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation))

Expand All @@ -136,7 +147,7 @@ class EventData:
EVENT_DATA: Dict[Event, EventData] = {
Event.connect: EventData(()),
Event.disconnect: EventData(()),
# TODO: figure out how to specify varargs for these two
# FIXME: figure out how to specify varargs for these two if we ever add overloads for @event
Event.error: EventData((), event_only=True),
Event.gateway_error: EventData((), event_only=True),
Event.ready: EventData(()),
Expand Down

0 comments on commit a52790d

Please sign in to comment.