Skip to content

Commit

Permalink
View active and past/expired request (#578)
Browse files Browse the repository at this point in the history
* Adding organizations integrations

* Linking accounts to form

* Resolving additonal merge conflicts

* Adding dynomodb fields

* Making changes so that the request is properly inputted into the database

* Modifying the AWSREquest form

* Adding unit tests, formatting and linting

* Renaming file

* Adding comments to function

* Adding active requests

* Adding functionality to display pending and past requests

* Fixing up unit tests, linting, and fomratting files
  • Loading branch information
sylviamclaughlin authored Jul 16, 2024
1 parent 53d42fd commit ad8fbb4
Show file tree
Hide file tree
Showing 14 changed files with 16,999 additions and 260 deletions.
6 changes: 3 additions & 3 deletions app/jobs/scheduled_tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jobs.revoke_aws_sso_access import revoke_aws_sso_access
# from jobs.revoke_aws_sso_access import revoke_aws_sso_access
from jobs.notify_stale_incident_channels import notify_stale_incident_channels
import threading
import time
Expand All @@ -19,8 +19,8 @@ def init(bot):
schedule.every().day.at("16:00").do(
notify_stale_incident_channels, client=bot.client
)

schedule.every(10).seconds.do(revoke_aws_sso_access, client=bot.client)
# Commenting out the following line to avoid running the task every 10 seconds. Will be enabled at the time of deployment.
# schedule.every(10).seconds.do(revoke_aws_sso_access, client=bot.client)
schedule.every(5).minutes.do(scheduler_heartbeat)
schedule.every(5).minutes.do(integration_healthchecks)
schedule.every(2).hours.do(provision_aws_identity_center)
Expand Down
44 changes: 44 additions & 0 deletions app/modules/aws/aws_access_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,47 @@ def get_expired_requests():
},
)
return response.get("Items", [])


def get_active_requests():
"""
Retrieves active requests from the DynamoDB table.
This function fetches records where the current time is less than the 'end_date_time' attribute,
indicating active requests.
Returns:
list: A list of active items from the DynamoDB table, or an empty list if none are found.
"""
# Get the current timestamp
current_timestamp = datetime.datetime.now().timestamp()

# Query to get records where current date time is less than end_date_time
response = client.scan(
TableName=table,
FilterExpression="end_date_time > :current_time",
ExpressionAttributeValues={":current_time": {"S": str(current_timestamp)}},
)
return response.get("Items", [])


def get_past_requests():
"""
Retrieves past requests from the DynamoDB table.
This function fetches records where the current time is greater than the 'end_date_time' attribute,
indicating past requests.
Returns:
list: A list of past items from the DynamoDB table, or an empty list if none are found.
"""
# Get the current timestamp
current_timestamp = datetime.datetime.now().timestamp()

# Query to get records where current date time is greater than end_date_time
response = client.scan(
TableName=table,
FilterExpression="end_date_time < :current_time",
ExpressionAttributeValues={":current_time": {"S": str(current_timestamp)}},
)
return response.get("Items", [])
31 changes: 31 additions & 0 deletions app/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from datetime import datetime, timezone, timedelta
from integrations.aws.organizations import get_active_account_names
from modules.aws.aws import request_aws_account_access
from modules.aws.aws_access_requests import get_active_requests, get_past_requests

logging.basicConfig(level=logging.INFO)
sns_message_validator = SNSMessageValidator()
Expand Down Expand Up @@ -332,6 +333,36 @@ async def get_accounts(request: Request, user: dict = Depends(get_current_user))
return get_active_account_names()


@handler.get("/active_requests")
@limiter.limit("5/minute")
async def get_aws_active_requests(
request: Request, user: dict = Depends(get_current_user)
):
"""
Retrieves the active access requests from the database.
Args:
request (Request): The HTTP request object.
Returns:
list: The list of active access requests.
"""
return get_active_requests()


@handler.get("/past_requests")
@limiter.limit("5/minute")
async def get_aws_past_requests(
request: Request, user: dict = Depends(get_current_user)
):
"""
Retrieves the past access requests from the database.
Args:
request (Request): The HTTP request object.
Returns:
list: The list of past access requests.
"""
return get_past_requests()


@handler.post("/hook/{id}")
@limiter.limit(
"30/minute"
Expand Down
1 change: 0 additions & 1 deletion app/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from jose import JWTError, jwt
from fastapi import HTTPException, status, Request


logging.basicConfig(level=logging.INFO)

load_dotenv()
Expand Down
102 changes: 102 additions & 0 deletions app/tests/modules/aws/test_aws_access_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,105 @@ def test_get_expired_requests_returns_list_of_expired_requests(client_mock):
"expired": {"BOOL": True},
}
]


