From 76968d12d101872799a85980ae927364b2daab4f Mon Sep 17 00:00:00 2001 From: Jonas Lundberg Date: Thu, 7 Nov 2019 23:38:48 +0100 Subject: [PATCH] Add uds arg to BaseClient and select tcp or uds in HttpConnection --- httpx/client.py | 2 ++ httpx/concurrency/base.py | 2 +- httpx/dispatch/connection.py | 19 +++++++++++++++++-- httpx/dispatch/connection_pool.py | 3 +++ tests/client/test_async_client.py | 14 ++++++++++++++ tests/client/test_client.py | 19 +++++++++++++++++++ tests/conftest.py | 7 +++++++ 7 files changed, 63 insertions(+), 3 deletions(-) diff --git a/httpx/client.py b/httpx/client.py index 45589b9e23..d4167c22b0 100644 --- a/httpx/client.py +++ b/httpx/client.py @@ -74,6 +74,7 @@ def __init__( app: typing.Callable = None, backend: ConcurrencyBackend = None, trust_env: bool = True, + uds: str = None, ): if backend is None: backend = AsyncioBackend() @@ -99,6 +100,7 @@ def __init__( pool_limits=pool_limits, backend=backend, trust_env=self.trust_env, + uds=uds, ) elif isinstance(dispatch, Dispatcher): async_dispatch = ThreadedDispatcher(dispatch, backend) diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index 1e9607b18d..2109c2121a 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -130,7 +130,7 @@ async def open_uds_stream( hostname: typing.Optional[str], ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> BaseTCPStream: + ) -> BaseSocketStream: raise NotImplementedError() # pragma: no cover def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index 91feb97676..0612bccb86 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -38,6 +38,7 @@ def __init__( http_versions: HTTPVersionTypes = None, backend: ConcurrencyBackend = None, release_func: typing.Optional[ReleaseCallback] = None, + uds: typing.Optional[str] = None, ): self.origin = Origin(origin) if isinstance(origin, str) else origin self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env) @@ -45,6 +46,7 @@ def __init__( self.http_versions = HTTPVersionConfig(http_versions) self.backend = AsyncioBackend() if backend is None else backend self.release_func = release_func + self.uds = uds self.h11_connection = None # type: typing.Optional[HTTP11Connection] self.h2_connection = None # type: typing.Optional[HTTP2Connection] @@ -84,8 +86,21 @@ async def connect( else: on_release = functools.partial(self.release_func, self) - logger.trace(f"start_connect host={host!r} port={port!r} timeout={timeout!r}") - stream = await self.backend.open_tcp_stream(host, port, ssl_context, timeout) + if self.uds is None: + logger.trace( + f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}" + ) + stream = await self.backend.open_tcp_stream( + host, port, ssl_context, timeout + ) + else: + logger.trace( + f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}" + ) + stream = await self.backend.open_uds_stream( + self.uds, host, ssl_context, timeout + ) + http_version = stream.get_http_version() logger.trace(f"connected http_version={http_version!r}") diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index 189fcff699..5d7d886dec 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -89,6 +89,7 @@ def __init__( pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, http_versions: HTTPVersionTypes = None, backend: ConcurrencyBackend = None, + uds: typing.Optional[str] = None, ): self.verify = verify self.cert = cert @@ -97,6 +98,7 @@ def __init__( self.http_versions = http_versions self.is_closed = False self.trust_env = trust_env + self.uds = uds self.keepalive_connections = ConnectionStore() self.active_connections = ConnectionStore() @@ -142,6 +144,7 @@ async def acquire_connection(self, origin: Origin) -> HTTPConnection: backend=self.backend, release_func=self.release_connection, trust_env=self.trust_env, + uds=self.uds, ) logger.trace(f"new_connection connection={connection!r}") else: diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index eaac5ef76a..cd53c8fc08 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -146,3 +146,17 @@ async def test_100_continue(server, backend): assert response.status_code == 200 assert response.content == data + + +async def test_uds(uds_server, backend): + url = uds_server.url + uds = uds_server.config.uds + assert uds is not None + async with httpx.AsyncClient(backend=backend, uds=uds) as client: + response = await client.get(url) + assert response.status_code == 200 + assert response.text == "Hello, world!" + assert response.http_version == "HTTP/1.1" + assert response.headers + assert repr(response) == "" + assert response.elapsed > timedelta(seconds=0) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index f7be6070e9..e9940bb231 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -138,6 +138,25 @@ def test_base_url(server): assert response.url == base_url +def test_uds(uds_server): + url = uds_server.url + uds = uds_server.config.uds + assert uds is not None + with httpx.Client(uds=uds) as http: + response = http.get(url) + assert response.status_code == 200 + assert response.url == url + assert response.content == b"Hello, world!" + assert response.text == "Hello, world!" + assert response.http_version == "HTTP/1.1" + assert response.encoding == "iso-8859-1" + assert response.request.url == url + assert response.headers + assert response.is_redirect is False + assert repr(response) == "" + assert response.elapsed > timedelta(0) + + def test_merge_url(): client = httpx.Client(base_url="https://www.paypal.com/") url = client.merge_url("http://www.paypal.com") diff --git a/tests/conftest.py b/tests/conftest.py index b1bb038ad4..5c44f2e375 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -288,6 +288,13 @@ def server(): yield from serve_in_thread(server) +@pytest.fixture(scope=SERVER_SCOPE) +def uds_server(): + config = Config(app=app, lifespan="off", loop="asyncio", uds="test_server.sock") + server = TestServer(config=config) + yield from serve_in_thread(server) + + @pytest.fixture(scope=SERVER_SCOPE) def https_server(cert_pem_file, cert_private_key_file): config = Config(