Skip to content

Commit

Permalink
Use new typing style (#963)
Browse files Browse the repository at this point in the history
* Use new typing style

* Pass all checks
  • Loading branch information
zrquan authored Oct 10, 2024
1 parent 4ee1ca2 commit 127505b
Show file tree
Hide file tree
Showing 25 changed files with 458 additions and 474 deletions.
20 changes: 11 additions & 9 deletions httpcore/_api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import Iterator, Optional, Union
from typing import Iterator

from ._models import URL, Extensions, HeaderTypes, Response
from ._sync.connection_pool import ConnectionPool


def request(
method: Union[bytes, str],
url: Union[URL, bytes, str],
method: bytes | str,
url: URL | bytes | str,
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
content: bytes | Iterator[bytes] | None = None,
extensions: Extensions | None = None,
) -> Response:
"""
Sends an HTTP request, returning the response.
Expand Down Expand Up @@ -47,12 +49,12 @@ def request(

@contextmanager
def stream(
method: Union[bytes, str],
url: Union[URL, bytes, str],
method: bytes | str,
url: URL | bytes | str,
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
content: bytes | Iterator[bytes] | None = None,
extensions: Extensions | None = None,
) -> Iterator[Response]:
"""
Sends an HTTP request, returning the response within a content manager.
Expand Down
26 changes: 14 additions & 12 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import itertools
import logging
import ssl
from types import TracebackType
from typing import Iterable, Iterator, Optional, Type
from typing import Iterable, Iterator

from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
Expand Down Expand Up @@ -37,15 +39,15 @@ class AsyncHTTPConnection(AsyncConnectionInterface):
def __init__(
self,
origin: Origin,
ssl_context: Optional[ssl.SSLContext] = None,
keepalive_expiry: Optional[float] = None,
ssl_context: ssl.SSLContext | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
local_address: str | None = None,
uds: str | None = None,
network_backend: AsyncNetworkBackend | None = None,
socket_options: Iterable[SOCKET_OPTION] | None = None,
) -> None:
self._origin = origin
self._ssl_context = ssl_context
Expand All @@ -59,7 +61,7 @@ def __init__(
self._network_backend: AsyncNetworkBackend = (
AutoBackend() if network_backend is None else network_backend
)
self._connection: Optional[AsyncConnectionInterface] = None
self._connection: AsyncConnectionInterface | None = None
self._connect_failed: bool = False
self._request_lock = AsyncLock()
self._socket_options = socket_options
Expand Down Expand Up @@ -208,13 +210,13 @@ def __repr__(self) -> str:
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.

async def __aenter__(self) -> "AsyncHTTPConnection":
async def __aenter__(self) -> AsyncHTTPConnection:
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
await self.aclose()
46 changes: 23 additions & 23 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import ssl
import sys
from types import TracebackType
from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type
from typing import AsyncIterable, AsyncIterator, Iterable

from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
Expand All @@ -15,12 +17,10 @@
class AsyncPoolRequest:
def __init__(self, request: Request) -> None:
self.request = request
self.connection: Optional[AsyncConnectionInterface] = None
self.connection: AsyncConnectionInterface | None = None
self._connection_acquired = AsyncEvent()

def assign_to_connection(
self, connection: Optional[AsyncConnectionInterface]
) -> None:
def assign_to_connection(self, connection: AsyncConnectionInterface | None) -> None:
self.connection = connection
self._connection_acquired.set()

Expand All @@ -29,7 +29,7 @@ def clear_connection(self) -> None:
self._connection_acquired = AsyncEvent()

async def wait_for_connection(
self, timeout: Optional[float] = None
self, timeout: float | None = None
) -> AsyncConnectionInterface:
if self.connection is None:
await self._connection_acquired.wait(timeout=timeout)
Expand All @@ -47,17 +47,17 @@ class AsyncConnectionPool(AsyncRequestInterface):

def __init__(
self,
ssl_context: Optional[ssl.SSLContext] = None,
max_connections: Optional[int] = 10,
max_keepalive_connections: Optional[int] = None,
keepalive_expiry: Optional[float] = None,
ssl_context: ssl.SSLContext | None = None,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
local_address: str | None = None,
uds: str | None = None,
network_backend: AsyncNetworkBackend | None = None,
socket_options: Iterable[SOCKET_OPTION] | None = None,
) -> None:
"""
A connection pool for making HTTP requests.
Expand Down Expand Up @@ -116,8 +116,8 @@ def __init__(

# The mutable state on a connection pool is the queue of incoming requests,
# and the set of connections that are servicing those requests.
self._connections: List[AsyncConnectionInterface] = []
self._requests: List[AsyncPoolRequest] = []
self._connections: list[AsyncConnectionInterface] = []
self._requests: list[AsyncPoolRequest] = []

# We only mutate the state of the connection pool within an 'optional_thread_lock'
# context. This holds a threading lock unless we're running in async mode,
Expand All @@ -139,7 +139,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
)

@property
def connections(self) -> List[AsyncConnectionInterface]:
def connections(self) -> list[AsyncConnectionInterface]:
"""
Return a list of the connections currently in the pool.
Expand Down Expand Up @@ -227,7 +227,7 @@ async def handle_async_request(self, request: Request) -> Response:
extensions=response.extensions,
)

def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]:
def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]:
"""
Manage the state of the connection pool, assigning incoming
requests to connections as available.
Expand Down Expand Up @@ -298,7 +298,7 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]:

return closing_connections

async def _close_connections(self, closing: List[AsyncConnectionInterface]) -> None:
async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None:
# Close connections which have been removed from the pool.
with AsyncShieldCancellation():
for connection in closing:
Expand All @@ -312,14 +312,14 @@ async def aclose(self) -> None:
self._connections = []
await self._close_connections(closing_connections)

async def __aenter__(self) -> "AsyncConnectionPool":
async def __aenter__(self) -> AsyncConnectionPool:
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
await self.aclose()

Expand Down
45 changes: 18 additions & 27 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
from __future__ import annotations

import enum
import logging
import ssl
import time
from types import TracebackType
from typing import (
Any,
AsyncIterable,
AsyncIterator,
List,
Optional,
Tuple,
Type,
Union,
)
from typing import Any, AsyncIterable, AsyncIterator, Union

import h11

Expand Down Expand Up @@ -55,12 +48,12 @@ def __init__(
self,
origin: Origin,
stream: AsyncNetworkStream,
keepalive_expiry: Optional[float] = None,
keepalive_expiry: float | None = None,
) -> None:
self._origin = origin
self._network_stream = stream
self._keepalive_expiry: Optional[float] = keepalive_expiry
self._expire_at: Optional[float] = None
self._keepalive_expiry: float | None = keepalive_expiry
self._expire_at: float | None = None
self._state = HTTPConnectionState.NEW
self._state_lock = AsyncLock()
self._request_count = 0
Expand Down Expand Up @@ -167,9 +160,7 @@ async def _send_request_body(self, request: Request) -> None:

await self._send_event(h11.EndOfMessage(), timeout=timeout)

async def _send_event(
self, event: h11.Event, timeout: Optional[float] = None
) -> None:
async def _send_event(self, event: h11.Event, timeout: float | None = None) -> None:
bytes_to_send = self._h11_state.send(event)
if bytes_to_send is not None:
await self._network_stream.write(bytes_to_send, timeout=timeout)
Expand All @@ -178,7 +169,7 @@ async def _send_event(

async def _receive_response_headers(
self, request: Request
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
) -> tuple[bytes, int, bytes, list[tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

Expand Down Expand Up @@ -214,8 +205,8 @@ async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]
break

async def _receive_event(
self, timeout: Optional[float] = None
) -> Union[h11.Event, Type[h11.PAUSED]]:
self, timeout: float | None = None
) -> h11.Event | type[h11.PAUSED]:
while True:
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
event = self._h11_state.next_event()
Expand Down Expand Up @@ -316,14 +307,14 @@ def __repr__(self) -> str:
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.

async def __aenter__(self) -> "AsyncHTTP11Connection":
async def __aenter__(self) -> AsyncHTTP11Connection:
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
await self.aclose()

Expand Down Expand Up @@ -360,15 +351,15 @@ def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None:
self._stream = stream
self._leading_data = leading_data

async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
if self._leading_data:
buffer = self._leading_data[:max_bytes]
self._leading_data = self._leading_data[max_bytes:]
return buffer
else:
return await self._stream.read(max_bytes, timeout)

async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
await self._stream.write(buffer, timeout)

async def aclose(self) -> None:
Expand All @@ -377,8 +368,8 @@ async def aclose(self) -> None:
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
server_hostname: str | None = None,
timeout: float | None = None,
) -> AsyncNetworkStream:
return await self._stream.start_tls(ssl_context, server_hostname, timeout)

Expand Down
Loading

0 comments on commit 127505b

Please sign in to comment.