-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for synchronous TLS-in-TLS
connections.
#732
Changes from 50 commits
5d01fac
9015e03
a2ec2da
81bcf16
ab232f9
73e00b5
8cb7324
d3acf7f
52b2822
54f2310
5727312
e31572e
5427555
5f73c52
ae51284
ca37fcc
c4e9b62
9fe3ea6
ac8ee56
f9e0d4b
6db047f
9876701
3fcb610
d87a901
0e1eae4
8e38d22
bb87261
86e5e70
ebd7990
5307e0c
1adc291
f156785
a268f22
53f45c4
358f07f
968ed16
be8dd62
7b9b019
aa784eb
3d952ca
f80e1b0
80c34f7
47a8c2a
2178d27
1aeda05
07818c4
7324ddc
0ff81e3
f68e768
92041d4
9d74b6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
import ssl | ||
import sys | ||
import typing | ||
from functools import partial | ||
from time import perf_counter | ||
|
||
from .._exceptions import ( | ||
ConnectError, | ||
|
@@ -17,6 +19,117 @@ | |
from .base import SOCKET_OPTION, NetworkBackend, NetworkStream | ||
|
||
|
||
class SyncTLSStream(NetworkStream): | ||
""" | ||
Because the standard `SSLContext.wrap_socket` method does | ||
not work for `SSLSocket` objects, we need this class | ||
to implement TLS stream using an underlying `SSLObject` | ||
instance in order to support TLS on top of TLS. | ||
""" | ||
|
||
# Defined in RFC 8449 | ||
TLS_RECORD_SIZE = 16384 | ||
|
||
def __init__( | ||
self, | ||
sock: socket.socket, | ||
ssl_context: ssl.SSLContext, | ||
server_hostname: typing.Optional[str] = None, | ||
timeout: typing.Optional[float] = None, | ||
): | ||
self._sock = sock | ||
self._incoming = ssl.MemoryBIO() | ||
self._outgoing = ssl.MemoryBIO() | ||
|
||
self.ssl_obj = ssl_context.wrap_bio( | ||
incoming=self._incoming, | ||
outgoing=self._outgoing, | ||
server_hostname=server_hostname, | ||
) | ||
|
||
self._perform_io(self.ssl_obj.do_handshake, timeout) | ||
|
||
def _perform_io( | ||
self, | ||
func: typing.Callable[..., typing.Any], | ||
timeout: typing.Optional[float], | ||
) -> typing.Any: | ||
ret = None | ||
karpetrosyan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
timeout = timeout or None # Replaces `0` with `None` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to distinguish between cases where we got 0 after decreasing our timeout and cases where we don't want to handle timeout at all. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if given paramter's value was |
||
|
||
while True: | ||
errno = None | ||
try: | ||
ret = func() | ||
except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e: | ||
errno = e.errno | ||
|
||
if timeout is not None and timeout <= 0: # pragma: no cover | ||
raise socket.timeout() | ||
|
||
self._sock.settimeout(timeout) | ||
operation_start = perf_counter() | ||
self._sock.sendall(self._outgoing.read()) | ||
# If the timeout is `None`, don't touch it. | ||
karpetrosyan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
timeout = timeout and timeout - (perf_counter() - operation_start) | ||
karpetrosyan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if errno == ssl.SSL_ERROR_WANT_READ: | ||
if timeout is not None and timeout <= 0: # pragma: no cover | ||
raise socket.timeout() | ||
|
||
self._sock.settimeout(timeout) | ||
operation_start = perf_counter() | ||
buf = self._sock.recv(self.TLS_RECORD_SIZE) | ||
|
||
# If the timeout is `None`, don't touch it. | ||
karpetrosyan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
timeout = timeout and timeout - (perf_counter() - operation_start) | ||
karpetrosyan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if buf: | ||
self._incoming.write(buf) | ||
else: | ||
self._incoming.write_eof() # pragma: no cover | ||
if errno is None: | ||
return ret | ||
|
||
def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: | ||
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} | ||
with map_exceptions(exc_map): | ||
return typing.cast( | ||
bytes, self._perform_io(partial(self.ssl_obj.read, max_bytes), timeout) | ||
) | ||
|
||
def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: | ||
exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} | ||
with map_exceptions(exc_map): | ||
while buffer: | ||
nsent = self._perform_io(partial(self.ssl_obj.write, buffer), timeout) | ||
buffer = buffer[nsent:] | ||
|
||
def close(self) -> None: | ||
self._sock.close() | ||
|
||
def start_tls( | ||
self, | ||
ssl_context: ssl.SSLContext, | ||
server_hostname: typing.Optional[str] = None, | ||
timeout: typing.Optional[float] = None, | ||
) -> "NetworkStream": | ||
raise NotImplementedError() # pragma: no cover | ||
|
||
def get_extra_info(self, info: str) -> typing.Any: # pragma: no cover | ||
if info == "ssl_object": | ||
return self.ssl_obj | ||
if info == "client_addr": | ||
return self._sock.getsockname() | ||
if info == "server_addr": | ||
return self._sock.getpeername() | ||
if info == "socket": | ||
return self._sock | ||
if info == "is_readable": | ||
return is_socket_readable(self._sock) | ||
return None | ||
|
||
|
||
class SyncStream(NetworkStream): | ||
def __init__(self, sock: socket.socket) -> None: | ||
self._sock = sock | ||
|
@@ -53,10 +166,18 @@ def start_tls( | |
} | ||
with map_exceptions(exc_map): | ||
try: | ||
self._sock.settimeout(timeout) | ||
sock = ssl_context.wrap_socket( | ||
self._sock, server_hostname=server_hostname | ||
) | ||
if isinstance(self._sock, ssl.SSLSocket): | ||
karpetrosyan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# If the underlying socket has already been upgraded | ||
# to the TLS layer (i.e. is an instance of SSLSocket), | ||
# we want to use another stream object that supports TLS-in-TLS. | ||
return SyncTLSStream( | ||
self._sock, ssl_context, server_hostname, timeout | ||
) | ||
else: | ||
self._sock.settimeout(timeout) | ||
sock = ssl_context.wrap_socket( | ||
self._sock, server_hostname=server_hostname | ||
) | ||
except Exception as exc: # pragma: nocover | ||
self.close() | ||
raise exc | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import pytest | ||
|
||
import httpcore | ||
|
||
READ_TIMEOUT = 2 | ||
WRITE_TIMEOUT = 2 | ||
CONNECT_TIMEOUT = 2 | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_connect_without_tls(tcp_server): | ||
backend = httpcore.AnyIOBackend() | ||
stream = await backend.connect_tcp( | ||
tcp_server.host, tcp_server.port, timeout=CONNECT_TIMEOUT | ||
) | ||
await stream.aclose() | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_write_without_tls(tcp_server): | ||
backend = httpcore.AnyIOBackend() | ||
stream = await backend.connect_tcp( | ||
tcp_server.host, tcp_server.port, timeout=CONNECT_TIMEOUT | ||
) | ||
async with stream: | ||
await stream.write(b"ping", timeout=WRITE_TIMEOUT) | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_read_without_tls(tcp_server): | ||
backend = httpcore.AnyIOBackend() | ||
stream = await backend.connect_tcp( | ||
tcp_server.host, tcp_server.port, timeout=CONNECT_TIMEOUT | ||
) | ||
async with stream: | ||
await stream.write(b"ping", timeout=WRITE_TIMEOUT) | ||
await stream.read(1024, timeout=READ_TIMEOUT) | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_connect_with_tls(tls_server, client_context): | ||
backend = httpcore.AnyIOBackend() | ||
stream = await backend.connect_tcp( | ||
tls_server.host, tls_server.port, timeout=CONNECT_TIMEOUT | ||
) | ||
async with stream: | ||
tls_stream = await stream.start_tls( | ||
ssl_context=client_context, timeout=CONNECT_TIMEOUT | ||
) | ||
await tls_stream.aclose() | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_write_with_tls(tls_server, client_context): | ||
backend = httpcore.AnyIOBackend() | ||
stream = await backend.connect_tcp( | ||
tls_server.host, tls_server.port, timeout=CONNECT_TIMEOUT | ||
) | ||
async with stream: | ||
tls_stream = await stream.start_tls( | ||
ssl_context=client_context, timeout=CONNECT_TIMEOUT | ||
) | ||
async with tls_stream: | ||
await tls_stream.write(b"ping", timeout=WRITE_TIMEOUT) | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_read_with_tls(tls_server, client_context): | ||
backend = httpcore.AnyIOBackend() | ||
stream = await backend.connect_tcp( | ||
tls_server.host, tls_server.port, timeout=CONNECT_TIMEOUT | ||
) | ||
async with stream: | ||
tls_stream = await stream.start_tls( | ||
ssl_context=client_context, timeout=CONNECT_TIMEOUT | ||
) | ||
async with tls_stream: | ||
await tls_stream.write(b"ping", timeout=WRITE_TIMEOUT) | ||
await tls_stream.read(1024, timeout=READ_TIMEOUT) | ||
tomchristie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@pytest.mark.anyio | ||
async def test_connect_with_tls_in_tls(tls_in_tls_server, client_context): | ||
backend = httpcore.AnyIOBackend() | ||
stream = await backend.connect_tcp( | ||
tls_in_tls_server.host, tls_in_tls_server.port, timeout=CONNECT_TIMEOUT | ||
) | ||
async with stream: | ||
tls_stream = await stream.start_tls( | ||
ssl_context=client_context, | ||
server_hostname="localhost", | ||
timeout=CONNECT_TIMEOUT, | ||
) | ||
async with tls_stream: | ||
tls_in_tls_stream = await tls_stream.start_tls( | ||
ssl_context=client_context, | ||
server_hostname="localhost", | ||
timeout=CONNECT_TIMEOUT, | ||
) | ||
await tls_in_tls_stream.aclose() | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_write_with_tls_in_tls(tls_in_tls_server, client_context): | ||
backend = httpcore.AnyIOBackend() | ||
stream = await backend.connect_tcp( | ||
tls_in_tls_server.host, tls_in_tls_server.port, timeout=CONNECT_TIMEOUT | ||
) | ||
async with stream: | ||
tls_stream = await stream.start_tls( | ||
ssl_context=client_context, | ||
server_hostname="localhost", | ||
timeout=CONNECT_TIMEOUT, | ||
) | ||
async with tls_stream: | ||
tls_in_tls_stream = await tls_stream.start_tls( | ||
ssl_context=client_context, | ||
server_hostname="localhost", | ||
timeout=CONNECT_TIMEOUT, | ||
) | ||
async with tls_in_tls_stream: | ||
await tls_in_tls_stream.write(b"ping", timeout=WRITE_TIMEOUT) | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_read_with_tls_in_tls(tls_in_tls_server, client_context): | ||
backend = httpcore.AnyIOBackend() | ||
stream = await backend.connect_tcp( | ||
tls_in_tls_server.host, tls_in_tls_server.port, timeout=CONNECT_TIMEOUT | ||
) | ||
async with stream: | ||
tls_stream = await stream.start_tls( | ||
ssl_context=client_context, | ||
server_hostname="localhost", | ||
timeout=CONNECT_TIMEOUT, | ||
) | ||
async with tls_stream: | ||
tls_in_tls_stream = await tls_stream.start_tls( | ||
ssl_context=client_context, | ||
server_hostname="localhost", | ||
timeout=CONNECT_TIMEOUT, | ||
) | ||
async with tls_in_tls_stream: | ||
await tls_in_tls_stream.write(b"ping", timeout=WRITE_TIMEOUT) | ||
await tls_in_tls_stream.read(1024, timeout=READ_TIMEOUT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need the
typing_extensions.Self
here. (?)Can we just have this return
NetworkStream
.The override point is
close()
, not the__enter__
/__exit__
which will stay the same even for subclasses.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can inherit, I believe we should always use Self to avoid strange type issues, such as when the instance of SlowNetworkStream is NetworkStream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Example:
OUTPUT test.py:11: note: Revealed type is "httpcore._backends.base.NetworkStream"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay right.
Could we use the
TypeVar
style, then?Eg... in
httpx
https://github.com/encode/httpx/blob/76c9cb65f2a159adb764c2236d139f85b46e1506/httpx/_client.py#L60
https://github.com/encode/httpx/blob/76c9cb65f2a159adb764c2236d139f85b46e1506/httpx/_client.py#L1263
Really prefer us avoiding introducing new third party packages wherever possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://github.com/encode/httpcore/blob/e31572e0371557d6163d2b7c28676e2b1727673b/httpcore/_backends/base.py
Initially, we used TypeVar, but it was too complicated.