diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 4aeb7ca53d..87e5944461 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -229,11 +229,36 @@ async def open_tcp_stream( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, + ) -> BaseTCPStream: + return await self._open_stream( + asyncio.open_connection(hostname, port, ssl=ssl_context), timeout + ) + + async def open_uds_stream( + self, + path: str, + hostname: typing.Optional[str], + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, + ) -> BaseTCPStream: + server_hostname = hostname if ssl_context else None + return await self._open_stream( + asyncio.open_unix_connection( + path, ssl=ssl_context, server_hostname=server_hostname + ), + timeout, + ) + + async def _open_stream( + self, + socket_stream: typing.Awaitable[ + typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter] + ], + timeout: TimeoutConfig, ) -> BaseTCPStream: try: stream_reader, stream_writer = await asyncio.wait_for( # type: ignore - asyncio.open_connection(hostname, port, ssl=ssl_context), - timeout.connect_timeout, + socket_stream, timeout.connect_timeout, ) except asyncio.TimeoutError: raise ConnectTimeout() diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index a23d89bd30..9abf126ea0 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -124,6 +124,15 @@ async def open_tcp_stream( ) -> BaseTCPStream: raise NotImplementedError() # pragma: no cover + async def open_uds_stream( + self, + path: str, + hostname: typing.Optional[str], + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, + ) -> BaseTCPStream: + raise NotImplementedError() # pragma: no cover + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: raise NotImplementedError() # pragma: no cover diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index da8e38a0ef..0aa8ebd908 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -177,11 +177,34 @@ async def open_tcp_stream( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, + ) -> TCPStream: + return await self._open_stream( + trio.open_tcp_stream(hostname, port), hostname, ssl_context, timeout + ) + + async def open_uds_stream( + self, + path: str, + hostname: typing.Optional[str], + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, + ) -> BaseTCPStream: + hostname = hostname if ssl_context else None + return await self._open_stream( + trio.open_unix_socket(path), hostname, ssl_context, timeout + ) + + async def _open_stream( + self, + socket_stream: typing.Awaitable[trio.SocketStream], + hostname: typing.Optional[str], + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, ) -> TCPStream: connect_timeout = _or_inf(timeout.connect_timeout) with trio.move_on_after(connect_timeout) as cancel_scope: - stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port) + stream: trio.SocketStream = await socket_stream if ssl_context is not None: stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname) await stream.do_handshake() diff --git a/tests/conftest.py b/tests/conftest.py index c6968894d5..a2d6d67dd6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -288,3 +288,17 @@ def https_server(cert_pem_file, cert_private_key_file): ) server = TestServer(config=config) yield from serve_in_thread(server) + + +@pytest.fixture(scope=SERVER_SCOPE) +def https_uds_server(cert_pem_file, cert_private_key_file): + config = Config( + app=app, + lifespan="off", + ssl_certfile=cert_pem_file, + ssl_keyfile=cert_private_key_file, + uds="https_test_server.sock", + loop="asyncio", + ) + server = TestServer(config=config) + yield from serve_in_thread(server) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 8bb933b697..b058ff3c45 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -7,26 +7,30 @@ from httpx.concurrency.trio import TrioBackend +def get_asyncio_cipher(stream): + return stream.stream_writer.get_extra_info("cipher", default=None) + + +def get_trio_cipher(stream): + return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None + + @pytest.mark.parametrize( - "backend, get_cipher", + "backend, test_uds, get_cipher", [ pytest.param( - AsyncioBackend(), - lambda stream: stream.stream_writer.get_extra_info("cipher", default=None), - marks=pytest.mark.asyncio, + AsyncioBackend(), False, get_asyncio_cipher, marks=pytest.mark.asyncio ), pytest.param( - TrioBackend(), - lambda stream: ( - stream.stream.cipher() - if isinstance(stream.stream, trio.SSLStream) - else None - ), - marks=pytest.mark.trio, + AsyncioBackend(), True, get_asyncio_cipher, marks=pytest.mark.asyncio ), + pytest.param(TrioBackend(), False, get_trio_cipher, marks=pytest.mark.trio), + pytest.param(TrioBackend(), True, get_trio_cipher, marks=pytest.mark.trio), ], ) -async def test_start_tls_on_socket_stream(https_server, backend, get_cipher): +async def test_start_tls_on_socket_stream( + https_server, https_uds_server, backend, test_uds, get_cipher +): """ See that the concurrency backend can make a connection without TLS then start TLS on an existing connection. @@ -37,9 +41,15 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher): ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig()) timeout = TimeoutConfig(5) - stream = await backend.open_tcp_stream( - https_server.url.host, https_server.url.port, None, timeout - ) + if test_uds: + assert https_uds_server.config.uds is not None + stream = await backend.open_uds_stream( + https_uds_server.config.uds, https_uds_server.url.host, None, timeout + ) + else: + stream = await backend.open_tcp_stream( + https_server.url.host, https_server.url.port, None, timeout + ) try: assert stream.is_connection_dropped() is False