Skip to content

Commit

Permalink
typing cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Dec 18, 2023
1 parent 25e0a30 commit 4d2f7cf
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import os
import pickle
import time
import typing
import warnings
from contextvars import ContextVar
from dataclasses import dataclass
from datetime import date
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterable,
Expand All @@ -23,6 +23,7 @@
List,
Optional,
Set,
Type,
Union,
)

Expand Down Expand Up @@ -79,7 +80,7 @@
ATTEMPT_TIMEOUT_GRACE_PERIOD = 5 # seconds


if typing.TYPE_CHECKING:
if TYPE_CHECKING:
import modal.stub


Expand Down Expand Up @@ -1454,7 +1455,7 @@ def __del__(self):
" Did you forget a @stub.function or @stub.cls decorator?"
)

def stack(self, flags) -> "_PartialFunction":
def add_flags(self, flags) -> "_PartialFunction":
# Helper method used internally when stacking decorators
return _PartialFunction(
raw_f=self.raw_f,
Expand Down Expand Up @@ -1709,7 +1710,7 @@ def _build(_warn_parentheses_missing=None) -> Callable[[Union[Callable[[], Any],

def wrapper(f: Union[Callable[[], Any], _PartialFunction]) -> _PartialFunction:
if isinstance(f, _PartialFunction):
return f.stack(_PartialFunctionFlags.BUILD)
return f.add_flags(_PartialFunctionFlags.BUILD)
else:
return _PartialFunction(f, _PartialFunctionFlags.BUILD)

Expand All @@ -1721,22 +1722,22 @@ def _enter(_warn_parentheses_missing=None) -> Callable[[Union[Callable[[], Any],

def wrapper(f: Union[Callable[[], Any], _PartialFunction]) -> _PartialFunction:
if isinstance(f, _PartialFunction):
return f.stack(_PartialFunctionFlags.ENTER)
return f.add_flags(_PartialFunctionFlags.ENTER)
else:
return _PartialFunction(f, _PartialFunctionFlags.ENTER)


# TODO(erikbern): annotate this with the right argument types
# TODO(erikbern): last argument should be Optional[TracebackType]
ExitHandlerType = Callable[[Optional[Type[BaseException]], Optional[BaseException], Any], None]


@typechecked
def _exit(_warn_parentheses_missing=None) -> Callable[[Union[Callable[..., Any], _PartialFunction]], _PartialFunction]:
def _exit(_warn_parentheses_missing=None) -> Callable[[ExitHandlerType], _PartialFunction]:
if _warn_parentheses_missing:
raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@exit()`.")

def wrapper(f: Union[Callable[..., Any], _PartialFunction]) -> _PartialFunction:
if isinstance(f, _PartialFunction):
return f.stack(_PartialFunctionFlags.EXIT)
else:
return _PartialFunction(f, _PartialFunctionFlags.EXIT)
def wrapper(f: ExitHandlerType) -> _PartialFunction:
return _PartialFunction(f, _PartialFunctionFlags.EXIT)


method = synchronize_api(_method)
Expand Down

0 comments on commit 4d2f7cf

Please sign in to comment.