Skip to content

Commit

Permalink
Add support for unix domain sockets (#511)
Browse files Browse the repository at this point in the history
* Add and implement open_uds_stream in concurrency backends

* Add uds arg to BaseClient and select tcp or uds in HttpConnection

* Make open stream methods in backends more explicit

* Close sentence

Co-Authored-By: Florimond Manca <[email protected]>

* Refactor uds concurrency test

* Remove redundant uds test assertions
  • Loading branch information
lundberg authored and florimondmanca committed Nov 19, 2019
1 parent a5f9983 commit 7a96a2c
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 6 deletions.
3 changes: 3 additions & 0 deletions httpx/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
app: typing.Callable = None,
backend: ConcurrencyBackend = None,
trust_env: bool = True,
uds: str = None,
):
if backend is None:
backend = AsyncioBackend()
Expand All @@ -99,6 +100,7 @@ def __init__(
pool_limits=pool_limits,
backend=backend,
trust_env=self.trust_env,
uds=uds,
)
elif isinstance(dispatch, Dispatcher):
async_dispatch = ThreadedDispatcher(dispatch, backend)
Expand Down Expand Up @@ -721,6 +723,7 @@ class Client(BaseClient):
async requests.
* **trust_env** - *(optional)* Enables or disables usage of environment
variables for configuration.
* **uds** - *(optional)* A path to a Unix domain socket to connect through.
"""

def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None:
Expand Down
23 changes: 23 additions & 0 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,29 @@ async def open_tcp_stream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)

async def open_uds_stream(
self,
path: str,
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> SocketStream:
server_hostname = hostname if ssl_context else None

try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_unix_connection(
path, ssl=ssl_context, server_hostname=server_hostname
),
timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()

return SocketStream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)

async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
Expand Down
9 changes: 9 additions & 0 deletions httpx/concurrency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ async def open_tcp_stream(
) -> BaseSocketStream:
raise NotImplementedError() # pragma: no cover

async def open_uds_stream(
self,
path: str,
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseSocketStream:
raise NotImplementedError() # pragma: no cover

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover

Expand Down
20 changes: 20 additions & 0 deletions httpx/concurrency/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,26 @@ async def open_tcp_stream(

return SocketStream(stream=stream, timeout=timeout)

async def open_uds_stream(
self,
path: str,
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> SocketStream:
connect_timeout = _or_inf(timeout.connect_timeout)

with trio.move_on_after(connect_timeout) as cancel_scope:
stream: trio.SocketStream = await trio.open_unix_socket(path)
if ssl_context is not None:
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
await stream.do_handshake()

if cancel_scope.cancelled_caught:
raise ConnectTimeout()

return SocketStream(stream=stream, timeout=timeout)

async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
Expand Down
19 changes: 17 additions & 2 deletions httpx/dispatch/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ def __init__(
http_versions: HTTPVersionTypes = None,
backend: ConcurrencyBackend = None,
release_func: typing.Optional[ReleaseCallback] = None,
uds: typing.Optional[str] = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
self.timeout = TimeoutConfig(timeout)
self.http_versions = HTTPVersionConfig(http_versions)
self.backend = AsyncioBackend() if backend is None else backend
self.release_func = release_func
self.uds = uds
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
self.h2_connection = None # type: typing.Optional[HTTP2Connection]

Expand Down Expand Up @@ -84,8 +86,21 @@ async def connect(
else:
on_release = functools.partial(self.release_func, self)

logger.trace(f"start_connect host={host!r} port={port!r} timeout={timeout!r}")
stream = await self.backend.open_tcp_stream(host, port, ssl_context, timeout)
if self.uds is None:
logger.trace(
f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}"
)
stream = await self.backend.open_tcp_stream(
host, port, ssl_context, timeout
)
else:
logger.trace(
f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}"
)
stream = await self.backend.open_uds_stream(
self.uds, host, ssl_context, timeout
)

http_version = stream.get_http_version()
logger.trace(f"connected http_version={http_version!r}")

Expand Down
3 changes: 3 additions & 0 deletions httpx/dispatch/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
http_versions: HTTPVersionTypes = None,
backend: ConcurrencyBackend = None,
uds: typing.Optional[str] = None,
):
self.verify = verify
self.cert = cert
Expand All @@ -97,6 +98,7 @@ def __init__(
self.http_versions = http_versions
self.is_closed = False
self.trust_env = trust_env
self.uds = uds

self.keepalive_connections = ConnectionStore()
self.active_connections = ConnectionStore()
Expand Down Expand Up @@ -142,6 +144,7 @@ async def acquire_connection(self, origin: Origin) -> HTTPConnection:
backend=self.backend,
release_func=self.release_connection,
trust_env=self.trust_env,
uds=self.uds,
)
logger.trace(f"new_connection connection={connection!r}")
else:
Expand Down
11 changes: 11 additions & 0 deletions tests/client/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,14 @@ async def test_100_continue(server, backend):

assert response.status_code == 200
assert response.content == data


async def test_uds(uds_server, backend):
url = uds_server.url
uds = uds_server.config.uds
assert uds is not None
async with httpx.AsyncClient(backend=backend, uds=uds) as client:
response = await client.get(url)
assert response.status_code == 200
assert response.text == "Hello, world!"
assert response.encoding == "iso-8859-1"
11 changes: 11 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ def test_base_url(server):
assert response.url == base_url


def test_uds(uds_server):
url = uds_server.url
uds = uds_server.config.uds
assert uds is not None
with httpx.Client(uds=uds) as http:
response = http.get(url)
assert response.status_code == 200
assert response.text == "Hello, world!"
assert response.encoding == "iso-8859-1"


def test_merge_url():
client = httpx.Client(base_url="https://www.paypal.com/")
url = client.merge_url("http://www.paypal.com")
Expand Down
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ def server():
yield from serve_in_thread(server)


@pytest.fixture(scope=SERVER_SCOPE)
def uds_server():
uds = "test_server.sock"
config = Config(app=app, lifespan="off", loop="asyncio", uds=uds)
server = TestServer(config=config)
yield from serve_in_thread(server)
os.remove(uds)


@pytest.fixture(scope=SERVER_SCOPE)
def https_server(cert_pem_file, cert_private_key_file):
config = Config(
Expand All @@ -301,3 +310,19 @@ def https_server(cert_pem_file, cert_private_key_file):
)
server = TestServer(config=config)
yield from serve_in_thread(server)


@pytest.fixture(scope=SERVER_SCOPE)
def https_uds_server(cert_pem_file, cert_private_key_file):
uds = "https_test_server.sock"
config = Config(
app=app,
lifespan="off",
ssl_certfile=cert_pem_file,
ssl_keyfile=cert_private_key_file,
uds=uds,
loop="asyncio",
)
server = TestServer(config=config)
yield from serve_in_thread(server)
os.remove(uds)
17 changes: 13 additions & 4 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,26 @@
),
],
)
async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
@pytest.mark.parametrize("use_uds", (False, True))
async def test_start_tls_on_socket_stream(
https_server, https_uds_server, backend, get_cipher, use_uds
):
"""
See that the concurrency backend can make a connection without TLS then
start TLS on an existing connection.
"""
ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
timeout = TimeoutConfig(5)

stream = await backend.open_tcp_stream(
https_server.url.host, https_server.url.port, None, timeout
)
if use_uds:
assert https_uds_server.config.uds is not None
stream = await backend.open_uds_stream(
https_uds_server.config.uds, https_uds_server.url.host, None, timeout
)
else:
stream = await backend.open_tcp_stream(
https_server.url.host, https_server.url.port, None, timeout
)

try:
assert stream.is_connection_dropped() is False
Expand Down

0 comments on commit 7a96a2c

Please sign in to comment.