diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index c1b0a5ce..c3c5277c 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -72,7 +72,7 @@ def max_streams_semaphore(self) -> AsyncSemaphore: # We do this lazily, to make sure backend autodetection always # runs within an async context. if not hasattr(self, "_max_streams_semaphore"): - max_streams = self.h2_state.remote_settings.max_concurrent_streams + max_streams = self.h2_state.local_settings.max_concurrent_streams self._max_streams_semaphore = self.backend.create_semaphore( max_streams, exc_class=PoolTimeout ) @@ -102,6 +102,8 @@ async def request( await self.send_connection_init(timeout) self.sent_connection_init = True + await self.max_streams_semaphore.acquire() + try: try: stream_id = self.h2_state.get_next_available_stream_id() except NoAvailableStreamIDError: @@ -110,10 +112,13 @@ async def request( else: self.state = ConnectionState.ACTIVE - h2_stream = AsyncHTTP2Stream(stream_id=stream_id, connection=self) - self.streams[stream_id] = h2_stream - self.events[stream_id] = [] - return await h2_stream.request(method, url, headers, stream, timeout) + h2_stream = AsyncHTTP2Stream(stream_id=stream_id, connection=self) + self.streams[stream_id] = h2_stream + self.events[stream_id] = [] + return await h2_stream.request(method, url, headers, stream, timeout) + except Exception: + self.max_streams_semaphore.release() + raise async def send_connection_init(self, timeout: TimeoutDict) -> None: """ @@ -242,15 +247,18 @@ async def acknowledge_received_data( await self.socket.write(data_to_send, timeout) async def close_stream(self, stream_id: int) -> None: - logger.trace("close_stream stream_id=%r", stream_id) - del self.streams[stream_id] - del self.events[stream_id] - - if not self.streams: - if self.state == ConnectionState.ACTIVE: - self.state = ConnectionState.IDLE - elif self.state == ConnectionState.FULL: - await self.aclose() + try: + logger.trace("close_stream stream_id=%r", stream_id) + del self.streams[stream_id] + del self.events[stream_id] + + if not self.streams: + if self.state == ConnectionState.ACTIVE: + self.state = ConnectionState.IDLE + elif self.state == ConnectionState.FULL: + await self.aclose() + finally: + self.max_streams_semaphore.release() class AsyncHTTP2Stream: @@ -276,21 +284,16 @@ async def request( b"content-length" in seen_headers or b"transfer-encoding" in seen_headers ) - await self.connection.max_streams_semaphore.acquire() - try: - await self.send_headers(method, url, headers, has_body, timeout) - if has_body: - await self.send_body(stream, timeout) - - # Receive the response. - status_code, headers = await self.receive_response(timeout) - reason_phrase = get_reason_phrase(status_code) - stream = AsyncByteStream( - aiterator=self.body_iter(timeout), aclose_func=self._response_closed - ) - except Exception: - self.connection.max_streams_semaphore.release() - raise + await self.send_headers(method, url, headers, has_body, timeout) + if has_body: + await self.send_body(stream, timeout) + + # Receive the response. + status_code, headers = await self.receive_response(timeout) + reason_phrase = get_reason_phrase(status_code) + stream = AsyncByteStream( + aiterator=self.body_iter(timeout), aclose_func=self._response_closed + ) return (b"HTTP/2", status_code, reason_phrase, headers, stream) @@ -362,7 +365,4 @@ async def body_iter(self, timeout: TimeoutDict) -> AsyncIterator[bytes]: break async def _response_closed(self) -> None: - try: - await self.connection.close_stream(self.stream_id) - finally: - self.connection.max_streams_semaphore.release() + await self.connection.close_stream(self.stream_id) diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 35213600..e12c92bf 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -72,7 +72,7 @@ def max_streams_semaphore(self) -> SyncSemaphore: # We do this lazily, to make sure backend autodetection always # runs within an async context. if not hasattr(self, "_max_streams_semaphore"): - max_streams = self.h2_state.remote_settings.max_concurrent_streams + max_streams = self.h2_state.local_settings.max_concurrent_streams self._max_streams_semaphore = self.backend.create_semaphore( max_streams, exc_class=PoolTimeout ) @@ -102,6 +102,8 @@ def request( self.send_connection_init(timeout) self.sent_connection_init = True + self.max_streams_semaphore.acquire() + try: try: stream_id = self.h2_state.get_next_available_stream_id() except NoAvailableStreamIDError: @@ -110,10 +112,13 @@ def request( else: self.state = ConnectionState.ACTIVE - h2_stream = SyncHTTP2Stream(stream_id=stream_id, connection=self) - self.streams[stream_id] = h2_stream - self.events[stream_id] = [] - return h2_stream.request(method, url, headers, stream, timeout) + h2_stream = SyncHTTP2Stream(stream_id=stream_id, connection=self) + self.streams[stream_id] = h2_stream + self.events[stream_id] = [] + return h2_stream.request(method, url, headers, stream, timeout) + except Exception: + self.max_streams_semaphore.release() + raise def send_connection_init(self, timeout: TimeoutDict) -> None: """ @@ -242,15 +247,18 @@ def acknowledge_received_data( self.socket.write(data_to_send, timeout) def close_stream(self, stream_id: int) -> None: - logger.trace("close_stream stream_id=%r", stream_id) - del self.streams[stream_id] - del self.events[stream_id] - - if not self.streams: - if self.state == ConnectionState.ACTIVE: - self.state = ConnectionState.IDLE - elif self.state == ConnectionState.FULL: - self.close() + try: + logger.trace("close_stream stream_id=%r", stream_id) + del self.streams[stream_id] + del self.events[stream_id] + + if not self.streams: + if self.state == ConnectionState.ACTIVE: + self.state = ConnectionState.IDLE + elif self.state == ConnectionState.FULL: + self.close() + finally: + self.max_streams_semaphore.release() class SyncHTTP2Stream: @@ -276,21 +284,16 @@ def request( b"content-length" in seen_headers or b"transfer-encoding" in seen_headers ) - self.connection.max_streams_semaphore.acquire() - try: - self.send_headers(method, url, headers, has_body, timeout) - if has_body: - self.send_body(stream, timeout) - - # Receive the response. - status_code, headers = self.receive_response(timeout) - reason_phrase = get_reason_phrase(status_code) - stream = SyncByteStream( - iterator=self.body_iter(timeout), close_func=self._response_closed - ) - except Exception: - self.connection.max_streams_semaphore.release() - raise + self.send_headers(method, url, headers, has_body, timeout) + if has_body: + self.send_body(stream, timeout) + + # Receive the response. + status_code, headers = self.receive_response(timeout) + reason_phrase = get_reason_phrase(status_code) + stream = SyncByteStream( + iterator=self.body_iter(timeout), close_func=self._response_closed + ) return (b"HTTP/2", status_code, reason_phrase, headers, stream) @@ -362,7 +365,4 @@ def body_iter(self, timeout: TimeoutDict) -> Iterator[bytes]: break def _response_closed(self) -> None: - try: - self.connection.close_stream(self.stream_id) - finally: - self.connection.max_streams_semaphore.release() + self.connection.close_stream(self.stream_id)