Skip to content

Commit

Permalink
Fixed Channel to construct valid authority header when host is the IP…
Browse files Browse the repository at this point in the history
…v6 address
  • Loading branch information
vmagamedov committed Jul 21, 2024
1 parent b98d2a0 commit be6379b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
13 changes: 11 additions & 2 deletions grpclib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import asyncio
import warnings
import ipaddress

from types import TracebackType
from typing import Generic, Optional, Union, Type, List, Sequence, Any, cast
Expand Down Expand Up @@ -683,9 +684,8 @@ def __init__(
self._codec = codec
self._status_details_codec = status_details_codec
self._ssl = ssl or None
self._authority = '{}:{}'.format(self._host, self._port)
self._scheme = 'https' if self._ssl else 'http'
self._authority = '{}:{}'.format(self._host, self._port)
self._authority = self._get_authority(self._host, self._port)
self._h2_config = H2Configuration(
client_side=True,
header_encoding='ascii',
Expand Down Expand Up @@ -779,6 +779,15 @@ def _get_default_ssl_context(
ctx.set_alpn_protocols(['h2'])
return ctx

def _get_authority(self, host, port):
try:
ipv6_address = ipaddress.IPv6Address(host)
except ipaddress.AddressValueError:
pass
else:
host = f"[{ipv6_address}]"
return '{}:{}'.format(host, port)

def request(
self,
name: str,
Expand Down
29 changes: 23 additions & 6 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import socket
import tempfile
import ipaddress

import pytest

Expand Down Expand Up @@ -46,19 +47,27 @@ class ClientServer:
channel = None
channel_ctx = None

def __init__(self, *, host="127.0.0.1"):
self.host = host

async def __aenter__(self):
host = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
_, port = s.getsockname()
try:
ipaddress.IPv6Address(self.host)
except ipaddress.AddressValueError:
family = socket.AF_INET6
else:
family = socket.AF_INET6
with socket.socket(family, socket.SOCK_STREAM) as s:
s.bind((self.host, 0))
_, port, *_ = s.getsockname()

dummy_service = DummyService()

self.server = Server([dummy_service])
await self.server.start(host, port)
await self.server.start(self.host, port)
self.server_ctx = await self.server.__aenter__()

self.channel = Channel(host=host, port=port)
self.channel = Channel(host=self.host, port=port)
self.channel_ctx = await self.channel.__aenter__()
dummy_stub = DummyServiceStub(self.channel)
return dummy_service, dummy_stub
Expand Down Expand Up @@ -211,3 +220,11 @@ async def test_stream_stream_advanced():
assert await stream.recv_message() == DummyReply(value='baz')

assert await stream.recv_message() is None


@pytest.mark.asyncio
async def test_ipv6():
async with ClientServer(host="::1") as (handler, stub):
reply = await stub.UnaryUnary(DummyRequest(value='ping'))
assert reply == DummyReply(value='pong')
assert handler.log == [DummyRequest(value='ping')]

0 comments on commit be6379b

Please sign in to comment.