diff --git a/falcon/routing/compiled.py b/falcon/routing/compiled.py index b7d6c3244..cf65cdcc5 100644 --- a/falcon/routing/compiled.py +++ b/falcon/routing/compiled.py @@ -14,12 +14,26 @@ """Default routing engine.""" +from __future__ import annotations + from collections import UserDict from inspect import iscoroutinefunction import keyword import re from threading import Lock -from typing import TYPE_CHECKING +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Pattern, + Set, + Tuple, + Type, + TYPE_CHECKING, + Union, +) from falcon.routing import converters from falcon.routing.util import map_http_methods @@ -29,7 +43,10 @@ from falcon.util.sync import wrap_sync_to_async if TYPE_CHECKING: - from typing import Any # NOQA: F401 + from falcon import Request + + _CxElement = Union['_CxParent', '_CxChild'] + _MethodDict = Dict[str, Callable] _TAB_STR = ' ' * 4 _FIELD_PATTERN = re.compile( @@ -86,10 +103,10 @@ class CompiledRouter: '_compile_lock', ) - def __init__(self): - self._ast = None - self._converters = None - self._finder_src = None + def __init__(self) -> None: + self._ast: _CxParent = _CxParent() + self._converters: List[converters.BaseConverter] = [] + self._finder_src: str = '' self._options = CompiledRouterOptions() @@ -97,9 +114,9 @@ def __init__(self): # here to reduce lookup time. self._converter_map = self._options.converters.data - self._patterns = None - self._return_values = None - self._roots = [] + self._patterns: List[Pattern] = [] + self._return_values: List[CompiledRouterNode] = [] + self._roots: List[CompiledRouterNode] = [] # NOTE(caselit): set _find to the delayed compile method to ensure that # compile is called when the router is first used @@ -107,18 +124,18 @@ def __init__(self): self._compile_lock = Lock() @property - def options(self): + def options(self) -> CompiledRouterOptions: return self._options @property - def finder_src(self): + def finder_src(self) -> str: # NOTE(caselit): ensure that the router is actually compiled before # returning the finder source, since the current value may be out of # date self.find('/') return self._finder_src - def map_http_methods(self, resource, **kwargs): + def map_http_methods(self, resource: object, **kwargs: Any) -> _MethodDict: """Map HTTP methods (e.g., GET, POST) to methods of a resource object. This method is called from :meth:`~.add_route` and may be overridden to @@ -147,7 +164,9 @@ class can use suffixed responders to distinguish requests return map_http_methods(resource, suffix=kwargs.get('suffix', None)) - def add_route(self, uri_template, resource, **kwargs): # noqa: C901 + def add_route( # noqa: C901 + self, uri_template: str, resource: object, **kwargs: Any + ) -> None: """Add a route between a URI path template and a resource. This method may be overridden to customize how a route is added. @@ -186,7 +205,7 @@ class can use suffixed responders to distinguish requests # NOTE(kgriffs): falcon.asgi.App injects this private kwarg; it is # only intended to be used internally. - asgi = kwargs.get('_asgi', False) + asgi: bool = kwargs.get('_asgi', False) method_map = self.map_http_methods(resource, **kwargs) @@ -204,11 +223,11 @@ class can use suffixed responders to distinguish requests path = uri_template.lstrip('/').split('/') - used_names = set() + used_names: Set[str] = set() for segment in path: self._validate_template_segment(segment, used_names) - def find_cmp_converter(node): + def find_cmp_converter(node: CompiledRouterNode) -> Optional[Tuple[str, str]]: value = [ (field, converter) for field, converter, _ in node.var_converter_map @@ -221,7 +240,7 @@ def find_cmp_converter(node): else: return None - def insert(nodes, path_index=0): + def insert(nodes: List[CompiledRouterNode], path_index: int = 0): for node in nodes: segment = path[path_index] if node.matches(segment): @@ -286,7 +305,11 @@ def insert(nodes, path_index=0): else: self._find = self._compile_and_find - def find(self, uri, req=None): + # NOTE(caselit): keep Request as string otherwise sphinx complains that it resolves + # to multiple classes, since the symbol is imported only for type check. + def find( + self, uri: str, req: Optional['Request'] = None + ) -> Optional[Tuple[object, Optional[_MethodDict], Dict[str, Any], Optional[str]]]: """Search for a route that matches the given partial URI. Args: @@ -305,8 +328,8 @@ def find(self, uri, req=None): """ path = uri.lstrip('/').split('/') - params = {} - node = self._find( + params: Dict[str, Any] = {} + node: Optional[CompiledRouterNode] = self._find( path, self._return_values, self._patterns, self._converters, params ) @@ -319,7 +342,7 @@ def find(self, uri, req=None): # Private # ----------------------------------------------------------------- - def _require_coroutine_responders(self, method_map): + def _require_coroutine_responders(self, method_map: _MethodDict) -> None: for method, responder in method_map.items(): # NOTE(kgriffs): We don't simply wrap non-async functions # since they likely perform relatively long blocking @@ -343,7 +366,7 @@ def let(responder=responder): msg = msg.format(responder) raise TypeError(msg) - def _require_non_coroutine_responders(self, method_map): + def _require_non_coroutine_responders(self, method_map: _MethodDict) -> None: for method, responder in method_map.items(): # NOTE(kgriffs): We don't simply wrap non-async functions # since they likely perform relatively long blocking @@ -359,7 +382,7 @@ def _require_non_coroutine_responders(self, method_map): msg = msg.format(responder) raise TypeError(msg) - def _validate_template_segment(self, segment, used_names): + def _validate_template_segment(self, segment: str, used_names: Set[str]) -> None: """Validate a single path segment of a URI template. 1. Ensure field names are valid Python identifiers, since they @@ -414,14 +437,14 @@ def _validate_template_segment(self, segment, used_names): def _generate_ast( # noqa: C901 self, - nodes: list, - parent, - return_values: list, - patterns: list, - params_stack: list, - level=0, - fast_return=True, - ): + nodes: List[CompiledRouterNode], + parent: _CxParent, + return_values: List[CompiledRouterNode], + patterns: List[Pattern], + params_stack: List[_CxElement], + level: int = 0, + fast_return: bool = True, + ) -> None: """Generate a coarse AST for the router.""" # NOTE(caselit): setting of the parameters in the params dict is delayed until # a match has been found by adding them to the param_stack. This way superfluous @@ -457,8 +480,6 @@ def _generate_ast( # noqa: C901 fast_return = not found_var_nodes - construct = None # type: Any - setter = None # type: Any original_params_stack = params_stack.copy() for node in nodes: params_stack = original_params_stack.copy() @@ -473,11 +494,11 @@ def _generate_ast( # noqa: C901 pattern_idx = len(patterns) patterns.append(node.var_pattern) - construct = _CxIfPathSegmentPattern( + cx_segment = _CxIfPathSegmentPattern( level, pattern_idx, node.var_pattern.pattern ) - parent.append_child(construct) - parent = construct + parent.append_child(cx_segment) + parent = cx_segment if node.var_converter_map: parent.append_child(_CxPrefetchGroupsFromPatternMatch()) @@ -486,10 +507,11 @@ def _generate_ast( # noqa: C901 ) else: - construct = _CxVariableFromPatternMatch(len(params_stack) + 1) - setter = _CxSetParamsFromDict(construct.dict_variable_name) - params_stack.append(setter) - parent.append_child(construct) + cx_pattern = _CxVariableFromPatternMatch(len(params_stack) + 1) + params_stack.append( + _CxSetParamsFromDict(cx_pattern.dict_variable_name) + ) + parent.append_child(cx_pattern) else: # NOTE(kgriffs): Simple nodes just capture the entire path @@ -513,16 +535,17 @@ def _generate_ast( # noqa: C901 else: parent.append_child(_CxSetFragmentFromPath(level)) - construct = _CxIfConverterField( + cx_converter = _CxIfConverterField( len(params_stack) + 1, converter_idx ) - setter = _CxSetParamFromValue( - field_name, construct.field_variable_name + params_stack.append( + _CxSetParamFromValue( + field_name, cx_converter.field_variable_name + ) ) - params_stack.append(setter) - parent.append_child(construct) - parent = construct + parent.append_child(cx_converter) + parent = cx_converter else: params_stack.append(_CxSetParamFromPath(node.var_name, level)) @@ -542,9 +565,9 @@ def _generate_ast( # noqa: C901 else: # NOTE(kgriffs): Not a param, so must match exactly - construct = _CxIfPathSegmentLiteral(level, node.raw_segment) - parent.append_child(construct) - parent = construct + cx_literal = _CxIfPathSegmentLiteral(level, node.raw_segment) + parent.append_child(cx_literal) + parent = cx_literal if node.resource is not None: # NOTE(kgriffs): This is a valid route, so we will want to @@ -576,11 +599,11 @@ def _generate_ast( # noqa: C901 # NOTE(kgriffs): Make sure that we have consumed all of # the segments for the requested route; otherwise we could # mistakenly match "/foo/23/bar" against "/foo/{id}". - construct = _CxIfPathLength('==', level + 1) + cx_path_len = _CxIfPathLength('==', level + 1) for params in params_stack: - construct.append_child(params) - construct.append_child(_CxReturnValue(resource_idx)) - parent.append_child(construct) + cx_path_len.append_child(params) + cx_path_len.append_child(_CxReturnValue(resource_idx)) + parent.append_child(cx_path_len) if fast_return: parent.append_child(_CxReturnNone()) @@ -591,10 +614,11 @@ def _generate_ast( # noqa: C901 parent.append_child(_CxReturnNone()) def _generate_conversion_ast( - self, parent, node: 'CompiledRouterNode', params_stack: list - ): - construct = None # type: Any - setter = None # type: Any + self, + parent: _CxParent, + node: CompiledRouterNode, + params_stack: List[_CxElement], + ) -> _CxParent: # NOTE(kgriffs): Unroll the converter loop into # a series of nested "if" constructs. for field_name, converter_name, converter_argstr in node.var_converter_map: @@ -609,24 +633,28 @@ def _generate_conversion_ast( parent.append_child(_CxSetFragmentFromField(field_name)) - construct = _CxIfConverterField(len(params_stack) + 1, converter_idx) - setter = _CxSetParamFromValue(field_name, construct.field_variable_name) - params_stack.append(setter) + cx_converter = _CxIfConverterField(len(params_stack) + 1, converter_idx) + params_stack.append( + _CxSetParamFromValue(field_name, cx_converter.field_variable_name) + ) - parent.append_child(construct) - parent = construct + parent.append_child(cx_converter) + parent = cx_converter # NOTE(kgriffs): Add remaining fields that were not # converted, if any. if node.num_fields > len(node.var_converter_map): - construct = _CxVariableFromPatternMatchPrefetched(len(params_stack) + 1) - setter = _CxSetParamsFromDict(construct.dict_variable_name) - params_stack.append(setter) - parent.append_child(construct) + cx_pattern_match = _CxVariableFromPatternMatchPrefetched( + len(params_stack) + 1 + ) + params_stack.append( + _CxSetParamsFromDict(cx_pattern_match.dict_variable_name) + ) + parent.append_child(cx_pattern_match) return parent - def _compile(self): + def _compile(self) -> Callable: """Generate Python code for the entire routing tree. The generated code is compiled and the resulting Python method @@ -649,19 +677,19 @@ def _compile(self): src_lines.append(self._ast.src(0)) - src_lines.append( - # PERF(kgriffs): Explicit return of None is faster than implicit - _TAB_STR + 'return None' - ) + # PERF(kgriffs): Explicit return of None is faster than implicit + src_lines.append(_TAB_STR + 'return None') self._finder_src = '\n'.join(src_lines) - scope = {} + scope: _MethodDict = {} exec(compile(self._finder_src, '', 'exec'), scope) return scope['find'] - def _instantiate_converter(self, klass, argstr=None): + def _instantiate_converter( + self, klass: type, argstr: Optional[str] = None + ) -> converters.BaseConverter: if argstr is None: return klass() @@ -669,7 +697,14 @@ def _instantiate_converter(self, klass, argstr=None): src = '{0}({1})'.format(klass.__name__, argstr) return eval(src, {klass.__name__: klass}) - def _compile_and_find(self, path, _return_values, _patterns, _converters, params): + def _compile_and_find( + self, + path: List[str], + _return_values: Any, + _patterns: Any, + _converters: Any, + params: Any, + ) -> Any: """Compile the router, set the `_find` attribute and return its result. This method is set to the `_find` attribute to delay the compilation of the @@ -704,8 +739,14 @@ class UnacceptableRouteError(ValueError): class CompiledRouterNode: """Represents a single URI segment in a URI.""" - def __init__(self, raw_segment, method_map=None, resource=None, uri_template=None): - self.children = [] + def __init__( + self, + raw_segment: str, + method_map: Optional[_MethodDict] = None, + resource: Optional[object] = None, + uri_template: Optional[str] = None, + ): + self.children: List[CompiledRouterNode] = [] self.raw_segment = raw_segment self.method_map = method_map @@ -718,9 +759,9 @@ def __init__(self, raw_segment, method_map=None, resource=None, uri_template=Non # TODO(kgriffs): Rename these since the docs talk about "fields" # or "field expressions", not "vars" or "variables". - self.var_name = None - self.var_pattern = None - self.var_converter_map = [] + self.var_name: Optional[str] = None + self.var_pattern: Optional[Pattern] = None + self.var_converter_map: List[Tuple[str, str, str]] = [] # NOTE(kgriffs): CompiledRouter.add_route validates field names, # so here we can just assume they are OK and use the simple @@ -792,12 +833,12 @@ def __init__(self, raw_segment, method_map=None, resource=None, uri_template=Non if self.is_complex: assert self.is_var - def matches(self, segment): + def matches(self, segment: str): """Return True if this node matches the supplied template segment.""" return segment == self.raw_segment - def conflicts_with(self, segment): + def conflicts_with(self, segment: str): """Return True if this node conflicts with a given template segment.""" # NOTE(kgriffs): This method assumes that the caller has already @@ -857,6 +898,8 @@ def conflicts_with(self, segment): class ConverterDict(UserDict): """A dict-like class for storing field converters.""" + data: Dict[str, Type[converters.BaseConverter]] + def __setitem__(self, name, converter): self._validate(name) UserDict.__setitem__(self, name, converter) @@ -906,6 +949,8 @@ class CompiledRouterOptions: __slots__ = ('converters',) + converters: ConverterDict + def __init__(self): object.__setattr__( self, @@ -934,12 +979,12 @@ def __setattr__(self, name, value) -> None: class _CxParent: def __init__(self): - self._children = [] + self._children: List[_CxElement] = [] - def append_child(self, construct): + def append_child(self, construct: _CxElement): self._children.append(construct) - def src(self, indentation): + def src(self, indentation: int) -> str: return self._children_src(indentation + 1) def _children_src(self, indentation): @@ -948,6 +993,12 @@ def _children_src(self, indentation): return '\n'.join(src_lines) +class _CxChild: + # This a base element only to aid pep484 + def src(self, indentation: int) -> str: + raise NotImplementedError + + class _CxIfPathLength(_CxParent): def __init__(self, comparison, length): super().__init__() @@ -1025,7 +1076,7 @@ def src(self, indentation): return '\n'.join(lines) -class _CxSetFragmentFromField: +class _CxSetFragmentFromField(_CxChild): def __init__(self, field_name): self._field_name = field_name @@ -1036,7 +1087,7 @@ def src(self, indentation): ) -class _CxSetFragmentFromPath: +class _CxSetFragmentFromPath(_CxChild): def __init__(self, segment_idx): self._segment_idx = segment_idx @@ -1047,7 +1098,7 @@ def src(self, indentation): ) -class _CxSetFragmentFromRemainingPaths: +class _CxSetFragmentFromRemainingPaths(_CxChild): def __init__(self, segment_idx): self._segment_idx = segment_idx @@ -1058,7 +1109,7 @@ def src(self, indentation): ) -class _CxVariableFromPatternMatch: +class _CxVariableFromPatternMatch(_CxChild): def __init__(self, unique_idx): self._unique_idx = unique_idx self.dict_variable_name = 'dict_match_{0}'.format(unique_idx) @@ -1069,7 +1120,7 @@ def src(self, indentation): ) -class _CxVariableFromPatternMatchPrefetched: +class _CxVariableFromPatternMatchPrefetched(_CxChild): def __init__(self, unique_idx): self._unique_idx = unique_idx self.dict_variable_name = 'dict_groups_{0}'.format(unique_idx) @@ -1078,17 +1129,17 @@ def src(self, indentation): return '{0}{1} = groups'.format(_TAB_STR * indentation, self.dict_variable_name) -class _CxPrefetchGroupsFromPatternMatch: +class _CxPrefetchGroupsFromPatternMatch(_CxChild): def src(self, indentation): return '{0}groups = match.groupdict()'.format(_TAB_STR * indentation) -class _CxReturnNone: +class _CxReturnNone(_CxChild): def src(self, indentation): return '{0}return None'.format(_TAB_STR * indentation) -class _CxReturnValue: +class _CxReturnValue(_CxChild): def __init__(self, value_idx): self._value_idx = value_idx @@ -1098,7 +1149,7 @@ def src(self, indentation): ) -class _CxSetParamFromPath: +class _CxSetParamFromPath(_CxChild): def __init__(self, param_name, segment_idx): self._param_name = param_name self._segment_idx = segment_idx @@ -1111,7 +1162,7 @@ def src(self, indentation): ) -class _CxSetParamFromValue: +class _CxSetParamFromValue(_CxChild): def __init__(self, param_name, field_value_name): self._param_name = param_name self._field_value_name = field_value_name @@ -1124,7 +1175,7 @@ def src(self, indentation): ) -class _CxSetParamsFromDict: +class _CxSetParamsFromDict(_CxChild): def __init__(self, dict_value_name): self._dict_value_name = dict_value_name diff --git a/falcon/routing/converters.py b/falcon/routing/converters.py index 8fd28fa32..2d2bc7fa1 100644 --- a/falcon/routing/converters.py +++ b/falcon/routing/converters.py @@ -58,7 +58,7 @@ def convert(self, value): """ -def _consumes_multiple_segments(converter): +def _consumes_multiple_segments(converter: object) -> bool: return getattr(converter, 'CONSUME_MULTIPLE_SEGMENTS', False) diff --git a/falcon/routing/util.py b/falcon/routing/util.py index b4f1fd8ca..684699a3e 100644 --- a/falcon/routing/util.py +++ b/falcon/routing/util.py @@ -14,7 +14,10 @@ """Routing utilities.""" +from __future__ import annotations + import re +from typing import Callable, Dict, Optional from falcon import constants from falcon import responders @@ -99,7 +102,9 @@ def compile_uri_template(template): return fields, re.compile(pattern, re.IGNORECASE) -def map_http_methods(resource, suffix=None): +def map_http_methods( + resource: object, suffix: Optional[str] = None +) -> Dict[str, Callable]: """Map HTTP methods (e.g., GET, POST) to methods of a resource object. Args: diff --git a/tests/test_compiled_router.py b/tests/test_compiled_router.py index 3ef622fa8..6e10d1b3d 100644 --- a/tests/test_compiled_router.py +++ b/tests/test_compiled_router.py @@ -5,6 +5,7 @@ import pytest +from falcon.routing import compiled from falcon.routing import CompiledRouter from falcon.routing import CompiledRouterOptions @@ -139,3 +140,8 @@ def convert(self, v): assert res is not None assert res[2] == {'bar': 'bar'} assert router.find('/foo/bar/bar') is None + + +def test_base_classes(): + with pytest.raises(NotImplementedError): + compiled._CxChild().src(42)