diff --git a/httpx/client.py b/httpx/client.py index 45589b9e23..c327c2db00 100644 --- a/httpx/client.py +++ b/httpx/client.py @@ -74,6 +74,7 @@ def __init__( app: typing.Callable = None, backend: ConcurrencyBackend = None, trust_env: bool = True, + uds: str = None, ): if backend is None: backend = AsyncioBackend() @@ -99,6 +100,7 @@ def __init__( pool_limits=pool_limits, backend=backend, trust_env=self.trust_env, + uds=uds, ) elif isinstance(dispatch, Dispatcher): async_dispatch = ThreadedDispatcher(dispatch, backend) @@ -721,6 +723,7 @@ class Client(BaseClient): async requests. * **trust_env** - *(optional)* Enables or disables usage of environment variables for configuration. + * **uds** - *(optional)* A path to a Unix domain socket to connect through. """ def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None: diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 019876e43e..a0163ed021 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -275,6 +275,29 @@ async def open_tcp_stream( stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout ) + async def open_uds_stream( + self, + path: str, + hostname: typing.Optional[str], + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, + ) -> SocketStream: + server_hostname = hostname if ssl_context else None + + try: + stream_reader, stream_writer = await asyncio.wait_for( # type: ignore + asyncio.open_unix_connection( + path, ssl=ssl_context, server_hostname=server_hostname + ), + timeout.connect_timeout, + ) + except asyncio.TimeoutError: + raise ConnectTimeout() + + return SocketStream( + stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout + ) + async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index 9d5bffde3e..2109c2121a 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -124,6 +124,15 @@ async def open_tcp_stream( ) -> BaseSocketStream: raise NotImplementedError() # pragma: no cover + async def open_uds_stream( + self, + path: str, + hostname: typing.Optional[str], + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, + ) -> BaseSocketStream: + raise NotImplementedError() # pragma: no cover + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: raise NotImplementedError() # pragma: no cover diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index 5d3b50dfbb..c84b1e4b8b 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -191,6 +191,26 @@ async def open_tcp_stream( return SocketStream(stream=stream, timeout=timeout) + async def open_uds_stream( + self, + path: str, + hostname: typing.Optional[str], + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, + ) -> SocketStream: + connect_timeout = _or_inf(timeout.connect_timeout) + + with trio.move_on_after(connect_timeout) as cancel_scope: + stream: trio.SocketStream = await trio.open_unix_socket(path) + if ssl_context is not None: + stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname) + await stream.do_handshake() + + if cancel_scope.cancelled_caught: + raise ConnectTimeout() + + return SocketStream(stream=stream, timeout=timeout) + async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index 91feb97676..0612bccb86 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -38,6 +38,7 @@ def __init__( http_versions: HTTPVersionTypes = None, backend: ConcurrencyBackend = None, release_func: typing.Optional[ReleaseCallback] = None, + uds: typing.Optional[str] = None, ): self.origin = Origin(origin) if isinstance(origin, str) else origin self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env) @@ -45,6 +46,7 @@ def __init__( self.http_versions = HTTPVersionConfig(http_versions) self.backend = AsyncioBackend() if backend is None else backend self.release_func = release_func + self.uds = uds self.h11_connection = None # type: typing.Optional[HTTP11Connection] self.h2_connection = None # type: typing.Optional[HTTP2Connection] @@ -84,8 +86,21 @@ async def connect( else: on_release = functools.partial(self.release_func, self) - logger.trace(f"start_connect host={host!r} port={port!r} timeout={timeout!r}") - stream = await self.backend.open_tcp_stream(host, port, ssl_context, timeout) + if self.uds is None: + logger.trace( + f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}" + ) + stream = await self.backend.open_tcp_stream( + host, port, ssl_context, timeout + ) + else: + logger.trace( + f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}" + ) + stream = await self.backend.open_uds_stream( + self.uds, host, ssl_context, timeout + ) + http_version = stream.get_http_version() logger.trace(f"connected http_version={http_version!r}") diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index 189fcff699..5d7d886dec 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -89,6 +89,7 @@ def __init__( pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, http_versions: HTTPVersionTypes = None, backend: ConcurrencyBackend = None, + uds: typing.Optional[str] = None, ): self.verify = verify self.cert = cert @@ -97,6 +98,7 @@ def __init__( self.http_versions = http_versions self.is_closed = False self.trust_env = trust_env + self.uds = uds self.keepalive_connections = ConnectionStore() self.active_connections = ConnectionStore() @@ -142,6 +144,7 @@ async def acquire_connection(self, origin: Origin) -> HTTPConnection: backend=self.backend, release_func=self.release_connection, trust_env=self.trust_env, + uds=self.uds, ) logger.trace(f"new_connection connection={connection!r}") else: diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index eaac5ef76a..42202aa180 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -146,3 +146,14 @@ async def test_100_continue(server, backend): assert response.status_code == 200 assert response.content == data + + +async def test_uds(uds_server, backend): + url = uds_server.url + uds = uds_server.config.uds + assert uds is not None + async with httpx.AsyncClient(backend=backend, uds=uds) as client: + response = await client.get(url) + assert response.status_code == 200 + assert response.text == "Hello, world!" + assert response.encoding == "iso-8859-1" diff --git a/tests/client/test_client.py b/tests/client/test_client.py index f7be6070e9..5dc196d933 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -138,6 +138,17 @@ def test_base_url(server): assert response.url == base_url +def test_uds(uds_server): + url = uds_server.url + uds = uds_server.config.uds + assert uds is not None + with httpx.Client(uds=uds) as http: + response = http.get(url) + assert response.status_code == 200 + assert response.text == "Hello, world!" + assert response.encoding == "iso-8859-1" + + def test_merge_url(): client = httpx.Client(base_url="https://www.paypal.com/") url = client.merge_url("http://www.paypal.com") diff --git a/tests/conftest.py b/tests/conftest.py index ef57caef30..de67ff7f92 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -288,6 +288,15 @@ def server(): yield from serve_in_thread(server) +@pytest.fixture(scope=SERVER_SCOPE) +def uds_server(): + uds = "test_server.sock" + config = Config(app=app, lifespan="off", loop="asyncio", uds=uds) + server = TestServer(config=config) + yield from serve_in_thread(server) + os.remove(uds) + + @pytest.fixture(scope=SERVER_SCOPE) def https_server(cert_pem_file, cert_private_key_file): config = Config( @@ -301,3 +310,19 @@ def https_server(cert_pem_file, cert_private_key_file): ) server = TestServer(config=config) yield from serve_in_thread(server) + + +@pytest.fixture(scope=SERVER_SCOPE) +def https_uds_server(cert_pem_file, cert_private_key_file): + uds = "https_test_server.sock" + config = Config( + app=app, + lifespan="off", + ssl_certfile=cert_pem_file, + ssl_keyfile=cert_private_key_file, + uds=uds, + loop="asyncio", + ) + server = TestServer(config=config) + yield from serve_in_thread(server) + os.remove(uds) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 7477ea3b96..cf3844cffe 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -24,7 +24,10 @@ ), ], ) -async def test_start_tls_on_socket_stream(https_server, backend, get_cipher): +@pytest.mark.parametrize("use_uds", (False, True)) +async def test_start_tls_on_socket_stream( + https_server, https_uds_server, backend, get_cipher, use_uds +): """ See that the concurrency backend can make a connection without TLS then start TLS on an existing connection. @@ -32,9 +35,15 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher): ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig()) timeout = TimeoutConfig(5) - stream = await backend.open_tcp_stream( - https_server.url.host, https_server.url.port, None, timeout - ) + if use_uds: + assert https_uds_server.config.uds is not None + stream = await backend.open_uds_stream( + https_uds_server.config.uds, https_uds_server.url.host, None, timeout + ) + else: + stream = await backend.open_tcp_stream( + https_server.url.host, https_server.url.port, None, timeout + ) try: assert stream.is_connection_dropped() is False