Skip to content

Commit

Permalink
Add uds arg to BaseClient and select tcp or uds in HttpConnection
Browse files Browse the repository at this point in the history
  • Loading branch information
lundberg committed Nov 8, 2019
1 parent 1e40664 commit c066d3e
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 4 deletions.
3 changes: 3 additions & 0 deletions httpx/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
app: typing.Callable = None,
backend: ConcurrencyBackend = None,
trust_env: bool = True,
uds: str = None,
):
if backend is None:
backend = AsyncioBackend()
Expand All @@ -99,6 +100,7 @@ def __init__(
pool_limits=pool_limits,
backend=backend,
trust_env=self.trust_env,
uds=uds,
)
elif isinstance(dispatch, Dispatcher):
async_dispatch = ThreadedDispatcher(dispatch, backend)
Expand Down Expand Up @@ -721,6 +723,7 @@ class Client(BaseClient):
async requests.
* **trust_env** - *(optional)* Enables or disables usage of environment
variables for configuration.
* **uds** - *(optional)* A path to a Unix domain socket to connect through
"""

def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None:
Expand Down
2 changes: 1 addition & 1 deletion httpx/concurrency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def open_uds_stream(
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseTCPStream:
) -> BaseSocketStream:
raise NotImplementedError() # pragma: no cover

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
Expand Down
19 changes: 17 additions & 2 deletions httpx/dispatch/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ def __init__(
http_versions: HTTPVersionTypes = None,
backend: ConcurrencyBackend = None,
release_func: typing.Optional[ReleaseCallback] = None,
uds: typing.Optional[str] = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
self.timeout = TimeoutConfig(timeout)
self.http_versions = HTTPVersionConfig(http_versions)
self.backend = AsyncioBackend() if backend is None else backend
self.release_func = release_func
self.uds = uds
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
self.h2_connection = None # type: typing.Optional[HTTP2Connection]

Expand Down Expand Up @@ -84,8 +86,21 @@ async def connect(
else:
on_release = functools.partial(self.release_func, self)

logger.trace(f"start_connect host={host!r} port={port!r} timeout={timeout!r}")
stream = await self.backend.open_tcp_stream(host, port, ssl_context, timeout)
if self.uds is None:
logger.trace(
f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}"
)
stream = await self.backend.open_tcp_stream(
host, port, ssl_context, timeout
)
else:
logger.trace(
f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}"
)
stream = await self.backend.open_uds_stream(
self.uds, host, ssl_context, timeout
)

http_version = stream.get_http_version()
logger.trace(f"connected http_version={http_version!r}")

Expand Down
3 changes: 3 additions & 0 deletions httpx/dispatch/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
http_versions: HTTPVersionTypes = None,
backend: ConcurrencyBackend = None,
uds: typing.Optional[str] = None,
):
self.verify = verify
self.cert = cert
Expand All @@ -97,6 +98,7 @@ def __init__(
self.http_versions = http_versions
self.is_closed = False
self.trust_env = trust_env
self.uds = uds

self.keepalive_connections = ConnectionStore()
self.active_connections = ConnectionStore()
Expand Down Expand Up @@ -142,6 +144,7 @@ async def acquire_connection(self, origin: Origin) -> HTTPConnection:
backend=self.backend,
release_func=self.release_connection,
trust_env=self.trust_env,
uds=self.uds,
)
logger.trace(f"new_connection connection={connection!r}")
else:
Expand Down
14 changes: 14 additions & 0 deletions tests/client/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,17 @@ async def test_100_continue(server, backend):

assert response.status_code == 200
assert response.content == data


async def test_uds(uds_server, backend):
url = uds_server.url
uds = uds_server.config.uds
assert uds is not None
async with httpx.AsyncClient(backend=backend, uds=uds) as client:
response = await client.get(url)
assert response.status_code == 200
assert response.text == "Hello, world!"
assert response.http_version == "HTTP/1.1"
assert response.headers
assert repr(response) == "<Response [200 OK]>"
assert response.elapsed > timedelta(seconds=0)
19 changes: 19 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,25 @@ def test_base_url(server):
assert response.url == base_url


def test_uds(uds_server):
url = uds_server.url
uds = uds_server.config.uds
assert uds is not None
with httpx.Client(uds=uds) as http:
response = http.get(url)
assert response.status_code == 200
assert response.url == url
assert response.content == b"Hello, world!"
assert response.text == "Hello, world!"
assert response.http_version == "HTTP/1.1"
assert response.encoding == "iso-8859-1"
assert response.request.url == url
assert response.headers
assert response.is_redirect is False
assert repr(response) == "<Response [200 OK]>"
assert response.elapsed > timedelta(0)


def test_merge_url():
client = httpx.Client(base_url="https://www.paypal.com/")
url = client.merge_url("http://www.paypal.com")
Expand Down
13 changes: 12 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ def server():
yield from serve_in_thread(server)


@pytest.fixture(scope=SERVER_SCOPE)
def uds_server():
uds = "test_server.sock"
config = Config(app=app, lifespan="off", loop="asyncio", uds=uds)
server = TestServer(config=config)
yield from serve_in_thread(server)
os.remove(uds)


@pytest.fixture(scope=SERVER_SCOPE)
def https_server(cert_pem_file, cert_private_key_file):
config = Config(
Expand All @@ -305,13 +314,15 @@ def https_server(cert_pem_file, cert_private_key_file):

@pytest.fixture(scope=SERVER_SCOPE)
def https_uds_server(cert_pem_file, cert_private_key_file):
uds = "https_test_server.sock"
config = Config(
app=app,
lifespan="off",
ssl_certfile=cert_pem_file,
ssl_keyfile=cert_private_key_file,
uds="https_test_server.sock",
uds=uds,
loop="asyncio",
)
server = TestServer(config=config)
yield from serve_in_thread(server)
os.remove(uds)

0 comments on commit c066d3e

Please sign in to comment.