diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index 8ce3d545..95ddb526 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -20,12 +20,14 @@ def __init__( self, origin: Origin, http2: bool = False, + uds: str = None, ssl_context: SSLContext = None, socket: AsyncSocketStream = None, local_address: str = None, ): self.origin = origin self.http2 = http2 + self.uds = uds self.ssl_context = SSLContext() if ssl_context is None else ssl_context self.socket = socket self.local_address = local_address @@ -98,9 +100,18 @@ async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream: timeout = {} if timeout is None else timeout ssl_context = self.ssl_context if scheme == b"https" else None try: - return await self.backend.open_tcp_stream( - hostname, port, ssl_context, timeout, local_address=self.local_address - ) + if self.uds is None: + return await self.backend.open_tcp_stream( + hostname, + port, + ssl_context, + timeout, + local_address=self.local_address, + ) + else: + return await self.backend.open_uds_stream( + self.uds, hostname, ssl_context, timeout + ) except Exception: self.connect_failed = True raise diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 9ce7ad43..ce2bb439 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -77,6 +77,7 @@ class AsyncConnectionPool(AsyncHTTPTransport): * **keepalive_expiry** - `Optional[float]` - The maximum time to allow before closing a keep-alive connection. * **http2** - `bool` - Enable HTTP/2 support. + * **uds** - `str` - Path to a Unix Domain Socket to use instead of TCP sockets. * **local_address** - `Optional[str]` - Local address to connect from. Can also be used to connect using a particular address family. Using `local_address="0.0.0.0"` will connect using an `AF_INET` address (IPv4), @@ -91,6 +92,7 @@ def __init__( max_keepalive_connections: int = None, keepalive_expiry: float = None, http2: bool = False, + uds: str = None, local_address: str = None, max_keepalive: int = None, ): @@ -106,6 +108,7 @@ def __init__( self._max_keepalive = max_keepalive self._keepalive_expiry = keepalive_expiry self._http2 = http2 + self._uds = uds self._local_address = local_address self._connections: Dict[Origin, Set[AsyncHTTPConnection]] = {} self._thread_lock = ThreadLock() @@ -172,6 +175,7 @@ async def request( connection = AsyncHTTPConnection( origin=origin, http2=self._http2, + uds=self._uds, ssl_context=self._ssl_context, local_address=self._local_address, ) diff --git a/httpcore/_backends/asyncio.py b/httpcore/_backends/asyncio.py index 1364e54b..5af810a9 100644 --- a/httpcore/_backends/asyncio.py +++ b/httpcore/_backends/asyncio.py @@ -244,6 +244,26 @@ async def open_tcp_stream( stream_reader=stream_reader, stream_writer=stream_writer ) + async def open_uds_stream( + self, + path: str, + hostname: bytes, + ssl_context: Optional[SSLContext], + timeout: TimeoutDict, + ) -> AsyncSocketStream: + host = hostname.decode("ascii") + connect_timeout = timeout.get("connect") + kwargs: dict = {"server_hostname": host} if ssl_context is not None else {} + exc_map = {asyncio.TimeoutError: ConnectTimeout, OSError: ConnectError} + with map_exceptions(exc_map): + stream_reader, stream_writer = await asyncio.wait_for( + asyncio.open_unix_connection(path, ssl=ssl_context, **kwargs), + connect_timeout, + ) + return SocketStream( + stream_reader=stream_reader, stream_writer=stream_writer + ) + def create_lock(self) -> AsyncLock: return Lock() diff --git a/httpcore/_backends/auto.py b/httpcore/_backends/auto.py index ee06a05a..19bc62c8 100644 --- a/httpcore/_backends/auto.py +++ b/httpcore/_backends/auto.py @@ -41,6 +41,15 @@ async def open_tcp_stream( hostname, port, ssl_context, timeout, local_address=local_address ) + async def open_uds_stream( + self, + path: str, + hostname: bytes, + ssl_context: Optional[SSLContext], + timeout: TimeoutDict, + ) -> AsyncSocketStream: + return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout) + def create_lock(self) -> AsyncLock: return self.backend.create_lock() diff --git a/httpcore/_backends/base.py b/httpcore/_backends/base.py index 9bb05af3..1c80156d 100644 --- a/httpcore/_backends/base.py +++ b/httpcore/_backends/base.py @@ -81,6 +81,15 @@ async def open_tcp_stream( ) -> AsyncSocketStream: raise NotImplementedError() # pragma: no cover + async def open_uds_stream( + self, + path: str, + hostname: bytes, + ssl_context: Optional[SSLContext], + timeout: TimeoutDict, + ) -> AsyncSocketStream: + raise NotImplementedError() # pragma: no cover + def create_lock(self) -> AsyncLock: raise NotImplementedError() # pragma: no cover diff --git a/httpcore/_backends/sync.py b/httpcore/_backends/sync.py index 69f03b25..8e9a7bbe 100644 --- a/httpcore/_backends/sync.py +++ b/httpcore/_backends/sync.py @@ -147,6 +147,28 @@ def open_tcp_stream( ) return SyncSocketStream(sock=sock) + def open_uds_stream( + self, + path: str, + hostname: bytes, + ssl_context: Optional[SSLContext], + timeout: TimeoutDict, + ) -> SyncSocketStream: + connect_timeout = timeout.get("connect") + exc_map = {socket.timeout: ConnectTimeout, socket.error: ConnectError} + + with map_exceptions(exc_map): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(connect_timeout) + sock.connect(path) + + if ssl_context is not None: + sock = ssl_context.wrap_socket( + sock, server_hostname=hostname.decode("ascii") + ) + + return SyncSocketStream(sock=sock) + def create_lock(self) -> SyncLock: return SyncLock() diff --git a/httpcore/_backends/trio.py b/httpcore/_backends/trio.py index 91e0cf44..476bb8df 100644 --- a/httpcore/_backends/trio.py +++ b/httpcore/_backends/trio.py @@ -170,6 +170,31 @@ async def open_tcp_stream( return SocketStream(stream=stream) + async def open_uds_stream( + self, + path: str, + hostname: bytes, + ssl_context: Optional[SSLContext], + timeout: TimeoutDict, + ) -> AsyncSocketStream: + connect_timeout = none_as_inf(timeout.get("connect")) + exc_map = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + } + + with map_exceptions(exc_map): + with trio.fail_after(connect_timeout): + stream: trio.abc.Stream = await trio.open_unix_socket(path) + + if ssl_context is not None: + stream = trio.SSLStream( + stream, ssl_context, server_hostname=hostname.decode("ascii") + ) + await stream.do_handshake() + + return SocketStream(stream=stream) + def create_lock(self) -> AsyncLock: return Lock() diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 529a8193..bc7d1894 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -20,12 +20,14 @@ def __init__( self, origin: Origin, http2: bool = False, + uds: str = None, ssl_context: SSLContext = None, socket: SyncSocketStream = None, local_address: str = None, ): self.origin = origin self.http2 = http2 + self.uds = uds self.ssl_context = SSLContext() if ssl_context is None else ssl_context self.socket = socket self.local_address = local_address @@ -98,9 +100,18 @@ def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream: timeout = {} if timeout is None else timeout ssl_context = self.ssl_context if scheme == b"https" else None try: - return self.backend.open_tcp_stream( - hostname, port, ssl_context, timeout, local_address=self.local_address - ) + if self.uds is None: + return self.backend.open_tcp_stream( + hostname, + port, + ssl_context, + timeout, + local_address=self.local_address, + ) + else: + return self.backend.open_uds_stream( + self.uds, hostname, ssl_context, timeout + ) except Exception: self.connect_failed = True raise diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 7a1e8c93..0270fd46 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -77,6 +77,7 @@ class SyncConnectionPool(SyncHTTPTransport): * **keepalive_expiry** - `Optional[float]` - The maximum time to allow before closing a keep-alive connection. * **http2** - `bool` - Enable HTTP/2 support. + * **uds** - `str` - Path to a Unix Domain Socket to use instead of TCP sockets. * **local_address** - `Optional[str]` - Local address to connect from. Can also be used to connect using a particular address family. Using `local_address="0.0.0.0"` will connect using an `AF_INET` address (IPv4), @@ -91,6 +92,7 @@ def __init__( max_keepalive_connections: int = None, keepalive_expiry: float = None, http2: bool = False, + uds: str = None, local_address: str = None, max_keepalive: int = None, ): @@ -106,6 +108,7 @@ def __init__( self._max_keepalive = max_keepalive self._keepalive_expiry = keepalive_expiry self._http2 = http2 + self._uds = uds self._local_address = local_address self._connections: Dict[Origin, Set[SyncHTTPConnection]] = {} self._thread_lock = ThreadLock() @@ -172,6 +175,7 @@ def request( connection = SyncHTTPConnection( origin=origin, http2=self._http2, + uds=self._uds, ssl_context=self._ssl_context, local_address=self._local_address, ) diff --git a/requirements.txt b/requirements.txt index f0f87640..5dce2bd4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,3 +24,4 @@ mypy isort==5.* mitmproxy trustme +uvicorn diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py index 8e11ed03..94635b3c 100644 --- a/tests/async_tests/test_interfaces.py +++ b/tests/async_tests/test_interfaces.py @@ -1,4 +1,6 @@ import ssl +import platform +from pathlib import Path import pytest @@ -304,3 +306,25 @@ async def test_connection_pool_get_connection_info( stats = await http.get_connection_info() assert stats == {} + + +@pytest.mark.skipif( + platform.system() not in ("Linux", "Darwin"), + reason="Unix Domain Sockets only exist on Unix", +) +@pytest.mark.usefixtures("async_environment") +async def test_http_request_unix_domain_socket(uds_server) -> None: + uds = uds_server.config.uds + assert uds is not None + async with httpcore.AsyncConnectionPool(uds=uds) as http: + method = b"GET" + url = (b"http", b"localhost", None, b"/") + headers = [(b"host", b"localhost")] + http_version, status_code, reason, headers, stream = await http.request( + method, url, headers + ) + assert http_version == b"HTTP/1.1" + assert status_code == 200 + assert reason == b"OK" + body = await read_body(stream) + assert body == b"Hello, world!" diff --git a/tests/conftest.py b/tests/conftest.py index b20c56b9..9a64c425 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,11 @@ import ssl import threading import typing +import contextlib +import time +import os +import uvicorn import pytest import trustme from mitmproxy import options, proxy @@ -120,3 +124,44 @@ def proxy_server(example_org_cert_path: str) -> typing.Iterator[URL]: yield (b"http", PROXY_HOST.encode(), PROXY_PORT, b"/") finally: thread.join() + + +class Server(uvicorn.Server): + def install_signal_handlers(self) -> None: + pass + + @contextlib.contextmanager + def serve_in_thread(self) -> typing.Iterator[None]: + thread = threading.Thread(target=self.run) + thread.start() + try: + while not self.started: + time.sleep(1e-3) + yield + finally: + self.should_exit = True + thread.join() + + +async def app(scope: dict, receive: typing.Callable, send: typing.Callable) -> None: + assert scope["type"] == "http" + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + +@pytest.fixture(scope="session") +def uds_server() -> typing.Iterator[Server]: + uds = "test_server.sock" + config = uvicorn.Config(app=app, lifespan="off", loop="asyncio", uds=uds) + server = Server(config=config) + try: + with server.serve_in_thread(): + yield server + finally: + os.remove(uds) diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py index ac3d8eca..c000a644 100644 --- a/tests/sync_tests/test_interfaces.py +++ b/tests/sync_tests/test_interfaces.py @@ -1,4 +1,6 @@ import ssl +import platform +from pathlib import Path import pytest @@ -304,3 +306,25 @@ def test_connection_pool_get_connection_info( stats = http.get_connection_info() assert stats == {} + + +@pytest.mark.skipif( + platform.system() not in ("Linux", "Darwin"), + reason="Unix Domain Sockets only exist on Unix", +) + +def test_http_request_unix_domain_socket(uds_server) -> None: + uds = uds_server.config.uds + assert uds is not None + with httpcore.SyncConnectionPool(uds=uds) as http: + method = b"GET" + url = (b"http", b"localhost", None, b"/") + headers = [(b"host", b"localhost")] + http_version, status_code, reason, headers, stream = http.request( + method, url, headers + ) + assert http_version == b"HTTP/1.1" + assert status_code == 200 + assert reason == b"OK" + body = read_body(stream) + assert body == b"Hello, world!"