diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index db64ec13d..889561243 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -155,7 +155,6 @@ async def get_application_relation_data( # Filter the data based on the relation name. relation_data = [v for v in data[unit_name]["relation-info"] if v["endpoint"] == relation_name] - if relation_id: # Filter the data based on the relation id. relation_data = [v for v in relation_data if v["relation-id"] == relation_id] @@ -212,13 +211,15 @@ async def check_or_scale_app(ops_test: OpsTest, user_app_name: str, required_uni # check if we need to scale current_units = len(ops_test.model.applications[user_app_name].units) - if current_units > required_units: + if current_units == required_units: + return + elif current_units > required_units: for i in range(0, current_units): unit_to_remove = [ops_test.model.applications[user_app_name].units[i].name] await ops_test.model.destroy_units(*unit_to_remove) await ops_test.model.wait_for_idle() - - units_to_add = required_units - current_units + else: + units_to_add = required_units - current_units await ops_test.model.applications[user_app_name].add_unit(count=units_to_add) await ops_test.model.wait_for_idle() diff --git a/tests/integration/metrics_tests/test_metrics.py b/tests/integration/metrics_tests/test_metrics.py index 0a6a5b0bc..d7a9eec5c 100644 --- a/tests/integration/metrics_tests/test_metrics.py +++ b/tests/integration/metrics_tests/test_metrics.py @@ -32,7 +32,7 @@ async def test_build_and_deploy(ops_test: OpsTest) -> None: """Build and deploy one unit of MongoDB.""" app_name = await get_app_name(ops_test) if app_name: - return check_or_scale_app(ops_test, app_name, len(UNIT_IDS)) + return await check_or_scale_app(ops_test, app_name, len(UNIT_IDS)) if await get_app_name(ops_test): return diff --git a/tests/integration/test_charm.py b/tests/integration/test_charm.py index 82ed8ff7c..7ffb6b783 100644 --- a/tests/integration/test_charm.py +++ b/tests/integration/test_charm.py @@ -46,7 +46,7 @@ async def test_build_and_deploy(ops_test: OpsTest) -> None: # is a pre-existing cluster. app_name = await get_app_name(ops_test) if app_name: - return check_or_scale_app(ops_test, app_name, len(UNIT_IDS)) + return await check_or_scale_app(ops_test, app_name, len(UNIT_IDS)) my_charm = await ops_test.build_charm(".") await ops_test.model.deploy(my_charm, num_units=len(UNIT_IDS)) diff --git a/tests/integration/tls_tests/test_tls.py b/tests/integration/tls_tests/test_tls.py index cd0e4640b..8ef6c32b1 100644 --- a/tests/integration/tls_tests/test_tls.py +++ b/tests/integration/tls_tests/test_tls.py @@ -39,7 +39,7 @@ async def test_build_and_deploy(ops_test: OpsTest) -> None: # is a pre-existing cluster. app_name = await get_app_name(ops_test) if app_name: - check_or_scale_app(ops_test, app_name, len(UNIT_IDS)) + await check_or_scale_app(ops_test, app_name, len(UNIT_IDS)) else: app_name = DATABASE_APP_NAME async with ops_test.fast_forward():