diff --git a/integrations/aws/.port/spec.yaml b/integrations/aws/.port/spec.yaml index 32e91e02b2..15adc57093 100644 --- a/integrations/aws/.port/spec.yaml +++ b/integrations/aws/.port/spec.yaml @@ -43,6 +43,11 @@ configurations: require: false description: The number of concurrent accounts to scan. By default, it is set to 50. default: 50 + - name: assumeRoleDuration + type: integer + require: false + description: The duration in seconds for which the credentials are valid. By default, it is set to 3600 seconds. + default: 3600 deploymentMethodRequirements: - type: default configurations: ['awsAccessKeyId', 'awsSecretAccessKey'] diff --git a/integrations/aws/CHANGELOG.md b/integrations/aws/CHANGELOG.md index a9649a47de..3d21920597 100644 --- a/integrations/aws/CHANGELOG.md +++ b/integrations/aws/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 +## 0.2.79 (2025-01-07) + + +### Bug Fixes + +- Fixed a bug where token refresh fails because its triggered while an active session was still using the old token. + + ## 0.2.78 (2025-01-02) diff --git a/integrations/aws/aws/aws_credentials.py b/integrations/aws/aws/aws_credentials.py index 2184df3e09..42b3009c6a 100644 --- a/integrations/aws/aws/aws_credentials.py +++ b/integrations/aws/aws/aws_credentials.py @@ -1,21 +1,46 @@ -from typing import AsyncIterator, Optional, Iterable +from typing import AsyncIterator, Optional, Iterable, Dict, Any, Callable, Awaitable +import typing import aioboto3 +from aiobotocore.credentials import ( + AioRefreshableCredentials, +) +from aiobotocore.session import get_session +from loguru import logger + +from types_aiobotocore_sts import STSClient + +from datetime import datetime, timezone, timedelta class AwsCredentials: def __init__( self, account_id: str, - access_key_id: str, - secret_access_key: str, + access_key_id: Optional[str] = None, + sts_client: Optional[STSClient] = None, + secret_access_key: Optional[str] = None, session_token: Optional[str] = None, + role_arn: Optional[str] = None, + session_name: Optional[str] = None, + duration: Optional[float] = None, ): self.account_id = account_id self.access_key_id = access_key_id self.secret_access_key = secret_access_key - self.session_token = session_token self.enabled_regions: list[str] = [] self.default_regions: list[str] = [] + self.role_arn = role_arn + self.session_name = session_name + self.session_token = session_token + self.duration = duration + self.sts_client = sts_client + + def is_role(self) -> bool: + return self.role_arn is not None + + def expiry_time(self) -> str: + expiry = datetime.now(timezone.utc) + timedelta(seconds=self.duration or 3600) + return expiry.isoformat() async def update_enabled_regions(self) -> None: session = aioboto3.Session( @@ -33,14 +58,62 @@ async def update_enabled_regions(self) -> None: if region["RegionOptStatus"] == "ENABLED_BY_DEFAULT" ] - def is_role(self) -> bool: - return self.session_token is not None + def _create_refresh_function(self) -> Callable[[], Awaitable[Dict[str, Any]]]: + """ + Returns a callable that fetches new credentials when the current credentials are close to expiry. + """ + + async def refresh() -> Dict[str, Any]: + """ + Refreshes AWS credentials by re-assuming the role to get new credentials. + + :return: A dictionary containing the new credentials and their expiration time. + """ + sts_client = typing.cast(STSClient, self.sts_client) + response = await sts_client.assume_role( + RoleArn=str(self.role_arn), + RoleSessionName=str(self.session_name), + ) + credentials = response["Credentials"] + self.access_key_id = credentials["AccessKeyId"] + self.secret_access_key = credentials["SecretAccessKey"] + self.session_token = credentials["SessionToken"] + expiration = credentials["Expiration"].isoformat() + return { + "access_key": self.access_key_id, + "secret_key": self.secret_access_key, + "token": self.session_token, + "expiry_time": expiration, + } + + return refresh async def create_session(self, region: Optional[str] = None) -> aioboto3.Session: + """ + Create a session possibly using AioRefreshableCredentials for auto refresh if these are role-based credentials. + """ if self.is_role(): - return aioboto3.Session( - self.access_key_id, self.secret_access_key, self.session_token, region + # For a role, use a refreshable credentials object + logger.debug( + f"Creating a refreshable session for role {self.role_arn} in account {self.account_id} for region {region}" + ) + + refresh_func = self._create_refresh_function() + + credentials = AioRefreshableCredentials.create_from_metadata( + metadata=await refresh_func(), + refresh_using=refresh_func, + method="sts-assume-role", ) + + botocore_session = get_session() + setattr(botocore_session, "_credentials", credentials) + if region: + botocore_session.set_config_variable("region", region) + + autorefresh_session = aioboto3.Session(botocore_session=botocore_session) + return autorefresh_session + else: return aioboto3.Session( aws_access_key_id=self.access_key_id, @@ -52,5 +125,6 @@ async def create_session_for_each_region( self, allowed_regions: Optional[Iterable[str]] = None ) -> AsyncIterator[aioboto3.Session]: regions = allowed_regions or self.enabled_regions + for region in regions: yield await self.create_session(region) diff --git a/integrations/aws/aws/session_manager.py b/integrations/aws/aws/session_manager.py index ff6d36cef5..eb4e277e97 100644 --- a/integrations/aws/aws/session_manager.py +++ b/integrations/aws/aws/session_manager.py @@ -15,7 +15,7 @@ class AccountNotFoundError(OceanAbortException): pass -ASSUME_ROLE_DURATION_SECONDS = 3600 # 1 hour +ASSUME_ROLE_DURATION_SECONDS = 900 # 1 hour class SessionManager: @@ -153,21 +153,20 @@ async def _assume_role_and_update_credentials( self, sts_client: STSClient, account: dict[str, Any] ) -> None: try: - account_role = await sts_client.assume_role( - RoleArn=f'arn:aws:iam::{account["Id"]}:role/{self._get_account_read_role_name()}', - RoleSessionName="OceanMemberAssumeRoleSession", - DurationSeconds=ASSUME_ROLE_DURATION_SECONDS, - ) - raw_credentials = account_role["Credentials"] + role_arn = f"arn:aws:iam::{account['Id']}:role/{self._get_account_read_role_name()}" + session_name = "OceanMemberAssumeRoleSession" + credentials = AwsCredentials( account_id=account["Id"], - access_key_id=raw_credentials["AccessKeyId"], - secret_access_key=raw_credentials["SecretAccessKey"], - session_token=raw_credentials["SessionToken"], + sts_client=sts_client, + role_arn=role_arn, + session_name=session_name, + duration=ASSUME_ROLE_DURATION_SECONDS, ) await credentials.update_enabled_regions() self._aws_credentials.append(credentials) self._aws_accessible_accounts.append(account) + except sts_client.exceptions.ClientError as e: if is_access_denied_exception(e): logger.info(f"Cannot assume role in account {account['Id']}. Skipping.") diff --git a/integrations/aws/main.py b/integrations/aws/main.py index d9edbaffe4..df150a5526 100644 --- a/integrations/aws/main.py +++ b/integrations/aws/main.py @@ -21,7 +21,7 @@ describe_accessible_accounts, get_accounts, get_sessions, - update_available_access_credentials, + initialize_access_credentials, validate_request, ) from port_ocean.context.ocean import ocean @@ -77,6 +77,9 @@ async def resync_resources_for_account( aws_resource_config = typing.cast(AWSResourceConfig, event.resource_config) if is_global_resource(kind): + logger.info( + f"Handling global resource {kind} for account {credentials.account_id}" + ) async for batch in _handle_global_resource_resync( kind, credentials, aws_resource_config ): @@ -106,7 +109,6 @@ async def resync_all(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: if kind in iter(ResourceKindsWithSpecialHandling): return - await update_available_access_credentials() tasks = [ semaphore_async_iterator( semaphore, @@ -116,20 +118,17 @@ async def resync_all(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: ] if tasks: async for batch in stream_async_iterators_tasks(*tasks): - await update_available_access_credentials() yield batch @ocean.on_resync(kind=ResourceKindsWithSpecialHandling.ACCOUNT) async def resync_account(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: - await update_available_access_credentials() for account in describe_accessible_accounts(): yield [fix_unserializable_date_properties(account)] @ocean.on_resync(kind=ResourceKindsWithSpecialHandling.ELASTICACHE_CLUSTER) async def resync_elasticache(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: - await update_available_access_credentials() aws_resource_config = typing.cast(AWSResourceConfig, event.resource_config) tasks = [ @@ -150,13 +149,11 @@ async def resync_elasticache(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: ] if tasks: async for batch in stream_async_iterators_tasks(*tasks): - await update_available_access_credentials() yield batch @ocean.on_resync(kind=ResourceKindsWithSpecialHandling.ELBV2_LOAD_BALANCER) async def resync_elv2_load_balancer(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: - await update_available_access_credentials() aws_resource_config = typing.cast(AWSResourceConfig, event.resource_config) tasks = [ @@ -178,13 +175,11 @@ async def resync_elv2_load_balancer(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: if tasks: async for batch in stream_async_iterators_tasks(*tasks): - await update_available_access_credentials() yield batch @ocean.on_resync(kind=ResourceKindsWithSpecialHandling.ACM_CERTIFICATE) async def resync_acm(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: - await update_available_access_credentials() aws_resource_config = typing.cast(AWSResourceConfig, event.resource_config) tasks = [ @@ -206,13 +201,11 @@ async def resync_acm(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: if tasks: async for batch in stream_async_iterators_tasks(*tasks): - await update_available_access_credentials() yield batch @ocean.on_resync(kind=ResourceKindsWithSpecialHandling.AMI_IMAGE) async def resync_ami(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: - await update_available_access_credentials() aws_resource_config = typing.cast(AWSResourceConfig, event.resource_config) tasks = [ @@ -234,13 +227,11 @@ async def resync_ami(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: ] if tasks: async for batch in stream_async_iterators_tasks(*tasks): - await update_available_access_credentials() yield batch @ocean.on_resync(kind=ResourceKindsWithSpecialHandling.CLOUDFORMATION_STACK) async def resync_cloudformation(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: - await update_available_access_credentials() aws_resource_config = typing.cast(AWSResourceConfig, event.resource_config) tasks = [ @@ -262,7 +253,6 @@ async def resync_cloudformation(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: if tasks: async for batch in stream_async_iterators_tasks(*tasks): - await update_available_access_credentials() yield batch @@ -301,7 +291,6 @@ class ResourceUpdate(BaseModel): @ocean.router.post("/webhook") async def webhook(update: ResourceUpdate, response: Response) -> fastapi.Response: - await update_available_access_credentials() try: logger.info(f"Received AWS Webhook request body: {update}") resource_type = update.resource_type @@ -401,3 +390,16 @@ async def webhook(update: ResourceUpdate, response: Response) -> fastapi.Respons status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=json.dumps({"ok": False, "error": str(e)}), ) + + +@ocean.on_start() +async def on_start() -> None: + logger.info("Starting Port Ocean AWS integration") + + if not ocean.integration_config.get("live_events_api_key"): + logger.warning( + "No live events api key provided" + "Without setting up the webhook, the integration will not export live changes from AWS" + ) + + await initialize_access_credentials() diff --git a/integrations/aws/pyproject.toml b/integrations/aws/pyproject.toml index e34c145729..bba63c20c1 100644 --- a/integrations/aws/pyproject.toml +++ b/integrations/aws/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws" -version = "0.2.78" +version = "0.2.79-rc1" description = "This integration will map all your resources in all the available accounts to your Port entities" authors = ["Shalev Avhar ", "Erik Zaadi "] diff --git a/integrations/aws/tests/utils/test_aws.py b/integrations/aws/tests/utils/test_aws.py index c74b16d1bd..e4d16b2681 100644 --- a/integrations/aws/tests/utils/test_aws.py +++ b/integrations/aws/tests/utils/test_aws.py @@ -1,44 +1,12 @@ import unittest from unittest.mock import AsyncMock, patch -from typing import AsyncGenerator, Any, List -from utils.aws import update_available_access_credentials, get_sessions, session_factory -from port_ocean.utils.async_iterators import stream_async_iterators_tasks +from typing import List +from utils.aws import get_sessions, session_factory from aws.aws_credentials import AwsCredentials from aws.session_manager import SessionManager from aioboto3 import Session -class TestUpdateAvailableAccessCredentials(unittest.IsolatedAsyncioTestCase): - """Test cases to simulate and handle the thundering herd problem in AWS credentials reset.""" - - @staticmethod - async def _run_update_access_iterator_result() -> AsyncGenerator[bool, None]: - result: bool = await update_available_access_credentials() - yield result - - @staticmethod - async def _create_iterator_tasks(func: Any, count: int) -> List[Any]: - """Helper to create async tasks.""" - return [func() for _ in range(count)] - - @patch("utils.aws._session_manager.reset", new_callable=AsyncMock) - @patch("utils.aws.lock", new_callable=AsyncMock) - async def test_multiple_task_execution( - self, mock_lock: AsyncMock, mock_reset: AsyncMock - ) -> None: - tasks: List[Any] = await self._create_iterator_tasks( - self._run_update_access_iterator_result, 10 - ) - async for result in stream_async_iterators_tasks(*tasks): - self.assertTrue(result) - - # Assert that the reset method was awaited exactly once (i.e., no thundering herd) - mock_reset.assert_awaited_once() - - mock_lock.__aenter__.assert_awaited_once() - mock_lock.__aexit__.assert_awaited_once() - - class TestAwsSessions(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.session_manager_mock: AsyncMock = patch( diff --git a/integrations/aws/utils/aws.py b/integrations/aws/utils/aws.py index 4de0a93a2b..47279b1d03 100644 --- a/integrations/aws/utils/aws.py +++ b/integrations/aws/utils/aws.py @@ -1,38 +1,21 @@ from typing import Any, AsyncIterator, Optional, Union import aioboto3 -from port_ocean.context.ocean import ocean from starlette.requests import Request -from aws.session_manager import SessionManager, ASSUME_ROLE_DURATION_SECONDS -from aws.aws_credentials import AwsCredentials - -from aiocache import cached, Cache # type: ignore -from asyncio import Lock - +from port_ocean.context.ocean import ocean from port_ocean.utils.async_iterators import stream_async_iterators_tasks -_session_manager: SessionManager = SessionManager() - -CACHE_DURATION_SECONDS = ( - 0.80 * ASSUME_ROLE_DURATION_SECONDS -) # Refresh role credentials after exhausting 80% of the session duration +from aws.aws_credentials import AwsCredentials +from aws.session_manager import SessionManager -lock = Lock() +_session_manager: SessionManager = SessionManager() -@cached(ttl=CACHE_DURATION_SECONDS, cache=Cache.MEMORY) -async def update_available_access_credentials() -> bool: - """ - Fetches the AWS account IDs that the current IAM role can access. - and saves them up to use as sessions - :return: List of AWS account IDs. - """ - async with lock: - await _session_manager.reset() - # makes this run once per DurationSeconds - return True +async def initialize_access_credentials() -> bool: + await _session_manager.reset() + return True def describe_accessible_accounts() -> list[dict[str, Any]]: @@ -49,7 +32,7 @@ async def get_accounts() -> AsyncIterator[AwsCredentials]: """ Gets the AWS account IDs that the current IAM role can access. """ - await update_available_access_credentials() + for credentials in _session_manager._aws_credentials: yield credentials @@ -78,7 +61,6 @@ async def get_sessions( """ Gets boto3 sessions for the AWS regions. """ - await update_available_access_credentials() if custom_account_id: credentials = _session_manager.find_credentials_by_account_id(custom_account_id)