Skip to content

Commit

Permalink
Add Unix Domain Socket support (#139)
Browse files Browse the repository at this point in the history
* Add Unix Domain Socket support

* Update tests

* Add uvicorn dep

* Added newline
  • Loading branch information
florimondmanca authored Aug 11, 2020
1 parent bf88f29 commit 2328b0c
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 6 deletions.
17 changes: 14 additions & 3 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ def __init__(
self,
origin: Origin,
http2: bool = False,
uds: str = None,
ssl_context: SSLContext = None,
socket: AsyncSocketStream = None,
local_address: str = None,
):
self.origin = origin
self.http2 = http2
self.uds = uds
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket
self.local_address = local_address
Expand Down Expand Up @@ -98,9 +100,18 @@ async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream:
timeout = {} if timeout is None else timeout
ssl_context = self.ssl_context if scheme == b"https" else None
try:
return await self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout, local_address=self.local_address
)
if self.uds is None:
return await self.backend.open_tcp_stream(
hostname,
port,
ssl_context,
timeout,
local_address=self.local_address,
)
else:
return await self.backend.open_uds_stream(
self.uds, hostname, ssl_context, timeout
)
except Exception:
self.connect_failed = True
raise
Expand Down
4 changes: 4 additions & 0 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class AsyncConnectionPool(AsyncHTTPTransport):
* **keepalive_expiry** - `Optional[float]` - The maximum time to allow
before closing a keep-alive connection.
* **http2** - `bool` - Enable HTTP/2 support.
* **uds** - `str` - Path to a Unix Domain Socket to use instead of TCP sockets.
* **local_address** - `Optional[str]` - Local address to connect from. Can
also be used to connect using a particular address family. Using
`local_address="0.0.0.0"` will connect using an `AF_INET` address (IPv4),
Expand All @@ -91,6 +92,7 @@ def __init__(
max_keepalive_connections: int = None,
keepalive_expiry: float = None,
http2: bool = False,
uds: str = None,
local_address: str = None,
max_keepalive: int = None,
):
Expand All @@ -106,6 +108,7 @@ def __init__(
self._max_keepalive = max_keepalive
self._keepalive_expiry = keepalive_expiry
self._http2 = http2
self._uds = uds
self._local_address = local_address
self._connections: Dict[Origin, Set[AsyncHTTPConnection]] = {}
self._thread_lock = ThreadLock()
Expand Down Expand Up @@ -172,6 +175,7 @@ async def request(
connection = AsyncHTTPConnection(
origin=origin,
http2=self._http2,
uds=self._uds,
ssl_context=self._ssl_context,
local_address=self._local_address,
)
Expand Down
20 changes: 20 additions & 0 deletions httpcore/_backends/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,26 @@ async def open_tcp_stream(
stream_reader=stream_reader, stream_writer=stream_writer
)

async def open_uds_stream(
self,
path: str,
hostname: bytes,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
) -> AsyncSocketStream:
host = hostname.decode("ascii")
connect_timeout = timeout.get("connect")
kwargs: dict = {"server_hostname": host} if ssl_context is not None else {}
exc_map = {asyncio.TimeoutError: ConnectTimeout, OSError: ConnectError}
with map_exceptions(exc_map):
stream_reader, stream_writer = await asyncio.wait_for(
asyncio.open_unix_connection(path, ssl=ssl_context, **kwargs),
connect_timeout,
)
return SocketStream(
stream_reader=stream_reader, stream_writer=stream_writer
)

def create_lock(self) -> AsyncLock:
return Lock()

Expand Down
9 changes: 9 additions & 0 deletions httpcore/_backends/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ async def open_tcp_stream(
hostname, port, ssl_context, timeout, local_address=local_address
)

async def open_uds_stream(
self,
path: str,
hostname: bytes,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
) -> AsyncSocketStream:
return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout)

def create_lock(self) -> AsyncLock:
return self.backend.create_lock()

Expand Down
9 changes: 9 additions & 0 deletions httpcore/_backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ async def open_tcp_stream(
) -> AsyncSocketStream:
raise NotImplementedError() # pragma: no cover

async def open_uds_stream(
self,
path: str,
hostname: bytes,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
) -> AsyncSocketStream:
raise NotImplementedError() # pragma: no cover

def create_lock(self) -> AsyncLock:
raise NotImplementedError() # pragma: no cover

Expand Down
22 changes: 22 additions & 0 deletions httpcore/_backends/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,28 @@ def open_tcp_stream(
)
return SyncSocketStream(sock=sock)

def open_uds_stream(
self,
path: str,
hostname: bytes,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
) -> SyncSocketStream:
connect_timeout = timeout.get("connect")
exc_map = {socket.timeout: ConnectTimeout, socket.error: ConnectError}

with map_exceptions(exc_map):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(connect_timeout)
sock.connect(path)

if ssl_context is not None:
sock = ssl_context.wrap_socket(
sock, server_hostname=hostname.decode("ascii")
)

return SyncSocketStream(sock=sock)

def create_lock(self) -> SyncLock:
return SyncLock()

Expand Down
25 changes: 25 additions & 0 deletions httpcore/_backends/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,31 @@ async def open_tcp_stream(

return SocketStream(stream=stream)

async def open_uds_stream(
self,
path: str,
hostname: bytes,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
) -> AsyncSocketStream:
connect_timeout = none_as_inf(timeout.get("connect"))
exc_map = {
trio.TooSlowError: ConnectTimeout,
trio.BrokenResourceError: ConnectError,
}

with map_exceptions(exc_map):
with trio.fail_after(connect_timeout):
stream: trio.abc.Stream = await trio.open_unix_socket(path)

if ssl_context is not None:
stream = trio.SSLStream(
stream, ssl_context, server_hostname=hostname.decode("ascii")
)
await stream.do_handshake()

return SocketStream(stream=stream)

def create_lock(self) -> AsyncLock:
return Lock()

Expand Down
17 changes: 14 additions & 3 deletions httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ def __init__(
self,
origin: Origin,
http2: bool = False,
uds: str = None,
ssl_context: SSLContext = None,
socket: SyncSocketStream = None,
local_address: str = None,
):
self.origin = origin
self.http2 = http2
self.uds = uds
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket
self.local_address = local_address
Expand Down Expand Up @@ -98,9 +100,18 @@ def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream:
timeout = {} if timeout is None else timeout
ssl_context = self.ssl_context if scheme == b"https" else None
try:
return self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout, local_address=self.local_address
)
if self.uds is None:
return self.backend.open_tcp_stream(
hostname,
port,
ssl_context,
timeout,
local_address=self.local_address,
)
else:
return self.backend.open_uds_stream(
self.uds, hostname, ssl_context, timeout
)
except Exception:
self.connect_failed = True
raise
Expand Down
4 changes: 4 additions & 0 deletions httpcore/_sync/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class SyncConnectionPool(SyncHTTPTransport):
* **keepalive_expiry** - `Optional[float]` - The maximum time to allow
before closing a keep-alive connection.
* **http2** - `bool` - Enable HTTP/2 support.
* **uds** - `str` - Path to a Unix Domain Socket to use instead of TCP sockets.
* **local_address** - `Optional[str]` - Local address to connect from. Can
also be used to connect using a particular address family. Using
`local_address="0.0.0.0"` will connect using an `AF_INET` address (IPv4),
Expand All @@ -91,6 +92,7 @@ def __init__(
max_keepalive_connections: int = None,
keepalive_expiry: float = None,
http2: bool = False,
uds: str = None,
local_address: str = None,
max_keepalive: int = None,
):
Expand All @@ -106,6 +108,7 @@ def __init__(
self._max_keepalive = max_keepalive
self._keepalive_expiry = keepalive_expiry
self._http2 = http2
self._uds = uds
self._local_address = local_address
self._connections: Dict[Origin, Set[SyncHTTPConnection]] = {}
self._thread_lock = ThreadLock()
Expand Down Expand Up @@ -172,6 +175,7 @@ def request(
connection = SyncHTTPConnection(
origin=origin,
http2=self._http2,
uds=self._uds,
ssl_context=self._ssl_context,
local_address=self._local_address,
)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ mypy
isort==5.*
mitmproxy
trustme
uvicorn
24 changes: 24 additions & 0 deletions tests/async_tests/test_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import ssl
import platform
from pathlib import Path

import pytest

Expand Down Expand Up @@ -304,3 +306,25 @@ async def test_connection_pool_get_connection_info(

stats = await http.get_connection_info()
assert stats == {}


@pytest.mark.skipif(
platform.system() not in ("Linux", "Darwin"),
reason="Unix Domain Sockets only exist on Unix",
)
@pytest.mark.usefixtures("async_environment")
async def test_http_request_unix_domain_socket(uds_server) -> None:
uds = uds_server.config.uds
assert uds is not None
async with httpcore.AsyncConnectionPool(uds=uds) as http:
method = b"GET"
url = (b"http", b"localhost", None, b"/")
headers = [(b"host", b"localhost")]
http_version, status_code, reason, headers, stream = await http.request(
method, url, headers
)
assert http_version == b"HTTP/1.1"
assert status_code == 200
assert reason == b"OK"
body = await read_body(stream)
assert body == b"Hello, world!"
45 changes: 45 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import ssl
import threading
import typing
import contextlib
import time

import os
import uvicorn
import pytest
import trustme
from mitmproxy import options, proxy
Expand Down Expand Up @@ -120,3 +124,44 @@ def proxy_server(example_org_cert_path: str) -> typing.Iterator[URL]:
yield (b"http", PROXY_HOST.encode(), PROXY_PORT, b"/")
finally:
thread.join()


class Server(uvicorn.Server):
def install_signal_handlers(self) -> None:
pass

@contextlib.contextmanager
def serve_in_thread(self) -> typing.Iterator[None]:
thread = threading.Thread(target=self.run)
thread.start()
try:
while not self.started:
time.sleep(1e-3)
yield
finally:
self.should_exit = True
thread.join()


async def app(scope: dict, receive: typing.Callable, send: typing.Callable) -> None:
assert scope["type"] == "http"
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})


@pytest.fixture(scope="session")
def uds_server() -> typing.Iterator[Server]:
uds = "test_server.sock"
config = uvicorn.Config(app=app, lifespan="off", loop="asyncio", uds=uds)
server = Server(config=config)
try:
with server.serve_in_thread():
yield server
finally:
os.remove(uds)
24 changes: 24 additions & 0 deletions tests/sync_tests/test_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import ssl
import platform
from pathlib import Path

import pytest

Expand Down Expand Up @@ -304,3 +306,25 @@ def test_connection_pool_get_connection_info(

stats = http.get_connection_info()
assert stats == {}


@pytest.mark.skipif(
platform.system() not in ("Linux", "Darwin"),
reason="Unix Domain Sockets only exist on Unix",
)

def test_http_request_unix_domain_socket(uds_server) -> None:
uds = uds_server.config.uds
assert uds is not None
with httpcore.SyncConnectionPool(uds=uds) as http:
method = b"GET"
url = (b"http", b"localhost", None, b"/")
headers = [(b"host", b"localhost")]
http_version, status_code, reason, headers, stream = http.request(
method, url, headers
)
assert http_version == b"HTTP/1.1"
assert status_code == 200
assert reason == b"OK"
body = read_body(stream)
assert body == b"Hello, world!"

0 comments on commit 2328b0c

Please sign in to comment.