Skip to content

Commit

Permalink
Make open stream methods in backends more explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
lundberg committed Nov 18, 2019
1 parent fdefc2a commit 1b0f54c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 32 deletions.
30 changes: 14 additions & 16 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,16 @@ async def open_tcp_stream(
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> SocketStream:
return await self._open_stream(
asyncio.open_connection(hostname, port, ssl=ssl_context), timeout
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()

return SocketStream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)

async def open_uds_stream(
Expand All @@ -275,23 +283,13 @@ async def open_uds_stream(
timeout: TimeoutConfig,
) -> SocketStream:
server_hostname = hostname if ssl_context else None
return await self._open_stream(
asyncio.open_unix_connection(
path, ssl=ssl_context, server_hostname=server_hostname
),
timeout,
)

async def _open_stream(
self,
socket_stream: typing.Awaitable[
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]
],
timeout: TimeoutConfig,
) -> SocketStream:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
socket_stream, timeout.connect_timeout,
asyncio.open_unix_connection(
path, ssl=ssl_context, server_hostname=server_hostname
),
timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()
Expand Down
29 changes: 13 additions & 16 deletions httpx/concurrency/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,33 +178,30 @@ async def open_tcp_stream(
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> SocketStream:
return await self._open_stream(
trio.open_tcp_stream(hostname, port), hostname, ssl_context, timeout
)
connect_timeout = _or_inf(timeout.connect_timeout)

with trio.move_on_after(connect_timeout) as cancel_scope:
stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
if ssl_context is not None:
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
await stream.do_handshake()

if cancel_scope.cancelled_caught:
raise ConnectTimeout()

return SocketStream(stream=stream, timeout=timeout)

async def open_uds_stream(
self,
path: str,
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> SocketStream:
hostname = hostname if ssl_context else None
return await self._open_stream(
trio.open_unix_socket(path), hostname, ssl_context, timeout
)

async def _open_stream(
self,
socket_stream: typing.Awaitable[trio.SocketStream],
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> SocketStream:
connect_timeout = _or_inf(timeout.connect_timeout)

with trio.move_on_after(connect_timeout) as cancel_scope:
stream: trio.SocketStream = await socket_stream
stream: trio.SocketStream = await trio.open_unix_socket(path)
if ssl_context is not None:
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
await stream.do_handshake()
Expand Down

0 comments on commit 1b0f54c

Please sign in to comment.