From e0818b49cbb30ccbd029bc0c565827b9d32668fd Mon Sep 17 00:00:00 2001 From: An Long Date: Tue, 26 Jan 2021 21:23:54 +0800 Subject: [PATCH] Fix the path issue caused by #148 --- tests/test_http.py | 20 ++++++++++++++++++-- thriftpy2/http.py | 8 +++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/test_http.py b/tests/test_http.py index 032eabd6..5d27cb68 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -117,12 +117,23 @@ def client_context_with_url(timeout=3000): url="http://127.0.0.1:6080", timeout=timeout) +def client_context_with_malformed_path(timeout=3000): + return client_context(addressbook.AddressBookService, host="127.0.0.1", + port=6080, path="foo", timeout=timeout) + + def client_with_url(timeout=3000): return make_client(addressbook.AddressBookService, url="http://127.0.0.1:6080", timeout=timeout) def client_without_url(timeout=3000): + return make_client(addressbook.AddressBookService, host="127.0.0.1", + port=6080, path="/foo", timeout=timeout) + + +@pytest.fixture +def client_with_malformed_path(timeout=3000): return make_client(addressbook.AddressBookService, host="127.0.0.1", port=6080, path="foo", timeout=timeout) @@ -157,8 +168,9 @@ def client_with_custom_header_factory(timeout=3000): def test_client_context(server): - with client() as c1, client_context_with_url() as c2: - assert c1.hello("world") == c2.hello("world") + with client() as c1, client_context_with_url() as c2,\ + client_context_with_malformed_path() as c3: + assert c1.hello("world") == c2.hello("world") == c3.hello("world") def test_clients(server): @@ -173,6 +185,10 @@ def test_clients_without_url(server): assert c.hello("world") == "hello world" +def test_client_with_malformed_path(client_with_malformed_path): + assert client_with_malformed_path.hello("world") == "hello world" + + def test_client_context_with_header_factory(server): with client_context_with_header_factory() as c: assert c.hello("world") == "hello world" diff --git a/thriftpy2/http.py b/thriftpy2/http.py index 887018aa..8a4340e0 100644 --- a/thriftpy2/http.py +++ b/thriftpy2/http.py @@ -60,7 +60,7 @@ from thriftpy2.transport import TBufferedTransportFactory -HTTP_URI = '{scheme}://{host}:{port}/{path}' +HTTP_URI = '{scheme}://{host}:{port}{path}' DEFAULT_HTTP_CLIENT_TIMEOUT_MS = 30000 # 30 seconds @@ -306,6 +306,9 @@ def make_client(service, host='localhost', port=9090, path='', scheme='http', port = parsed_url.port or port scheme = parsed_url.scheme or scheme path = parsed_url.path or path + if path and path[0] != "/": + # path should have `/` prefix, but we can make a compatible here. + path = "/" + path uri = HTTP_URI.format(scheme=scheme, host=host, port=port, path=path) http_socket = THttpClient(uri, timeout, ssl_context_factory, http_header_factory) transport = trans_factory.get_transport(http_socket) @@ -327,6 +330,9 @@ def client_context(service, host='localhost', port=9090, path='', scheme='http', port = parsed_url.port or port scheme = parsed_url.scheme or scheme path = parsed_url.path or path + if path and path[0] != "/": + # path should have `/` prefix, but we can make a compatible here. + path = "/" + path uri = HTTP_URI.format(scheme=scheme, host=host, port=port, path=path) http_socket = THttpClient(uri, timeout, ssl_context_factory, http_header_factory) transport = trans_factory.get_transport(http_socket)