Skip to content

Commit

Permalink
Honor max concurrent streams (#89)
Browse files Browse the repository at this point in the history
* Use wait_closed with asyncio, with socket unwrapping workaround.

* Fix for Python 3.6, and comments

* Add type: ignore for Python 3.6

* Honor MAX_CONCURRENT_STREAMS

* Drop erronous commit

* Don't release stream concurrency semaphore until *after* network closing the stream

* Don't use bare except
  • Loading branch information
tomchristie authored May 14, 2020
1 parent eddcc69 commit 717da48
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 26 deletions.
45 changes: 32 additions & 13 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from h2.exceptions import NoAvailableStreamIDError
from h2.settings import SettingCodes, Settings

from .._backends.auto import AsyncLock, AsyncSocketStream, AutoBackend
from .._exceptions import ProtocolError
from .._backends.auto import AsyncLock, AsyncSemaphore, AsyncSocketStream, AutoBackend
from .._exceptions import PoolTimeout, ProtocolError
from .._types import URL, Headers, TimeoutDict
from .._utils import get_logger
from .base import (
Expand Down Expand Up @@ -67,6 +67,17 @@ def read_lock(self) -> AsyncLock:
self._read_lock = self.backend.create_lock()
return self._read_lock

@property
def max_streams_semaphore(self) -> AsyncSemaphore:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_max_streams_semaphore"):
max_streams = self.h2_state.remote_settings.max_concurrent_streams
self._max_streams_semaphore = self.backend.create_semaphore(
max_streams, exc_class=PoolTimeout
)
return self._max_streams_semaphore

async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
pass

Expand Down Expand Up @@ -265,16 +276,21 @@ async def request(
b"content-length" in seen_headers or b"transfer-encoding" in seen_headers
)

await self.send_headers(method, url, headers, has_body, timeout)
if has_body:
await self.send_body(stream, timeout)

# Receive the response.
status_code, headers = await self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = AsyncByteStream(
aiterator=self.body_iter(timeout), aclose_func=self._response_closed
)
await self.connection.max_streams_semaphore.acquire()
try:
await self.send_headers(method, url, headers, has_body, timeout)
if has_body:
await self.send_body(stream, timeout)

# Receive the response.
status_code, headers = await self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = AsyncByteStream(
aiterator=self.body_iter(timeout), aclose_func=self._response_closed
)
except Exception:
self.connection.max_streams_semaphore.release()
raise

return (b"HTTP/2", status_code, reason_phrase, headers, stream)

Expand Down Expand Up @@ -346,4 +362,7 @@ async def body_iter(self, timeout: TimeoutDict) -> AsyncIterator[bytes]:
break

async def _response_closed(self) -> None:
await self.connection.close_stream(self.stream_id)
try:
await self.connection.close_stream(self.stream_id)
finally:
self.connection.max_streams_semaphore.release()
45 changes: 32 additions & 13 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from h2.exceptions import NoAvailableStreamIDError
from h2.settings import SettingCodes, Settings

from .._backends.auto import SyncLock, SyncSocketStream, SyncBackend
from .._exceptions import ProtocolError
from .._backends.auto import SyncLock, SyncSemaphore, SyncSocketStream, SyncBackend
from .._exceptions import PoolTimeout, ProtocolError
from .._types import URL, Headers, TimeoutDict
from .._utils import get_logger
from .base import (
Expand Down Expand Up @@ -67,6 +67,17 @@ def read_lock(self) -> SyncLock:
self._read_lock = self.backend.create_lock()
return self._read_lock

@property
def max_streams_semaphore(self) -> SyncSemaphore:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_max_streams_semaphore"):
max_streams = self.h2_state.remote_settings.max_concurrent_streams
self._max_streams_semaphore = self.backend.create_semaphore(
max_streams, exc_class=PoolTimeout
)
return self._max_streams_semaphore

def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
pass

Expand Down Expand Up @@ -265,16 +276,21 @@ def request(
b"content-length" in seen_headers or b"transfer-encoding" in seen_headers
)

self.send_headers(method, url, headers, has_body, timeout)
if has_body:
self.send_body(stream, timeout)

# Receive the response.
status_code, headers = self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = SyncByteStream(
iterator=self.body_iter(timeout), close_func=self._response_closed
)
self.connection.max_streams_semaphore.acquire()
try:
self.send_headers(method, url, headers, has_body, timeout)
if has_body:
self.send_body(stream, timeout)

# Receive the response.
status_code, headers = self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = SyncByteStream(
iterator=self.body_iter(timeout), close_func=self._response_closed
)
except Exception:
self.connection.max_streams_semaphore.release()
raise

return (b"HTTP/2", status_code, reason_phrase, headers, stream)

Expand Down Expand Up @@ -346,4 +362,7 @@ def body_iter(self, timeout: TimeoutDict) -> Iterator[bytes]:
break

def _response_closed(self) -> None:
self.connection.close_stream(self.stream_id)
try:
self.connection.close_stream(self.stream_id)
finally:
self.connection.max_streams_semaphore.release()

0 comments on commit 717da48

Please sign in to comment.