Skip to content

Commit

Permalink
Max concurrent stream improvements (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomchristie authored May 14, 2020
1 parent 717da48 commit 30847a0
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 66 deletions.
66 changes: 33 additions & 33 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def max_streams_semaphore(self) -> AsyncSemaphore:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_max_streams_semaphore"):
max_streams = self.h2_state.remote_settings.max_concurrent_streams
max_streams = self.h2_state.local_settings.max_concurrent_streams
self._max_streams_semaphore = self.backend.create_semaphore(
max_streams, exc_class=PoolTimeout
)
Expand Down Expand Up @@ -102,6 +102,8 @@ async def request(
await self.send_connection_init(timeout)
self.sent_connection_init = True

await self.max_streams_semaphore.acquire()
try:
try:
stream_id = self.h2_state.get_next_available_stream_id()
except NoAvailableStreamIDError:
Expand All @@ -110,10 +112,13 @@ async def request(
else:
self.state = ConnectionState.ACTIVE

h2_stream = AsyncHTTP2Stream(stream_id=stream_id, connection=self)
self.streams[stream_id] = h2_stream
self.events[stream_id] = []
return await h2_stream.request(method, url, headers, stream, timeout)
h2_stream = AsyncHTTP2Stream(stream_id=stream_id, connection=self)
self.streams[stream_id] = h2_stream
self.events[stream_id] = []
return await h2_stream.request(method, url, headers, stream, timeout)
except Exception:
self.max_streams_semaphore.release()
raise

async def send_connection_init(self, timeout: TimeoutDict) -> None:
"""
Expand Down Expand Up @@ -242,15 +247,18 @@ async def acknowledge_received_data(
await self.socket.write(data_to_send, timeout)

async def close_stream(self, stream_id: int) -> None:
logger.trace("close_stream stream_id=%r", stream_id)
del self.streams[stream_id]
del self.events[stream_id]

if not self.streams:
if self.state == ConnectionState.ACTIVE:
self.state = ConnectionState.IDLE
elif self.state == ConnectionState.FULL:
await self.aclose()
try:
logger.trace("close_stream stream_id=%r", stream_id)
del self.streams[stream_id]
del self.events[stream_id]

if not self.streams:
if self.state == ConnectionState.ACTIVE:
self.state = ConnectionState.IDLE
elif self.state == ConnectionState.FULL:
await self.aclose()
finally:
self.max_streams_semaphore.release()


class AsyncHTTP2Stream:
Expand All @@ -276,21 +284,16 @@ async def request(
b"content-length" in seen_headers or b"transfer-encoding" in seen_headers
)

await self.connection.max_streams_semaphore.acquire()
try:
await self.send_headers(method, url, headers, has_body, timeout)
if has_body:
await self.send_body(stream, timeout)

# Receive the response.
status_code, headers = await self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = AsyncByteStream(
aiterator=self.body_iter(timeout), aclose_func=self._response_closed
)
except Exception:
self.connection.max_streams_semaphore.release()
raise
await self.send_headers(method, url, headers, has_body, timeout)
if has_body:
await self.send_body(stream, timeout)

# Receive the response.
status_code, headers = await self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = AsyncByteStream(
aiterator=self.body_iter(timeout), aclose_func=self._response_closed
)

return (b"HTTP/2", status_code, reason_phrase, headers, stream)

Expand Down Expand Up @@ -362,7 +365,4 @@ async def body_iter(self, timeout: TimeoutDict) -> AsyncIterator[bytes]:
break

async def _response_closed(self) -> None:
try:
await self.connection.close_stream(self.stream_id)
finally:
self.connection.max_streams_semaphore.release()
await self.connection.close_stream(self.stream_id)
66 changes: 33 additions & 33 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def max_streams_semaphore(self) -> SyncSemaphore:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_max_streams_semaphore"):
max_streams = self.h2_state.remote_settings.max_concurrent_streams
max_streams = self.h2_state.local_settings.max_concurrent_streams
self._max_streams_semaphore = self.backend.create_semaphore(
max_streams, exc_class=PoolTimeout
)
Expand Down Expand Up @@ -102,6 +102,8 @@ def request(
self.send_connection_init(timeout)
self.sent_connection_init = True

self.max_streams_semaphore.acquire()
try:
try:
stream_id = self.h2_state.get_next_available_stream_id()
except NoAvailableStreamIDError:
Expand All @@ -110,10 +112,13 @@ def request(
else:
self.state = ConnectionState.ACTIVE

h2_stream = SyncHTTP2Stream(stream_id=stream_id, connection=self)
self.streams[stream_id] = h2_stream
self.events[stream_id] = []
return h2_stream.request(method, url, headers, stream, timeout)
h2_stream = SyncHTTP2Stream(stream_id=stream_id, connection=self)
self.streams[stream_id] = h2_stream
self.events[stream_id] = []
return h2_stream.request(method, url, headers, stream, timeout)
except Exception:
self.max_streams_semaphore.release()
raise

def send_connection_init(self, timeout: TimeoutDict) -> None:
"""
Expand Down Expand Up @@ -242,15 +247,18 @@ def acknowledge_received_data(
self.socket.write(data_to_send, timeout)

def close_stream(self, stream_id: int) -> None:
logger.trace("close_stream stream_id=%r", stream_id)
del self.streams[stream_id]
del self.events[stream_id]

if not self.streams:
if self.state == ConnectionState.ACTIVE:
self.state = ConnectionState.IDLE
elif self.state == ConnectionState.FULL:
self.close()
try:
logger.trace("close_stream stream_id=%r", stream_id)
del self.streams[stream_id]
del self.events[stream_id]

if not self.streams:
if self.state == ConnectionState.ACTIVE:
self.state = ConnectionState.IDLE
elif self.state == ConnectionState.FULL:
self.close()
finally:
self.max_streams_semaphore.release()


class SyncHTTP2Stream:
Expand All @@ -276,21 +284,16 @@ def request(
b"content-length" in seen_headers or b"transfer-encoding" in seen_headers
)

self.connection.max_streams_semaphore.acquire()
try:
self.send_headers(method, url, headers, has_body, timeout)
if has_body:
self.send_body(stream, timeout)

# Receive the response.
status_code, headers = self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = SyncByteStream(
iterator=self.body_iter(timeout), close_func=self._response_closed
)
except Exception:
self.connection.max_streams_semaphore.release()
raise
self.send_headers(method, url, headers, has_body, timeout)
if has_body:
self.send_body(stream, timeout)

# Receive the response.
status_code, headers = self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = SyncByteStream(
iterator=self.body_iter(timeout), close_func=self._response_closed
)

return (b"HTTP/2", status_code, reason_phrase, headers, stream)

Expand Down Expand Up @@ -362,7 +365,4 @@ def body_iter(self, timeout: TimeoutDict) -> Iterator[bytes]:
break

def _response_closed(self) -> None:
try:
self.connection.close_stream(self.stream_id)
finally:
self.connection.max_streams_semaphore.release()
self.connection.close_stream(self.stream_id)

0 comments on commit 30847a0

Please sign in to comment.