Skip to content

Commit

Permalink
Rename BaseTCPStream/TCPStream to BaseSocketStream/SocketStream (#517)
Browse files Browse the repository at this point in the history
  • Loading branch information
lundberg authored and florimondmanca committed Nov 8, 2019
1 parent 586acdd commit 1a32cf0
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 32 deletions.
4 changes: 2 additions & 2 deletions httpx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .concurrency.base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseTCPStream,
BaseSocketStream,
ConcurrencyBackend,
)
from .config import (
Expand Down Expand Up @@ -114,7 +114,7 @@
"TooManyRedirects",
"WriteTimeout",
"AsyncDispatcher",
"BaseTCPStream",
"BaseSocketStream",
"ConcurrencyBackend",
"Dispatcher",
"URL",
Expand Down
16 changes: 8 additions & 8 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
BaseEvent,
BasePoolSemaphore,
BaseQueue,
BaseTCPStream,
BaseSocketStream,
ConcurrencyBackend,
TimeoutFlag,
)
Expand Down Expand Up @@ -41,7 +41,7 @@ def _fixed_write(self, data: bytes) -> None: # type: ignore
MonkeyPatch.write = _fixed_write


class TCPStream(BaseTCPStream):
class SocketStream(BaseSocketStream):
def __init__(
self,
stream_reader: asyncio.StreamReader,
Expand All @@ -52,11 +52,11 @@ def __init__(
self.stream_writer = stream_writer
self.timeout = timeout

self._inner: typing.Optional[TCPStream] = None
self._inner: typing.Optional[SocketStream] = None

async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> BaseTCPStream:
) -> "SocketStream":
loop = asyncio.get_event_loop()
if not hasattr(loop, "start_tls"): # pragma: no cover
raise NotImplementedError(
Expand All @@ -83,8 +83,8 @@ async def start_tls(
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
)

ssl_stream = TCPStream(stream_reader, stream_writer, self.timeout)
# When we return a new TCPStream with new StreamReader/StreamWriter instances,
ssl_stream = SocketStream(stream_reader, stream_writer, self.timeout)
# When we return a new SocketStream with new StreamReader/StreamWriter instances
# we need to keep references to the old StreamReader/StreamWriter so that they
# are not garbage collected and closed while we're still using them.
ssl_stream._inner = self
Expand Down Expand Up @@ -229,7 +229,7 @@ async def open_tcp_stream(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseTCPStream:
) -> SocketStream:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
Expand All @@ -238,7 +238,7 @@ async def open_tcp_stream(
except asyncio.TimeoutError:
raise ConnectTimeout()

return TCPStream(
return SocketStream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)

Expand Down
8 changes: 4 additions & 4 deletions httpx/concurrency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def set_write_timeouts(self) -> None:
self.raise_on_write_timeout = True


class BaseTCPStream:
class BaseSocketStream:
"""
A TCP stream with read/write operations. Abstracts away any asyncio-specific
A socket stream with read/write operations. Abstracts away any asyncio-specific
interfaces into a more generic base class, that we can use with alternate
backends, or for stand-alone test cases.
"""
Expand All @@ -49,7 +49,7 @@ def get_http_version(self) -> str:

async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> "BaseTCPStream":
) -> "BaseSocketStream":
raise NotImplementedError() # pragma: no cover

async def read(
Expand Down Expand Up @@ -121,7 +121,7 @@ async def open_tcp_stream(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseTCPStream:
) -> BaseSocketStream:
raise NotImplementedError() # pragma: no cover

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
Expand Down
12 changes: 6 additions & 6 deletions httpx/concurrency/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
BaseEvent,
BasePoolSemaphore,
BaseQueue,
BaseTCPStream,
BaseSocketStream,
ConcurrencyBackend,
TimeoutFlag,
)
Expand All @@ -23,7 +23,7 @@ def _or_inf(value: typing.Optional[float]) -> float:
return value if value is not None else float("inf")


class TCPStream(BaseTCPStream):
class SocketStream(BaseSocketStream):
def __init__(
self,
stream: typing.Union[trio.SocketStream, trio.SSLStream],
Expand All @@ -36,7 +36,7 @@ def __init__(

async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> BaseTCPStream:
) -> "SocketStream":
# Check that the write buffer is empty. We should never start a TLS stream
# while there is still pending data to write.
assert self.write_buffer == b""
Expand All @@ -52,7 +52,7 @@ async def start_tls(
if cancel_scope.cancelled_caught:
raise ConnectTimeout()

return TCPStream(ssl_stream, self.timeout)
return SocketStream(ssl_stream, self.timeout)

def get_http_version(self) -> str:
if not isinstance(self.stream, trio.SSLStream):
Expand Down Expand Up @@ -177,7 +177,7 @@ async def open_tcp_stream(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> TCPStream:
) -> SocketStream:
connect_timeout = _or_inf(timeout.connect_timeout)

with trio.move_on_after(connect_timeout) as cancel_scope:
Expand All @@ -189,7 +189,7 @@ async def open_tcp_stream(
if cancel_scope.cancelled_caught:
raise ConnectTimeout()

return TCPStream(stream=stream, timeout=timeout)
return SocketStream(stream=stream, timeout=timeout)

async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
Expand Down
4 changes: 2 additions & 2 deletions httpx/dispatch/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import h11

from ..concurrency.base import BaseTCPStream, ConcurrencyBackend, TimeoutFlag
from ..concurrency.base import BaseSocketStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse
from ..utils import get_logger
Expand Down Expand Up @@ -31,7 +31,7 @@ class HTTP11Connection:

def __init__(
self,
stream: BaseTCPStream,
stream: BaseSocketStream,
backend: ConcurrencyBackend,
on_release: typing.Optional[OnReleaseCallback] = None,
):
Expand Down
9 changes: 7 additions & 2 deletions httpx/dispatch/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import h2.events
from h2.settings import SettingCodes, Settings

from ..concurrency.base import BaseEvent, BaseTCPStream, ConcurrencyBackend, TimeoutFlag
from ..concurrency.base import (
BaseEvent,
BaseSocketStream,
ConcurrencyBackend,
TimeoutFlag,
)
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import ProtocolError
from ..models import AsyncRequest, AsyncResponse
Expand All @@ -19,7 +24,7 @@ class HTTP2Connection:

def __init__(
self,
stream: BaseTCPStream,
stream: BaseSocketStream,
backend: ConcurrencyBackend,
on_release: typing.Callable = None,
):
Expand Down
2 changes: 1 addition & 1 deletion httpx/dispatch/proxy_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async def tunnel_start_tls(
stream = http_connection.stream

# If we need to start TLS again for the target server
# we need to pull the TCP stream off the internal
# we need to pull the socket stream off the internal
# HTTP connection object and run start_tls()
if origin.is_ssl:
ssl_config = SSLConfig(cert=self.cert, verify=self.verify)
Expand Down
14 changes: 7 additions & 7 deletions tests/dispatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import h2.connection
import h2.events

from httpx import AsyncioBackend, BaseTCPStream, Request, TimeoutConfig
from httpx import AsyncioBackend, BaseSocketStream, Request, TimeoutConfig
from tests.concurrency import sleep


Expand All @@ -21,7 +21,7 @@ async def open_tcp_stream(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseTCPStream:
) -> BaseSocketStream:
self.server = MockHTTP2Server(self.app, backend=self.backend)
return self.server

Expand All @@ -30,7 +30,7 @@ def __getattr__(self, name: str) -> typing.Any:
return getattr(self.backend, name)


class MockHTTP2Server(BaseTCPStream):
class MockHTTP2Server(BaseSocketStream):
def __init__(self, app, backend):
config = h2.config.H2Configuration(client_side=False)
self.conn = h2.connection.H2Connection(config=config)
Expand All @@ -43,7 +43,7 @@ def __init__(self, app, backend):
self.returning = {}
self.settings_changed = []

# TCP stream interface
# Socket stream interface

def get_http_version(self) -> str:
return "HTTP/2"
Expand Down Expand Up @@ -178,7 +178,7 @@ async def open_tcp_stream(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseTCPStream:
) -> BaseSocketStream:
self.received_data.append(
b"--- CONNECT(%s, %d) ---" % (hostname.encode(), port)
)
Expand All @@ -189,13 +189,13 @@ def __getattr__(self, name: str) -> typing.Any:
return getattr(self.backend, name)


class MockRawSocketStream(BaseTCPStream):
class MockRawSocketStream(BaseSocketStream):
def __init__(self, backend: MockRawSocketBackend):
self.backend = backend

async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> BaseTCPStream:
) -> BaseSocketStream:
self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
return MockRawSocketStream(self.backend)

Expand Down

0 comments on commit 1a32cf0

Please sign in to comment.