Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for synchronous TLS-in-TLS connections. #732

Closed
Closed
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
5d01fac
Add tests for the backends, without TLS
karpetrosyan Jun 16, 2023
9015e03
Rename tcp connection test
karpetrosyan Jun 16, 2023
a2ec2da
Rename tcp connection test
karpetrosyan Jun 16, 2023
81bcf16
Add write tests without tls
karpetrosyan Jun 16, 2023
ab232f9
Add read tests without tls
karpetrosyan Jun 16, 2023
73e00b5
Remove unneeded assertions
karpetrosyan Jun 16, 2023
8cb7324
Add connect, read and write tests for tls connections
karpetrosyan Jun 19, 2023
d3acf7f
Add context manager support for stream classes
karpetrosyan Jun 19, 2023
52b2822
Use context manager instead of try/expect
karpetrosyan Jun 19, 2023
54f2310
Remove .idea
karpetrosyan Jun 19, 2023
5727312
Use pytest-httpbin certificates instead of disable verifying
karpetrosyan Jun 20, 2023
e31572e
Add timeouts for the read/write/connect tests
karpetrosyan Jun 20, 2023
5427555
Use typing_extensions Self
karpetrosyan Jun 27, 2023
5f73c52
Use sockets instead of httpbin
karpetrosyan Jun 30, 2023
ae51284
Add trustme into the requirements
karpetrosyan Jun 30, 2023
ca37fcc
Improve conftest.py
karpetrosyan Jun 30, 2023
c4e9b62
Handle BrokenPipeError
karpetrosyan Jul 1, 2023
9fe3ea6
Add tls_in_tls server
karpetrosyan Jul 3, 2023
ac8ee56
Add failing tests for tls_in_tls
karpetrosyan Jul 3, 2023
f9e0d4b
Lint
karpetrosyan Jul 3, 2023
6db047f
Suppress unexpected message error
karpetrosyan Jul 3, 2023
9876701
Suppress SSL_ERROR_ZERO_RETURN
karpetrosyan Jul 3, 2023
3fcb610
Use ssl error constants instead of the hard coded values
karpetrosyan Jul 3, 2023
d87a901
Lint
karpetrosyan Jul 3, 2023
0e1eae4
Instead of listing all possible exceptions, use OSError
karpetrosyan Jul 3, 2023
8e38d22
Add TLS-in-TLS implementation
karpetrosyan Jul 4, 2023
bb87261
Merge branch 'master' into add-tests-for-network-backends
tomchristie Jul 4, 2023
86e5e70
Add pragma: no cover for timeout cases
karpetrosyan Jul 4, 2023
ebd7990
improve SyncTLSStream
T-256 Jul 5, 2023
5307e0c
typo
T-256 Jul 5, 2023
1adc291
paranthes
T-256 Jul 5, 2023
f156785
Update httpcore/_backends/sync.py
T-256 Jul 6, 2023
a268f22
drop setter
T-256 Jul 6, 2023
53f45c4
Merge pull request #1 from T-256/patch-2
karpetrosyan Jul 6, 2023
358f07f
Linting
karpetrosyan Jul 6, 2023
968ed16
Add docstrings
karpetrosyan Jul 6, 2023
be8dd62
Drop OverallTimeout class
karpetrosyan Jul 6, 2023
7b9b019
Drop OverallTimeout class
karpetrosyan Jul 6, 2023
aa784eb
Merge branch 'master' into add-tests-for-network-backends
karpetrosyan Jul 6, 2023
3d952ca
Update httpcore/_backends/sync.py
karpetrosyan Jul 6, 2023
f80e1b0
Fix timeout, add TLS_RECORD_SIZE
karpetrosyan Jul 7, 2023
80c34f7
Typo
karpetrosyan Jul 7, 2023
47a8c2a
Move socket timeout raise
karpetrosyan Jul 7, 2023
2178d27
Replace 0 with None
karpetrosyan Jul 7, 2023
1aeda05
Merge branch 'master' into add-tests-for-network-backends
karpetrosyan Jul 8, 2023
07818c4
Merge branch 'master' into add-tests-for-network-backends
karpetrosyan Jul 12, 2023
7324ddc
Add changelog
karpetrosyan Jul 13, 2023
0ff81e3
Merge branch 'master' into add-tests-for-network-backends
karpetrosyan Jul 14, 2023
f68e768
Merge branch 'master' into add-tests-for-network-backends
karpetrosyan Jul 31, 2023
92041d4
Merge branch 'master' into add-tests-for-network-backends
karpetrosyan Aug 8, 2023
9d74b6a
Merge branch 'master' into add-tests-for-network-backends
karpetrosyan Aug 25, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## Unreleased

