From 1a32cf036a825f6eb35395af5388a3b23180a82e Mon Sep 17 00:00:00 2001 From: Jonas Lundberg Date: Fri, 8 Nov 2019 17:09:38 +0100 Subject: [PATCH] Rename BaseTCPStream/TCPStream to BaseSocketStream/SocketStream (#517) --- httpx/__init__.py | 4 ++-- httpx/concurrency/asyncio.py | 16 ++++++++-------- httpx/concurrency/base.py | 8 ++++---- httpx/concurrency/trio.py | 12 ++++++------ httpx/dispatch/http11.py | 4 ++-- httpx/dispatch/http2.py | 9 +++++++-- httpx/dispatch/proxy_http.py | 2 +- tests/dispatch/utils.py | 14 +++++++------- 8 files changed, 37 insertions(+), 32 deletions(-) diff --git a/httpx/__init__.py b/httpx/__init__.py index 6859964f8a..5eb85162c5 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -5,7 +5,7 @@ from .concurrency.base import ( BaseBackgroundManager, BasePoolSemaphore, - BaseTCPStream, + BaseSocketStream, ConcurrencyBackend, ) from .config import ( @@ -114,7 +114,7 @@ "TooManyRedirects", "WriteTimeout", "AsyncDispatcher", - "BaseTCPStream", + "BaseSocketStream", "ConcurrencyBackend", "Dispatcher", "URL", diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 4aeb7ca53d..010d8215a8 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -12,7 +12,7 @@ BaseEvent, BasePoolSemaphore, BaseQueue, - BaseTCPStream, + BaseSocketStream, ConcurrencyBackend, TimeoutFlag, ) @@ -41,7 +41,7 @@ def _fixed_write(self, data: bytes) -> None: # type: ignore MonkeyPatch.write = _fixed_write -class TCPStream(BaseTCPStream): +class SocketStream(BaseSocketStream): def __init__( self, stream_reader: asyncio.StreamReader, @@ -52,11 +52,11 @@ def __init__( self.stream_writer = stream_writer self.timeout = timeout - self._inner: typing.Optional[TCPStream] = None + self._inner: typing.Optional[SocketStream] = None async def start_tls( self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig - ) -> BaseTCPStream: + ) -> "SocketStream": loop = asyncio.get_event_loop() if not hasattr(loop, "start_tls"): # pragma: no cover raise NotImplementedError( @@ -83,8 +83,8 @@ async def start_tls( transport=transport, protocol=protocol, reader=stream_reader, loop=loop ) - ssl_stream = TCPStream(stream_reader, stream_writer, self.timeout) - # When we return a new TCPStream with new StreamReader/StreamWriter instances, + ssl_stream = SocketStream(stream_reader, stream_writer, self.timeout) + # When we return a new SocketStream with new StreamReader/StreamWriter instances # we need to keep references to the old StreamReader/StreamWriter so that they # are not garbage collected and closed while we're still using them. ssl_stream._inner = self @@ -229,7 +229,7 @@ async def open_tcp_stream( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> BaseTCPStream: + ) -> SocketStream: try: stream_reader, stream_writer = await asyncio.wait_for( # type: ignore asyncio.open_connection(hostname, port, ssl=ssl_context), @@ -238,7 +238,7 @@ async def open_tcp_stream( except asyncio.TimeoutError: raise ConnectTimeout() - return TCPStream( + return SocketStream( stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout ) diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index a23d89bd30..9d5bffde3e 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -37,9 +37,9 @@ def set_write_timeouts(self) -> None: self.raise_on_write_timeout = True -class BaseTCPStream: +class BaseSocketStream: """ - A TCP stream with read/write operations. Abstracts away any asyncio-specific + A socket stream with read/write operations. Abstracts away any asyncio-specific interfaces into a more generic base class, that we can use with alternate backends, or for stand-alone test cases. """ @@ -49,7 +49,7 @@ def get_http_version(self) -> str: async def start_tls( self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig - ) -> "BaseTCPStream": + ) -> "BaseSocketStream": raise NotImplementedError() # pragma: no cover async def read( @@ -121,7 +121,7 @@ async def open_tcp_stream( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> BaseTCPStream: + ) -> BaseSocketStream: raise NotImplementedError() # pragma: no cover def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index da8e38a0ef..5d3b50dfbb 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -13,7 +13,7 @@ BaseEvent, BasePoolSemaphore, BaseQueue, - BaseTCPStream, + BaseSocketStream, ConcurrencyBackend, TimeoutFlag, ) @@ -23,7 +23,7 @@ def _or_inf(value: typing.Optional[float]) -> float: return value if value is not None else float("inf") -class TCPStream(BaseTCPStream): +class SocketStream(BaseSocketStream): def __init__( self, stream: typing.Union[trio.SocketStream, trio.SSLStream], @@ -36,7 +36,7 @@ def __init__( async def start_tls( self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig - ) -> BaseTCPStream: + ) -> "SocketStream": # Check that the write buffer is empty. We should never start a TLS stream # while there is still pending data to write. assert self.write_buffer == b"" @@ -52,7 +52,7 @@ async def start_tls( if cancel_scope.cancelled_caught: raise ConnectTimeout() - return TCPStream(ssl_stream, self.timeout) + return SocketStream(ssl_stream, self.timeout) def get_http_version(self) -> str: if not isinstance(self.stream, trio.SSLStream): @@ -177,7 +177,7 @@ async def open_tcp_stream( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> TCPStream: + ) -> SocketStream: connect_timeout = _or_inf(timeout.connect_timeout) with trio.move_on_after(connect_timeout) as cancel_scope: @@ -189,7 +189,7 @@ async def open_tcp_stream( if cancel_scope.cancelled_caught: raise ConnectTimeout() - return TCPStream(stream=stream, timeout=timeout) + return SocketStream(stream=stream, timeout=timeout) async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any diff --git a/httpx/dispatch/http11.py b/httpx/dispatch/http11.py index fba58f731f..b1781bffa4 100644 --- a/httpx/dispatch/http11.py +++ b/httpx/dispatch/http11.py @@ -2,7 +2,7 @@ import h11 -from ..concurrency.base import BaseTCPStream, ConcurrencyBackend, TimeoutFlag +from ..concurrency.base import BaseSocketStream, ConcurrencyBackend, TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes from ..models import AsyncRequest, AsyncResponse from ..utils import get_logger @@ -31,7 +31,7 @@ class HTTP11Connection: def __init__( self, - stream: BaseTCPStream, + stream: BaseSocketStream, backend: ConcurrencyBackend, on_release: typing.Optional[OnReleaseCallback] = None, ): diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index c76f99f7ce..5c6643103d 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -5,7 +5,12 @@ import h2.events from h2.settings import SettingCodes, Settings -from ..concurrency.base import BaseEvent, BaseTCPStream, ConcurrencyBackend, TimeoutFlag +from ..concurrency.base import ( + BaseEvent, + BaseSocketStream, + ConcurrencyBackend, + TimeoutFlag, +) from ..config import TimeoutConfig, TimeoutTypes from ..exceptions import ProtocolError from ..models import AsyncRequest, AsyncResponse @@ -19,7 +24,7 @@ class HTTP2Connection: def __init__( self, - stream: BaseTCPStream, + stream: BaseSocketStream, backend: ConcurrencyBackend, on_release: typing.Callable = None, ): diff --git a/httpx/dispatch/proxy_http.py b/httpx/dispatch/proxy_http.py index f6d52e58a6..8ad0ca8597 100644 --- a/httpx/dispatch/proxy_http.py +++ b/httpx/dispatch/proxy_http.py @@ -179,7 +179,7 @@ async def tunnel_start_tls( stream = http_connection.stream # If we need to start TLS again for the target server - # we need to pull the TCP stream off the internal + # we need to pull the socket stream off the internal # HTTP connection object and run start_tls() if origin.is_ssl: ssl_config = SSLConfig(cert=self.cert, verify=self.verify) diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index 3b0d534000..e2916c7b00 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -5,7 +5,7 @@ import h2.connection import h2.events -from httpx import AsyncioBackend, BaseTCPStream, Request, TimeoutConfig +from httpx import AsyncioBackend, BaseSocketStream, Request, TimeoutConfig from tests.concurrency import sleep @@ -21,7 +21,7 @@ async def open_tcp_stream( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> BaseTCPStream: + ) -> BaseSocketStream: self.server = MockHTTP2Server(self.app, backend=self.backend) return self.server @@ -30,7 +30,7 @@ def __getattr__(self, name: str) -> typing.Any: return getattr(self.backend, name) -class MockHTTP2Server(BaseTCPStream): +class MockHTTP2Server(BaseSocketStream): def __init__(self, app, backend): config = h2.config.H2Configuration(client_side=False) self.conn = h2.connection.H2Connection(config=config) @@ -43,7 +43,7 @@ def __init__(self, app, backend): self.returning = {} self.settings_changed = [] - # TCP stream interface + # Socket stream interface def get_http_version(self) -> str: return "HTTP/2" @@ -178,7 +178,7 @@ async def open_tcp_stream( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> BaseTCPStream: + ) -> BaseSocketStream: self.received_data.append( b"--- CONNECT(%s, %d) ---" % (hostname.encode(), port) ) @@ -189,13 +189,13 @@ def __getattr__(self, name: str) -> typing.Any: return getattr(self.backend, name) -class MockRawSocketStream(BaseTCPStream): +class MockRawSocketStream(BaseSocketStream): def __init__(self, backend: MockRawSocketBackend): self.backend = backend async def start_tls( self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig - ) -> BaseTCPStream: + ) -> BaseSocketStream: self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode()) return MockRawSocketStream(self.backend)