@patch("modules.aws.aws_access_requests.client")
def test_get_active_requests(mock_dynamodb_scan):
# Mock the current timestamp
current_timestamp = datetime.datetime(2024, 1, 1).timestamp()

# Define the mock response
mock_response = {
"Items": [
{"id": {"S": "123"}, "end_date_time": {"S": "1720830150.452"}},
{"id": {"S": "456"}, "end_date_time": {"S": "1720830150.999"}},
]
}
mock_dynamodb_scan.scan.return_value = mock_response

# Call the function
with patch("modules.aws.aws_access_requests.datetime") as mock_datetime:
mock_datetime.datetime.now.return_value = datetime.datetime(2024, 1, 1)
items = aws_access_requests.get_active_requests()

# Assertions
mock_dynamodb_scan.scan.assert_called_once_with(
TableName="aws_access_requests",
FilterExpression="end_date_time > :current_time",
ExpressionAttributeValues={":current_time": {"S": str(current_timestamp)}},
)
assert items == mock_response["Items"]


@patch("modules.aws.aws_access_requests.client")
def test_get_active_requests_empty(mock_dynamodb_scan):
# Mock the current timestamp
current_timestamp = datetime.datetime(2024, 1, 1).timestamp()
with patch("modules.aws.aws_access_requests.datetime") as mock_datetime:
mock_datetime.datetime.now.return_value = datetime.datetime(2024, 1, 1)

# Define the mock response
mock_response = {"Items": []}
mock_dynamodb_scan.scan.return_value = mock_response

# Call the function
items = aws_access_requests.get_active_requests()

# Assertions
mock_dynamodb_scan.scan.assert_called_once_with(
TableName="aws_access_requests",
FilterExpression="end_date_time > :current_time",
ExpressionAttributeValues={":current_time": {"S": str(current_timestamp)}},
)
assert items == mock_response["Items"]


@patch("modules.aws.aws_access_requests.client")
def test_get_past_requests(mock_dynamodb_scan):
# Mock the current timestamp
current_timestamp = datetime.datetime(2024, 1, 1).timestamp()
with patch("modules.aws.aws_access_requests.datetime") as mock_datetime:
mock_datetime.datetime.now.return_value = datetime.datetime(2024, 1, 1)

# Define the mock response
mock_response = {
"Items": [
{"id": {"S": "123"}, "end_date_time": {"S": "1720830150.452"}},
{"id": {"S": "456"}, "end_date_time": {"S": "1720830150.999"}},
]
}
mock_dynamodb_scan.scan.return_value = mock_response

# Call the function
items = aws_access_requests.get_past_requests()

# Assertions
mock_dynamodb_scan.scan.assert_called_once_with(
TableName="aws_access_requests",
FilterExpression="end_date_time < :current_time",
ExpressionAttributeValues={":current_time": {"S": str(current_timestamp)}},
)
assert items == mock_response["Items"]


@patch("modules.aws.aws_access_requests.client")
def test_get_past_requests_empty(mock_dynamodb_scan):
# Mock the current timestamp
current_timestamp = datetime.datetime(2024, 1, 1).timestamp()
with patch("modules.aws.aws_access_requests.datetime") as mock_datetime:
mock_datetime.datetime.now.return_value = datetime.datetime(2024, 1, 1)

# Define the mock response
mock_response = {"Items": []}
mock_dynamodb_scan.scan.return_value = mock_response

# Call the function
items = aws_access_requests.get_past_requests()

# Assertions
mock_dynamodb_scan.scan.assert_called_once_with(
TableName="aws_access_requests",
FilterExpression="end_date_time < :current_time",
ExpressionAttributeValues={":current_time": {"S": str(current_timestamp)}},
)
assert items == mock_response["Items"]
110 changes: 104 additions & 6 deletions app/tests/server/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from starlette.responses import JSONResponse
from httpx import AsyncClient
from starlette.types import Scope
from starlette.datastructures import Headers
from starlette.datastructures import Headers, MutableHeaders
import os
import pytest
import datetime
from fastapi.testclient import TestClient
from fastapi import Request, HTTPException
from fastapi import Request, HTTPException, status

app = server.handler
app.add_middleware(bot_middleware.BotMiddleware, bot=MagicMock())
Expand Down Expand Up @@ -683,11 +683,17 @@ def more_than_24hours_dates_access_request():
)