- Add support for synchronous **TLS-in-TLS** connections. (#732)
- Change the type of `Extensions` from `Mapping[Str, Any]` to `MutableMapping[Str, Any]`. (#762)
- Handle HTTP/1.1 half-closed connections gracefully. (#641)

Expand Down
28 changes: 28 additions & 0 deletions httpcore/_backends/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import ssl
import time
import typing
from types import TracebackType

if typing.TYPE_CHECKING: # pragma: no cover
from typing_extensions import Self

SOCKET_OPTION = typing.Union[
typing.Tuple[int, int, int],
Expand All @@ -16,6 +20,17 @@ def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
raise NotImplementedError() # pragma: nocover

def __enter__(self) -> "Self":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need the typing_extensions.Self here. (?)
Can we just have this return NetworkStream.

The override point is close(), not the __enter__/__exit__ which will stay the same even for subclasses.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can inherit, I believe we should always use Self to avoid strange type issues, such as when the instance of SlowNetworkStream is NetworkStream.

Copy link
Member Author

@karpetrosyan karpetrosyan Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example:

import typing
from httpcore._backends.base import NetworkStream
from time import sleep

class SlowNetworkStream(NetworkStream):
    
    def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
        sleep(100)

with SlowNetworkStream() as stream:
    reveal_type(stream)

OUTPUT test.py:11: note: Revealed type is "httpcore._backends.base.NetworkStream"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay right.
Could we use the TypeVar style, then?

Eg... in httpx

https://github.com/encode/httpx/blob/76c9cb65f2a159adb764c2236d139f85b46e1506/httpx/_client.py#L60
https://github.com/encode/httpx/blob/76c9cb65f2a159adb764c2236d139f85b46e1506/httpx/_client.py#L1263

Really prefer us avoiding introducing new third party packages wherever possible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return self

def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
) -> None:
self.close()

def close(self) -> None:
raise NotImplementedError() # pragma: nocover

Expand Down Expand Up @@ -65,6 +80,19 @@ async def write(
) -> None:
raise NotImplementedError() # pragma: nocover

async def __aenter__(
self,
) -> "Self":
return self

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
) -> None:
await self.aclose()

async def aclose(self) -> None:
raise NotImplementedError() # pragma: nocover

Expand Down
129 changes: 125 additions & 4 deletions httpcore/_backends/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import ssl
import sys
import typing
from functools import partial
from time import perf_counter

from .._exceptions import (
ConnectError,
Expand All @@ -17,6 +19,117 @@
from .base import SOCKET_OPTION, NetworkBackend, NetworkStream


class SyncTLSStream(NetworkStream):
"""
Because the standard `SSLContext.wrap_socket` method does
not work for `SSLSocket` objects, we need this class
to implement TLS stream using an underlying `SSLObject`
instance in order to support TLS on top of TLS.
"""

# Defined in RFC 8449
TLS_RECORD_SIZE = 16384

def __init__(
self,
sock: socket.socket,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
):
self._sock = sock
self._incoming = ssl.MemoryBIO()
self._outgoing = ssl.MemoryBIO()

self.ssl_obj = ssl_context.wrap_bio(
incoming=self._incoming,
outgoing=self._outgoing,
server_hostname=server_hostname,
)

self._perform_io(self.ssl_obj.do_handshake, timeout)

def _perform_io(
self,
func: typing.Callable[..., typing.Any],
timeout: typing.Optional[float],
) -> typing.Any:
ret = None
karpetrosyan marked this conversation as resolved.
Show resolved Hide resolved
timeout = timeout or None # Replaces `0` with `None`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to distinguish between cases where we got 0 after decreasing our timeout and cases where we don't want to handle timeout at all.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if given paramter's value was 0.0, it replaces it with None to avoid socket.timeout at here


while True:
errno = None
try:
ret = func()
except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
errno = e.errno

if timeout is not None and timeout <= 0: # pragma: no cover
raise socket.timeout()

self._sock.settimeout(timeout)
operation_start = perf_counter()
self._sock.sendall(self._outgoing.read())
# If the timeout is `None`, don't touch it.
karpetrosyan marked this conversation as resolved.
Show resolved Hide resolved
timeout = timeout and timeout - (perf_counter() - operation_start)
karpetrosyan marked this conversation as resolved.
Show resolved Hide resolved

if errno == ssl.SSL_ERROR_WANT_READ:
if timeout is not None and timeout <= 0: # pragma: no cover
raise socket.timeout()

self._sock.settimeout(timeout)
operation_start = perf_counter()
buf = self._sock.recv(self.TLS_RECORD_SIZE)

# If the timeout is `None`, don't touch it.
karpetrosyan marked this conversation as resolved.
Show resolved Hide resolved
timeout = timeout and timeout - (perf_counter() - operation_start)
karpetrosyan marked this conversation as resolved.
Show resolved Hide resolved

if buf:
self._incoming.write(buf)
else:
self._incoming.write_eof() # pragma: no cover
if errno is None:
return ret

def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
with map_exceptions(exc_map):
return typing.cast(
bytes, self._perform_io(partial(self.ssl_obj.read, max_bytes), timeout)
)

def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
with map_exceptions(exc_map):
while buffer:
nsent = self._perform_io(partial(self.ssl_obj.write, buffer), timeout)
buffer = buffer[nsent:]

def close(self) -> None:
self._sock.close()

def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> "NetworkStream":
raise NotImplementedError() # pragma: no cover

def get_extra_info(self, info: str) -> typing.Any: # pragma: no cover
if info == "ssl_object":
return self.ssl_obj
if info == "client_addr":
return self._sock.getsockname()
if info == "server_addr":
return self._sock.getpeername()
if info == "socket":
return self._sock
if info == "is_readable":
return is_socket_readable(self._sock)
return None


class SyncStream(NetworkStream):
def __init__(self, sock: socket.socket) -> None:
self._sock = sock
Expand Down Expand Up @@ -53,10 +166,18 @@ def start_tls(
}
with map_exceptions(exc_map):
try:
self._sock.settimeout(timeout)
sock = ssl_context.wrap_socket(
self._sock, server_hostname=server_hostname
)
if isinstance(self._sock, ssl.SSLSocket):
karpetrosyan marked this conversation as resolved.
Show resolved Hide resolved
# If the underlying socket has already been upgraded
# to the TLS layer (i.e. is an instance of SSLSocket),
# we want to use another stream object that supports TLS-in-TLS.
return SyncTLSStream(
self._sock, ssl_context, server_hostname, timeout
)
else:
self._sock.settimeout(timeout)
sock = ssl_context.wrap_socket(
self._sock, server_hostname=server_hostname
)
except Exception as exc: # pragma: nocover
self.close()
raise exc
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ black==23.3.0
coverage[toml]==7.2.7
ruff==0.0.277
mypy==1.4.1
trustme==1.0.0
typing-extensions==4.6.3
trio-typing==0.8.0
types-certifi==2021.10.8.3
pytest==7.4.0
Expand Down
145 changes: 145 additions & 0 deletions tests/_backends/test_anyio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import pytest

import httpcore

READ_TIMEOUT = 2
WRITE_TIMEOUT = 2
CONNECT_TIMEOUT = 2


@pytest.mark.anyio
async def test_connect_without_tls(tcp_server):
backend = httpcore.AnyIOBackend()
stream = await backend.connect_tcp(
tcp_server.host, tcp_server.port, timeout=CONNECT_TIMEOUT
)
await stream.aclose()


@pytest.mark.anyio
async def test_write_without_tls(tcp_server):
backend = httpcore.AnyIOBackend()
stream = await backend.connect_tcp(
tcp_server.host, tcp_server.port, timeout=CONNECT_TIMEOUT
)
async with stream:
await stream.write(b"ping", timeout=WRITE_TIMEOUT)


@pytest.mark.anyio
async def test_read_without_tls(tcp_server):
backend = httpcore.AnyIOBackend()
stream = await backend.connect_tcp(
tcp_server.host, tcp_server.port, timeout=CONNECT_TIMEOUT
)
async with stream:
await stream.write(b"ping", timeout=WRITE_TIMEOUT)
await stream.read(1024, timeout=READ_TIMEOUT)


@pytest.mark.anyio
async def test_connect_with_tls(tls_server, client_context):
backend = httpcore.AnyIOBackend()
stream = await backend.connect_tcp(
tls_server.host, tls_server.port, timeout=CONNECT_TIMEOUT
)
async with stream:
tls_stream = await stream.start_tls(
ssl_context=client_context, timeout=CONNECT_TIMEOUT
)
await tls_stream.aclose()


@pytest.mark.anyio
async def test_write_with_tls(tls_server, client_context):
backend = httpcore.AnyIOBackend()
stream = await backend.connect_tcp(
tls_server.host, tls_server.port, timeout=CONNECT_TIMEOUT
)
async with stream:
tls_stream = await stream.start_tls(
ssl_context=client_context, timeout=CONNECT_TIMEOUT
)
async with tls_stream:
await tls_stream.write(b"ping", timeout=WRITE_TIMEOUT)


@pytest.mark.anyio
async def test_read_with_tls(tls_server, client_context):
backend = httpcore.AnyIOBackend()
stream = await backend.connect_tcp(
tls_server.host, tls_server.port, timeout=CONNECT_TIMEOUT
)
async with stream:
tls_stream = await stream.start_tls(
ssl_context=client_context, timeout=CONNECT_TIMEOUT
)
async with tls_stream:
await tls_stream.write(b"ping", timeout=WRITE_TIMEOUT)
await tls_stream.read(1024, timeout=READ_TIMEOUT)
tomchristie marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.anyio
async def test_connect_with_tls_in_tls(tls_in_tls_server, client_context):
backend = httpcore.AnyIOBackend()
stream = await backend.connect_tcp(
tls_in_tls_server.host, tls_in_tls_server.port, timeout=CONNECT_TIMEOUT
)
async with stream:
tls_stream = await stream.start_tls(
ssl_context=client_context,
server_hostname="localhost",
timeout=CONNECT_TIMEOUT,
)
async with tls_stream:
tls_in_tls_stream = await tls_stream.start_tls(
ssl_context=client_context,
server_hostname="localhost",
timeout=CONNECT_TIMEOUT,
)
await tls_in_tls_stream.aclose()


@pytest.mark.anyio
async def test_write_with_tls_in_tls(tls_in_tls_server, client_context):
backend = httpcore.AnyIOBackend()
stream = await backend.connect_tcp(
tls_in_tls_server.host, tls_in_tls_server.port, timeout=CONNECT_TIMEOUT
)
async with stream:
tls_stream = await stream.start_tls(
ssl_context=client_context,
server_hostname="localhost",
timeout=CONNECT_TIMEOUT,
)
async with tls_stream:
tls_in_tls_stream = await tls_stream.start_tls(
ssl_context=client_context,
server_hostname="localhost",
timeout=CONNECT_TIMEOUT,
)
async with tls_in_tls_stream:
await tls_in_tls_stream.write(b"ping", timeout=WRITE_TIMEOUT)


@pytest.mark.anyio
async def test_read_with_tls_in_tls(tls_in_tls_server, client_context):
backend = httpcore.AnyIOBackend()
stream = await backend.connect_tcp(
tls_in_tls_server.host, tls_in_tls_server.port, timeout=CONNECT_TIMEOUT
)
async with stream:
tls_stream = await stream.start_tls(
ssl_context=client_context,
server_hostname="localhost",
timeout=CONNECT_TIMEOUT,
)
async with tls_stream:
tls_in_tls_stream = await tls_stream.start_tls(
ssl_context=client_context,
server_hostname="localhost",
timeout=CONNECT_TIMEOUT,
)
async with tls_in_tls_stream:
await tls_in_tls_stream.write(b"ping", timeout=WRITE_TIMEOUT)
await tls_in_tls_stream.read(1024, timeout=READ_TIMEOUT)
Loading