Skip to content

Commit

Permalink
try add cognito unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
xiazhvera committed Jan 2, 2025
1 parent bb8bd88 commit b9c398a
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions test/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ def has_custom_auth_environment():
class Config:
cache = None

def __init__(self, endpoint, cert, key, region, cognito_creds):
def __init__(self, endpoint, cert, key, region, cognito_creds, cognito_id):
self.cert = cert
self.key = key
self.endpoint = endpoint
self.region = region
self.cognito_creds = cognito_creds
self.cognito_id = cognito_id

@staticmethod
def get():
Expand Down Expand Up @@ -66,7 +67,7 @@ def get():
response = cognito.get_credentials_for_identity(IdentityId=cognito_id)
cognito_creds = response['Credentials']

Config.cache = Config(endpoint, cert, key, region, cognito_creds)
Config.cache = Config(endpoint, cert, key, region, cognito_creds, cognito_id)
except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as ex:
print(ex)
raise unittest.SkipTest("No credentials")
Expand Down Expand Up @@ -157,6 +158,24 @@ def test_websockets_sts(self):
client_id=create_client_id(),
client_bootstrap=bootstrap)
self._test_connection(connection)

def test_websockets_cognito(self):
"""Websocket connection with X-Amz-Security-Token query param"""
config = Config.get()
elg = EventLoopGroup()
resolver = DefaultHostResolver(elg)
bootstrap = ClientBootstrap(elg, resolver)
cognito_endpoint = f"cognito-identity.{config.region}.amazonaws.com"
credentials_provider = auth.AwsCredentialsProvider.new_cognito(
endpoint=cognito_endpoint,
identity=config.cognito_id)
connection = mqtt_connection_builder.websockets_with_default_aws_signing(
region=config.region,
credentials_provider=cred_provider,
endpoint=config.endpoint,
client_id=create_client_id(),
client_bootstrap=bootstrap)
self._test_connection(connection)

@unittest.skipIf(PROXY_HOST is None, 'requires "proxyhost" and "proxyport" env vars')
def test_websockets_proxy(self):
Expand Down

0 comments on commit b9c398a

Please sign in to comment.