Skip to content

Commit

Permalink
refactor: move thingy to a separate method
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv committed Apr 28, 2023
1 parent caf558d commit 890cdfc
Showing 1 changed file with 56 additions and 47 deletions.
103 changes: 56 additions & 47 deletions scripts/codemods/typed_events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: MIT

from __future__ import annotations

import types
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, cast
Expand Down Expand Up @@ -58,60 +60,67 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef):
event_data = EVENT_DATA[event]
if event_data.event_only:
continue
args = event_data.args
new_overloads.append(self.generate_wait_for_overload(node, event, event_data))

new_overload = node.with_changes(
body=cst.IndentedBlock([cst.SimpleStatementLine([cst.Expr(cst.Ellipsis())])]),
decorators=[
cst.Decorator(cst.Name("overload")),
cst.Decorator(cst.Name("_generated")),
],
leading_lines=(),
)
return cst.FlattenSentinel([*new_overloads, node])

# set `event` annotation
new_annotation = cst.parse_expression(
# the lazy way of doing things
f'Literal[Event.{event.name}, "{event.value}"]',
config=self.module.config_for_parsing,
)
new_overload = new_overload.with_deep_changes(
get_param(new_overload, "event"), annotation=cst.Annotation(new_annotation)
)
def generate_wait_for_overload(
self, func: cst.FunctionDef, event: Event, event_data: EventData
) -> cst.FunctionDef:
args = event_data.args

# set `check` annotation
callable_annotation = m.findall(
get_param(new_overload, "check"), m.Subscript(m.Name("Callable"))
)[0]
callable_params = m.findall(callable_annotation, m.Ellipsis())[0]
new_annotation = cst.parse_expression(
f'[{",".join(args)}]',
config=self.module.config_for_parsing,
)
new_overload = new_overload.deep_replace(callable_params, new_annotation)
new_overload = func.with_changes(
body=cst.IndentedBlock([cst.SimpleStatementLine([cst.Expr(cst.Ellipsis())])]),
decorators=[
cst.Decorator(cst.Name("overload")),
cst.Decorator(cst.Name("_generated")),
],
leading_lines=(),
)

# set return annotation
if len(args) == 0:
new_annotation_str = "None"
elif len(args) == 1:
new_annotation_str = args[0]
else:
new_annotation_str = f'Tuple[{",".join(args)}]'
new_annotation = cst.parse_expression(
f"Coroutine[Any, Any, {new_annotation_str}]", config=self.module.config_for_parsing
)
new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation))
# set `event` annotation
new_annotation = cst.parse_expression(
# the lazy way of doing things
f'Literal[Event.{event.name}, "{event.value}"]',
config=self.module.config_for_parsing,
)
new_overload = new_overload.with_deep_changes(
get_param(new_overload, "event"), annotation=cst.Annotation(new_annotation)
)

# set `self` annotation as a workaround for overloads in subclasses
if event_data.bot:
new_overload = new_overload.with_deep_changes(
get_param(new_overload, "self"),
annotation=cst.Annotation(cst.Name("AnyBot")),
)
# set `check` annotation
callable_annotation = m.findall(
get_param(new_overload, "check"), m.Subscript(m.Name("Callable"))
)[0]
callable_params = m.findall(callable_annotation, m.Ellipsis())[0]
new_annotation = cst.parse_expression(
f'[{",".join(args)}]',
config=self.module.config_for_parsing,
)
new_overload = cast(
cst.FunctionDef, new_overload.deep_replace(callable_params, new_annotation)
)

new_overloads.append(new_overload)
# set return annotation
if len(args) == 0:
new_annotation_str = "None"
elif len(args) == 1:
new_annotation_str = args[0]
else:
new_annotation_str = f'Tuple[{",".join(args)}]'
new_annotation = cst.parse_expression(
f"Coroutine[Any, Any, {new_annotation_str}]", config=self.module.config_for_parsing
)
new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation))

return cst.FlattenSentinel([*new_overloads, node])
# set `self` annotation as a workaround for overloads in subclasses
if event_data.bot:
new_overload = new_overload.with_deep_changes(
get_param(new_overload, "self"),
annotation=cst.Annotation(cst.Name("AnyBot")),
)

return new_overload


@dataclass
Expand Down

0 comments on commit 890cdfc

Please sign in to comment.