Skip to content

Commit

Permalink
Add and implement open_uds_stream in concurrency backends
Browse files Browse the repository at this point in the history
  • Loading branch information
lundberg committed Nov 8, 2019
1 parent 1a32cf0 commit 1e40664
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 18 deletions.
29 changes: 27 additions & 2 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,36 @@ async def open_tcp_stream(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> SocketStream:
return await self._open_stream(
asyncio.open_connection(hostname, port, ssl=ssl_context), timeout
)

async def open_uds_stream(
self,
path: str,
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> SocketStream:
server_hostname = hostname if ssl_context else None
return await self._open_stream(
asyncio.open_unix_connection(
path, ssl=ssl_context, server_hostname=server_hostname
),
timeout,
)

async def _open_stream(
self,
socket_stream: typing.Awaitable[
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]
],
timeout: TimeoutConfig,
) -> SocketStream:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
timeout.connect_timeout,
socket_stream, timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()
Expand Down
9 changes: 9 additions & 0 deletions httpx/concurrency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ async def open_tcp_stream(
) -> BaseSocketStream:
raise NotImplementedError() # pragma: no cover

async def open_uds_stream(
self,
path: str,
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseTCPStream:
raise NotImplementedError() # pragma: no cover

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover

Expand Down
25 changes: 24 additions & 1 deletion httpx/concurrency/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,34 @@ async def open_tcp_stream(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> SocketStream:
return await self._open_stream(
trio.open_tcp_stream(hostname, port), hostname, ssl_context, timeout
)

async def open_uds_stream(
self,
path: str,
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> SocketStream:
hostname = hostname if ssl_context else None
return await self._open_stream(
trio.open_unix_socket(path), hostname, ssl_context, timeout
)

async def _open_stream(
self,
socket_stream: typing.Awaitable[trio.SocketStream],
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> SocketStream:
connect_timeout = _or_inf(timeout.connect_timeout)

with trio.move_on_after(connect_timeout) as cancel_scope:
stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
stream: trio.SocketStream = await socket_stream
if ssl_context is not None:
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
await stream.do_handshake()
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,17 @@ def https_server(cert_pem_file, cert_private_key_file):
)
server = TestServer(config=config)
yield from serve_in_thread(server)


@pytest.fixture(scope=SERVER_SCOPE)
def https_uds_server(cert_pem_file, cert_private_key_file):
config = Config(
app=app,
lifespan="off",
ssl_certfile=cert_pem_file,
ssl_keyfile=cert_private_key_file,
uds="https_test_server.sock",
loop="asyncio",
)
server = TestServer(config=config)
yield from serve_in_thread(server)
40 changes: 25 additions & 15 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,30 @@
from httpx.concurrency.trio import TrioBackend


def get_asyncio_cipher(stream):
return stream.stream_writer.get_extra_info("cipher", default=None)


def get_trio_cipher(stream):
return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None


@pytest.mark.parametrize(
"backend, get_cipher",
"backend, test_uds, get_cipher",
[
pytest.param(
AsyncioBackend(),
lambda stream: stream.stream_writer.get_extra_info("cipher", default=None),
marks=pytest.mark.asyncio,
AsyncioBackend(), False, get_asyncio_cipher, marks=pytest.mark.asyncio
),
pytest.param(
TrioBackend(),
lambda stream: (
stream.stream.cipher()
if isinstance(stream.stream, trio.SSLStream)
else None
),
marks=pytest.mark.trio,
AsyncioBackend(), True, get_asyncio_cipher, marks=pytest.mark.asyncio
),
pytest.param(TrioBackend(), False, get_trio_cipher, marks=pytest.mark.trio),
pytest.param(TrioBackend(), True, get_trio_cipher, marks=pytest.mark.trio),
],
)
async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
async def test_start_tls_on_socket_stream(
https_server, https_uds_server, backend, test_uds, get_cipher
):
"""
See that the concurrency backend can make a connection without TLS then
start TLS on an existing connection.
Expand All @@ -37,9 +41,15 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
timeout = TimeoutConfig(5)

stream = await backend.open_tcp_stream(
https_server.url.host, https_server.url.port, None, timeout
)
if test_uds:
assert https_uds_server.config.uds is not None
stream = await backend.open_uds_stream(
https_uds_server.config.uds, https_uds_server.url.host, None, timeout
)
else:
stream = await backend.open_tcp_stream(
https_server.url.host, https_server.url.port, None, timeout
)

try:
assert stream.is_connection_dropped() is False
Expand Down

0 comments on commit 1e40664

Please sign in to comment.