From e470e62395762e52f5bf72347f40657711de04ec Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Fri, 20 Sep 2024 17:14:04 +0200 Subject: [PATCH] feat(typing): annotate routing package (#2327) * typing: type app * typing: type websocket module * typing: type asgi.reader, asgi.structures, asgi.stream * typing: type most of media * typing: type multipart * typing: type response * style: fix spelling in multipart.py * style(tests): explain referencing the same property multiple times * style: fix linter errors * chore: revert behavioral change to cors middleware. * typing: type falcon.routing package * chore: do not build rapidjson on PyPy --------- Co-authored-by: Vytautas Liuolia --- falcon/routing/compiled.py | 90 +++++++++++++++++------------------- falcon/routing/converters.py | 65 +++++++++++++++++--------- falcon/routing/util.py | 19 ++++---- pyproject.toml | 2 - 4 files changed, 95 insertions(+), 81 deletions(-) diff --git a/falcon/routing/compiled.py b/falcon/routing/compiled.py index 443d0d4f3..6407484c6 100644 --- a/falcon/routing/compiled.py +++ b/falcon/routing/compiled.py @@ -240,7 +240,7 @@ def find_cmp_converter(node: CompiledRouterNode) -> Optional[Tuple[str, str]]: else: return None - def insert(nodes: List[CompiledRouterNode], path_index: int = 0): + def insert(nodes: List[CompiledRouterNode], path_index: int = 0) -> None: for node in nodes: segment = path[path_index] if node.matches(segment): @@ -351,12 +351,7 @@ def _require_coroutine_responders(self, method_map: MethodDict) -> None: # issue. if not iscoroutinefunction(responder) and is_python_func(responder): if _should_wrap_non_coroutines(): - - def let(responder=responder): - method_map[method] = wrap_sync_to_async(responder) - - let() - + method_map[method] = wrap_sync_to_async(responder) else: msg = ( 'The {} responder must be a non-blocking ' @@ -515,12 +510,13 @@ def _generate_ast( # noqa: C901 else: # NOTE(kgriffs): Simple nodes just capture the entire path - # segment as the value for the param. + # segment as the value for the param. They have a var_name defined + field_name = node.var_name + assert field_name is not None if node.var_converter_map: assert len(node.var_converter_map) == 1 - field_name = node.var_name __, converter_name, converter_argstr = node.var_converter_map[0] converter_class = self._converter_map[converter_name] @@ -547,7 +543,7 @@ def _generate_ast( # noqa: C901 parent.append_child(cx_converter) parent = cx_converter else: - params_stack.append(_CxSetParamFromPath(node.var_name, level)) + params_stack.append(_CxSetParamFromPath(field_name, level)) # NOTE(kgriffs): We don't allow multiple simple var nodes # to exist at the same level, e.g.: @@ -745,7 +741,7 @@ def __init__( method_map: Optional[MethodDict] = None, resource: Optional[object] = None, uri_template: Optional[str] = None, - ): + ) -> None: self.children: List[CompiledRouterNode] = [] self.raw_segment = raw_segment @@ -833,12 +829,12 @@ def __init__( if self.is_complex: assert self.is_var - def matches(self, segment: str): + def matches(self, segment: str) -> bool: """Return True if this node matches the supplied template segment.""" return segment == self.raw_segment - def conflicts_with(self, segment: str): + def conflicts_with(self, segment: str) -> bool: """Return True if this node conflicts with a given template segment.""" # NOTE(kgriffs): This method assumes that the caller has already @@ -900,11 +896,11 @@ class ConverterDict(UserDict): data: Dict[str, Type[converters.BaseConverter]] - def __setitem__(self, name, converter): + def __setitem__(self, name: str, converter: Type[converters.BaseConverter]) -> None: self._validate(name) UserDict.__setitem__(self, name, converter) - def _validate(self, name): + def _validate(self, name: str) -> None: if not _IDENTIFIER_PATTERN.match(name): raise ValueError( 'Invalid converter name. Names may not be blank, and may ' @@ -948,14 +944,14 @@ class CompiledRouterOptions: __slots__ = ('converters',) - def __init__(self): + def __init__(self) -> None: object.__setattr__( self, 'converters', ConverterDict((name, converter) for name, converter in converters.BUILTIN), ) - def __setattr__(self, name, value) -> None: + def __setattr__(self, name: str, value: Any) -> None: if name == 'converters': raise AttributeError('Cannot set "converters", please update it in place.') super().__setattr__(name, value) @@ -978,13 +974,13 @@ class _CxParent: def __init__(self) -> None: self._children: List[_CxElement] = [] - def append_child(self, construct: _CxElement): + def append_child(self, construct: _CxElement) -> None: self._children.append(construct) def src(self, indentation: int) -> str: return self._children_src(indentation + 1) - def _children_src(self, indentation): + def _children_src(self, indentation: int) -> str: src_lines = [child.src(indentation) for child in self._children] return '\n'.join(src_lines) @@ -997,12 +993,12 @@ def src(self, indentation: int) -> str: class _CxIfPathLength(_CxParent): - def __init__(self, comparison, length): + def __init__(self, comparison: str, length: int) -> None: super().__init__() self._comparison = comparison self._length = length - def src(self, indentation): + def src(self, indentation: int) -> str: template = '{0}if path_len {1} {2}:\n{3}' return template.format( _TAB_STR * indentation, @@ -1013,12 +1009,12 @@ def src(self, indentation): class _CxIfPathSegmentLiteral(_CxParent): - def __init__(self, segment_idx, literal): + def __init__(self, segment_idx: int, literal: str) -> None: super().__init__() self._segment_idx = segment_idx self._literal = literal - def src(self, indentation): + def src(self, indentation: int) -> str: template = "{0}if path[{1}] == '{2}':\n{3}" return template.format( _TAB_STR * indentation, @@ -1029,13 +1025,13 @@ def src(self, indentation): class _CxIfPathSegmentPattern(_CxParent): - def __init__(self, segment_idx, pattern_idx, pattern_text): + def __init__(self, segment_idx: int, pattern_idx: int, pattern_text: str) -> None: super().__init__() self._segment_idx = segment_idx self._pattern_idx = pattern_idx self._pattern_text = pattern_text - def src(self, indentation): + def src(self, indentation: int) -> str: lines = [ '{0}match = patterns[{1}].match(path[{2}]) # {3}'.format( _TAB_STR * indentation, @@ -1051,13 +1047,13 @@ def src(self, indentation): class _CxIfConverterField(_CxParent): - def __init__(self, unique_idx, converter_idx): + def __init__(self, unique_idx: int, converter_idx: int) -> None: super().__init__() self._converter_idx = converter_idx self._unique_idx = unique_idx self.field_variable_name = 'field_value_{0}'.format(unique_idx) - def src(self, indentation): + def src(self, indentation: int) -> str: lines = [ '{0}{1} = converters[{2}].convert(fragment)'.format( _TAB_STR * indentation, @@ -1074,10 +1070,10 @@ def src(self, indentation): class _CxSetFragmentFromField(_CxChild): - def __init__(self, field_name): + def __init__(self, field_name: str) -> None: self._field_name = field_name - def src(self, indentation): + def src(self, indentation: int) -> str: return "{0}fragment = groups.pop('{1}')".format( _TAB_STR * indentation, self._field_name, @@ -1085,10 +1081,10 @@ def src(self, indentation): class _CxSetFragmentFromPath(_CxChild): - def __init__(self, segment_idx): + def __init__(self, segment_idx: int) -> None: self._segment_idx = segment_idx - def src(self, indentation): + def src(self, indentation: int) -> str: return '{0}fragment = path[{1}]'.format( _TAB_STR * indentation, self._segment_idx, @@ -1096,10 +1092,10 @@ def src(self, indentation): class _CxSetFragmentFromRemainingPaths(_CxChild): - def __init__(self, segment_idx): + def __init__(self, segment_idx: int) -> None: self._segment_idx = segment_idx - def src(self, indentation): + def src(self, indentation: int) -> str: return '{0}fragment = path[{1}:]'.format( _TAB_STR * indentation, self._segment_idx, @@ -1107,51 +1103,51 @@ def src(self, indentation): class _CxVariableFromPatternMatch(_CxChild): - def __init__(self, unique_idx): + def __init__(self, unique_idx: int) -> None: self._unique_idx = unique_idx self.dict_variable_name = 'dict_match_{0}'.format(unique_idx) - def src(self, indentation): + def src(self, indentation: int) -> str: return '{0}{1} = match.groupdict()'.format( _TAB_STR * indentation, self.dict_variable_name ) class _CxVariableFromPatternMatchPrefetched(_CxChild): - def __init__(self, unique_idx): + def __init__(self, unique_idx: int) -> None: self._unique_idx = unique_idx self.dict_variable_name = 'dict_groups_{0}'.format(unique_idx) - def src(self, indentation): + def src(self, indentation: int) -> str: return '{0}{1} = groups'.format(_TAB_STR * indentation, self.dict_variable_name) class _CxPrefetchGroupsFromPatternMatch(_CxChild): - def src(self, indentation): + def src(self, indentation: int) -> str: return '{0}groups = match.groupdict()'.format(_TAB_STR * indentation) class _CxReturnNone(_CxChild): - def src(self, indentation): + def src(self, indentation: int) -> str: return '{0}return None'.format(_TAB_STR * indentation) class _CxReturnValue(_CxChild): - def __init__(self, value_idx): + def __init__(self, value_idx: int) -> None: self._value_idx = value_idx - def src(self, indentation): + def src(self, indentation: int) -> str: return '{0}return return_values[{1}]'.format( _TAB_STR * indentation, self._value_idx ) class _CxSetParamFromPath(_CxChild): - def __init__(self, param_name, segment_idx): + def __init__(self, param_name: str, segment_idx: int) -> None: self._param_name = param_name self._segment_idx = segment_idx - def src(self, indentation): + def src(self, indentation: int) -> str: return "{0}params['{1}'] = path[{2}]".format( _TAB_STR * indentation, self._param_name, @@ -1160,11 +1156,11 @@ def src(self, indentation): class _CxSetParamFromValue(_CxChild): - def __init__(self, param_name, field_value_name): + def __init__(self, param_name: str, field_value_name: str) -> None: self._param_name = param_name self._field_value_name = field_value_name - def src(self, indentation): + def src(self, indentation: int) -> str: return "{0}params['{1}'] = {2}".format( _TAB_STR * indentation, self._param_name, @@ -1173,10 +1169,10 @@ def src(self, indentation): class _CxSetParamsFromDict(_CxChild): - def __init__(self, dict_value_name): + def __init__(self, dict_value_name: str) -> None: self._dict_value_name = dict_value_name - def src(self, indentation): + def src(self, indentation: int) -> str: return '{0}params.update({1})'.format( _TAB_STR * indentation, self._dict_value_name, diff --git a/falcon/routing/converters.py b/falcon/routing/converters.py index 2d2bc7fa1..d50d6b85e 100644 --- a/falcon/routing/converters.py +++ b/falcon/routing/converters.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import abc from datetime import datetime from math import isfinite -from typing import Optional +from typing import Any, ClassVar, Iterable, Optional, overload, Union import uuid __all__ = ( @@ -34,7 +35,7 @@ class BaseConverter(metaclass=abc.ABCMeta): """Abstract base class for URI template field converters.""" - CONSUME_MULTIPLE_SEGMENTS = False + CONSUME_MULTIPLE_SEGMENTS: ClassVar[bool] = False """When set to ``True`` it indicates that this converter will consume multiple URL path segments. Currently a converter with ``CONSUME_MULTIPLE_SEGMENTS=True`` must be at the end of the URL template @@ -42,8 +43,8 @@ class BaseConverter(metaclass=abc.ABCMeta): segments. """ - @abc.abstractmethod # pragma: no cover - def convert(self, value): + @abc.abstractmethod + def convert(self, value: str) -> Any: """Convert a URI template field value to another format or type. Args: @@ -76,14 +77,19 @@ class IntConverter(BaseConverter): __slots__ = ('_num_digits', '_min', '_max') - def __init__(self, num_digits=None, min=None, max=None): + def __init__( + self, + num_digits: Optional[int] = None, + min: Optional[int] = None, + max: Optional[int] = None, + ) -> None: if num_digits is not None and num_digits < 1: raise ValueError('num_digits must be at least 1') self._num_digits = num_digits self._min = min self._max = max - def convert(self, value): + def convert(self, value: str) -> Optional[int]: if self._num_digits is not None and len(value) != self._num_digits: return None @@ -96,22 +102,35 @@ def convert(self, value): return None try: - value = int(value) + converted = int(value) except ValueError: return None - return self._validate_min_max_value(value) + return _validate_min_max_value(self, converted) - def _validate_min_max_value(self, value): - if self._min is not None and value < self._min: - return None - if self._max is not None and value > self._max: - return None - return value +@overload +def _validate_min_max_value(converter: IntConverter, value: int) -> Optional[int]: ... + + +@overload +def _validate_min_max_value( + converter: FloatConverter, value: float +) -> Optional[float]: ... + + +def _validate_min_max_value( + converter: Union[IntConverter, FloatConverter], value: Union[int, float] +) -> Optional[Union[int, float]]: + if converter._min is not None and value < converter._min: + return None + if converter._max is not None and value > converter._max: + return None + + return value -class FloatConverter(IntConverter): +class FloatConverter(BaseConverter): """Converts a field value to an float. Identifier: `float` @@ -124,19 +143,19 @@ class FloatConverter(IntConverter): nan, inf, and -inf in addition to finite numbers. """ - __slots__ = '_finite' + __slots__ = '_finite', '_min', '_max' def __init__( self, min: Optional[float] = None, max: Optional[float] = None, finite: bool = True, - ): + ) -> None: self._min = min self._max = max self._finite = finite if finite is not None else True - def convert(self, value: str): + def convert(self, value: str) -> Optional[float]: if value.strip() != value: return None @@ -149,7 +168,7 @@ def convert(self, value: str): except ValueError: return None - return self._validate_min_max_value(converted) + return _validate_min_max_value(self, converted) class DateTimeConverter(BaseConverter): @@ -165,10 +184,10 @@ class DateTimeConverter(BaseConverter): __slots__ = ('_format_string',) - def __init__(self, format_string='%Y-%m-%dT%H:%M:%SZ'): + def __init__(self, format_string: str = '%Y-%m-%dT%H:%M:%SZ') -> None: self._format_string = format_string - def convert(self, value): + def convert(self, value: str) -> Optional[datetime]: try: return strptime(value, self._format_string) except ValueError: @@ -185,7 +204,7 @@ class UUIDConverter(BaseConverter): Note, however, that hyphens and the URN prefix are optional. """ - def convert(self, value): + def convert(self, value: str) -> Optional[uuid.UUID]: try: return uuid.UUID(value) except ValueError: @@ -213,7 +232,7 @@ class PathConverter(BaseConverter): CONSUME_MULTIPLE_SEGMENTS = True - def convert(self, value): + def convert(self, value: Iterable[str]) -> str: return '/'.join(value) diff --git a/falcon/routing/util.py b/falcon/routing/util.py index 3d254acc1..b789b0829 100644 --- a/falcon/routing/util.py +++ b/falcon/routing/util.py @@ -17,22 +17,25 @@ from __future__ import annotations import re -from typing import Callable, Dict, Optional +from typing import Optional, Set, Tuple, TYPE_CHECKING from falcon import constants from falcon import responders from falcon.util.deprecation import deprecated +if TYPE_CHECKING: + from falcon.typing import MethodDict + class SuffixedMethodNotFoundError(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: super(SuffixedMethodNotFoundError, self).__init__(message) self.message = message # NOTE(kgriffs): Published method; take care to avoid breaking changes. @deprecated('This method will be removed in Falcon 4.0.') -def compile_uri_template(template): +def compile_uri_template(template: str) -> Tuple[Set[str], re.Pattern[str]]: """Compile the given URI template string into a pattern matcher. This function can be used to construct custom routing engines that @@ -102,9 +105,7 @@ def compile_uri_template(template): return fields, re.compile(pattern, re.IGNORECASE) -def map_http_methods( - resource: object, suffix: Optional[str] = None -) -> Dict[str, Callable]: +def map_http_methods(resource: object, suffix: Optional[str] = None) -> MethodDict: """Map HTTP methods (e.g., GET, POST) to methods of a resource object. Args: @@ -151,7 +152,7 @@ def map_http_methods( return method_map -def set_default_responders(method_map, asgi=False): +def set_default_responders(method_map: MethodDict, asgi: bool = False) -> None: """Map HTTP methods not explicitly defined on a resource to default responders. Args: @@ -169,11 +170,11 @@ def set_default_responders(method_map, asgi=False): if 'OPTIONS' not in method_map: # OPTIONS itself is intentionally excluded from the Allow header opt_responder = responders.create_default_options(allowed_methods, asgi=asgi) - method_map['OPTIONS'] = opt_responder + method_map['OPTIONS'] = opt_responder # type: ignore[assignment] allowed_methods.append('OPTIONS') na_responder = responders.create_method_not_allowed(allowed_methods, asgi=asgi) for method in constants.COMBINED_METHODS: if method not in method_map: - method_map[method] = na_responder + method_map[method] = na_responder # type: ignore[assignment] diff --git a/pyproject.toml b/pyproject.toml index 679607566..ee74ee31d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,8 +117,6 @@ exclude = ["examples", "tests"] [[tool.mypy.overrides]] module = [ "falcon.media.validators.*", - "falcon.routing.*", - "falcon.routing.converters", "falcon.testing.*", "falcon.vendor.*", ]