diff --git a/appium/webdriver/appium_connection.py b/appium/webdriver/appium_connection.py index 34cf5709..be7c7058 100644 --- a/appium/webdriver/appium_connection.py +++ b/appium/webdriver/appium_connection.py @@ -25,20 +25,14 @@ PREFIX_HEADER = 'appium/' -_HEADER_IDEMOTENCY_KEY = 'X-Idempotency-Key' +HEADER_IDEMOTENCY_KEY = 'X-Idempotency-Key' def _get_new_headers(key: str, headers: Dict[str, str]) -> Dict[str, str]: """Return a new dictionary of heafers without the given key. The key match is case-insensitive.""" - new_headers = dict() - key = key.lower() - for k, v in headers.items(): - if k.lower() == key: - continue - new_headers[k] = v - return new_headers + return {k: v for k, v in headers.items() if k.lower() != key} class AppiumConnection(RemoteConnection): @@ -63,8 +57,8 @@ def get_remote_connection_headers(cls, parsed_url: 'ParseResult', keep_alive: bo if parsed_url.path.endswith('/session'): # https://github.com/appium/appium-base-driver/pull/400 - cls.extra_headers[_HEADER_IDEMOTENCY_KEY] = str(uuid.uuid4()) + cls.extra_headers[HEADER_IDEMOTENCY_KEY] = str(uuid.uuid4()) else: - cls.extra_headers = _get_new_headers(_HEADER_IDEMOTENCY_KEY, cls.extra_headers) + cls.extra_headers = _get_new_headers(HEADER_IDEMOTENCY_KEY, cls.extra_headers) return {**super().get_remote_connection_headers(parsed_url, keep_alive=keep_alive), **cls.extra_headers} diff --git a/test/unit/webdriver/appium_connection_test.py b/test/unit/webdriver/appium_connection_test.py index 4338c130..c59303e7 100644 --- a/test/unit/webdriver/appium_connection_test.py +++ b/test/unit/webdriver/appium_connection_test.py @@ -29,3 +29,11 @@ def test_get_remote_connection_headers(self): ) self.assertIsNone(headers.get('X-Idempotency-Key')) self.assertEqual(headers.get('custom'), 'header') + + def test_remove_headers_case_insensitive(self): + for h in ['X-Idempotency-Key', 'X-idempotency-Key', 'x-idempotency-key']: + appium_connection.AppiumConnection.extra_headers = {h: 'value'} + appium_connection.AppiumConnection.get_remote_connection_headers( + parse.urlparse('http://http://127.0.0.1:4723/session/session_id') + ) + self.assertEqual(appium_connection.AppiumConnection.extra_headers, {})