def get_mock_request(session_data=None):
def get_mock_request(session_data=None, cookies=None):
headers = Headers({"content-type": "application/json"})
if cookies:
cookie_header = "; ".join([f"{key}={value}" for key, value in cookies.items()])
headers = MutableHeaders(headers)
headers.append("cookie", cookie_header)

scope: Scope = {
"type": "http",
"method": "POST",
"headers": Headers({"content-type": "application/json"}).raw,
"headers": headers.raw,
"path": "/request_access",
"raw_path": b"/request_access",
"session": session_data or {},
Expand Down Expand Up @@ -832,7 +838,8 @@ async def test_create_access_request_more_than_24_hours(
):
# Arrange
session_data = {"user": {"username": "test_user", "email": "[email protected]"}}
request = get_mock_request(session_data)
cookies = {"access_token": "mocked_jwt_token"}
request = get_mock_request(session_data, cookies)
mock_accounts = [
{
"Id": "345678901234",
Expand All @@ -849,7 +856,7 @@ async def test_create_access_request_more_than_24_hours(
mock_get_user_email_from_request.return_value = "[email protected]"
mock_get_user_id.return_value = "user_id_456"
mock_get_current_user.return_value = {"user": "test_user"}
mock_create_aws_access_request.return_value = False
mock_create_aws_access_request.return_value = True

# Act & Assert
with pytest.raises(HTTPException) as excinfo:
Expand Down Expand Up @@ -900,3 +907,94 @@ async def test_create_access_request_failure(
await server.create_access_request(request, valid_access_request)
assert excinfo.value.status_code == 500
assert excinfo.value.detail == "Failed to create access request"


@pytest.mark.asyncio
async def test_get_aws_active_requests_unauthenticated():
# Mock get_current_user to raise an HTTPException
with patch("modules.aws.aws_access_requests.get_active_requests"):
with patch(
"server.utils.get_current_user",
side_effect=HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
),
):
# Create an invalid JWT token
invalid_jwt_token = "invalid_jwt_token"

# Mock the cookie in the request
request = get_mock_request(cookies={"access_token": invalid_jwt_token})

# Call the dependency function directly to see if it raises an exception
with pytest.raises(HTTPException):
await server.get_current_user(request)

# If you need to test the actual endpoint, use the TestClient
response = client.get(
"/active_requests", cookies={"access_token": invalid_jwt_token}
)

# Assertions for the endpoint
assert response.status_code == 401
assert response.json() == {"detail": "Invalid token"}


@patch("server.utils.get_current_user", new_callable=AsyncMock)
@patch("modules.aws.aws_access_requests.client")
@patch("modules.aws.aws_access_requests.get_active_requests")
@pytest.mark.asyncio
async def test_get_aws_active_requests_success(
mock_get_active_requests, mock_dynamodbscan, mock_get_current_user
):
mock_get_current_user.return_value = {"username": "test_user"}

mock_response = [
{
"id": {"S": "123"},
"account_name": {"S": "ExampleAccount"},
"access_type": {"S": "read"},
"reason_for_access": {"S": "test_reason"},
"start_date_time": {"S": "1720820150.452"},
"end_date_time": {"S": "1720830150.452"},
"expired": {"BOOL": False},
},
{
"id": {"S": "456"},
"account_name": {"S": "ExampleAccount2"},
"access_type": {"S": "write"},
"reason_for_access": {"S": "test_reason2"},
"start_date_time": {"S": "1720820150.999"},
"end_date_time": {"S": "1720830150.999"},
"expired": {"BOOL": False},
},
]
mock_dynamo_response = {"Items": mock_response}
mock_dynamodbscan.scan.return_value = mock_dynamo_response

# Create a mock request with the cookie
request = get_mock_request(cookies={"access_token": "mocked_jwt_token"})

# Act
mock_get_active_requests.return_value = mock_response
response = await server.get_aws_active_requests(request)

# Assertions
assert response == mock_response


@pytest.mark.asyncio
async def test_get_aws_active_requests_exception_unauthenticated():
# Mock get_current_user to raise an HTTPException
with patch("modules.aws.aws_access_requests.get_active_requests"):
with patch(
"server.utils.get_current_user",
side_effect=HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
),
):
# Make the GET request
response = client.get("/active_requests")

# Assertions
assert response.status_code == 401
assert response.json() == {"detail": "Not authenticated"}
Loading

0 comments on commit ad8fbb4

Please sign in to comment.