Skip to content

Commit

Permalink
Add an internal api so that we might revoke api keys (#1973)
Browse files Browse the repository at this point in the history
* Add an internal api so that we might revoke api keys

* Add tests for api keys revokation
  • Loading branch information
jzbahrai authored Sep 13, 2023
1 parent 3fb5cfe commit 69072a2
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 2 deletions.
103 changes: 101 additions & 2 deletions app/api_key/rest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from flask import Blueprint, jsonify
from datetime import datetime

import werkzeug
from flask import Blueprint, current_app, jsonify, request

from app import DATETIME_FORMAT
from app.dao.api_key_dao import (
expire_api_key,
get_api_key_by_secret,
update_compromised_api_key_info,
)
from app.dao.fact_notification_status_dao import (
get_api_key_ranked_by_notifications_created,
get_last_send_for_api_key,
get_total_notifications_sent_for_api_key,
)
from app.errors import register_errors
from app.errors import InvalidRequest, register_errors

api_key_blueprint = Blueprint("api_key", __name__)
register_errors(api_key_blueprint)
Expand Down Expand Up @@ -59,3 +67,94 @@ def get_api_keys_ranked(n_days_back):
}
)
return jsonify(data=data)


def send_api_key_revokation_email(service_id, api_key_name, api_key_information):
"""
TODO: this function if not ready yet. It needs a template to be created.
email = email_data_request_schema.load(request.get_json())
users_to_send_to = dao_fetch_active_users_for_service(service_id)
template = dao_get_template_by_id(current_app.config["API_KEY_REVOKED_TEMPLATE_ID"]) # this template currently doesn't exist
service = Service.query.get(current_app.config["NOTIFY_SERVICE_ID"])
users_service = Service.query.get(service_id)
for user_to_send_to in users_to_send_to:
saved_notification = persist_notification(
template_id=template.id,
template_version=template.version,
recipient=email["email"],
service=service,
personalisation={
"user_name": user_to_send_to.name,
"api_key_name": api_key_name,
"service_name": users_service.name,
"api_key_information": api_key_information,
},
notification_type=template.template_type,
api_key_id=None,
key_type=KEY_TYPE_NORMAL,
reply_to_text=service.get_default_reply_to_email_address(),
)
send_notification_to_queue(saved_notification, False, queue=QueueNames.NOTIFY)
"""
return


@api_key_blueprint.route("/revoke-api-keys", methods=["POST"])
def revoke_api_keys():
"""
We take a list of api keys and revoke them. The data is of the form:
[
{
"token": "NMIfyYncKcRALEXAMPLE",
"type": "mycompany_api_token",
"url": "https://github.com/octocat/Hello-World/blob/12345600b9cbe38a219f39a9941c9319b600c002/foo/bar.txt",
"source": "content",
}
]
The function does 3 things:
1. Finds the api key by the token
2. Revokes the api key
3. Saves the source and url into the compromised_key_info field
4. Sends the service owners of the api key an email notification indicating that the key has been revoked
"""
try:
data = request.get_json()
except werkzeug.exceptions.BadRequest as errors:
raise InvalidRequest(errors, status_code=400)

# Step 1
for api_key_data in data:
try:
# take last 36 chars of string so that it works even if the full key is provided.
api_key_token = api_key_data["token"][-36:]
api_key = get_api_key_by_secret(api_key_token)
except Exception:
current_app.logger.error(f"API key not found for token {api_key_data['type']}")
continue # skip to next api key

# Step 2
expire_api_key(api_key.service_id, api_key.id)

current_app.logger.info("Expired api key {} for service {}".format(api_key.id, api_key.service_id))

# Step 3
update_compromised_api_key_info(
api_key.service_id,
api_key.id,
{
"time_of_revocation": str(datetime.utcnow()),
"type": api_key_data["type"],
"url": api_key_data["url"],
"source": api_key_data["source"],
},
)

# Step 4
send_api_key_revokation_email(api_key.service_id, api_key.name, api_key_data)

return jsonify(result="ok"), 201
8 changes: 8 additions & 0 deletions app/dao/api_key_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def expire_api_key(service_id, api_key_id):
db.session.add(api_key)


@transactional
@version_class(ApiKey)
def update_compromised_api_key_info(service_id, api_key_id, compromised_info):
api_key = ApiKey.query.filter_by(id=api_key_id, service_id=service_id).one()
api_key.compromised_key_info = compromised_info
db.session.add(api_key)


def get_api_key_by_secret(secret):
signed_with_all_keys = signer_api_key.sign_with_all_keys(str(secret))
for signed_secret in signed_with_all_keys:
Expand Down
1 change: 1 addition & 0 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ class ApiKey(BaseModel, Versioned):
)
created_by = db.relationship("User")
created_by_id = db.Column(UUID(as_uuid=True), db.ForeignKey("users.id"), index=True, nullable=False)
compromised_key_info = db.Column(JSONB(none_as_null=True), nullable=True, default={})

__table_args__ = (
Index(
Expand Down
26 changes: 26 additions & 0 deletions migrations/versions/0436_add_columns_api_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Revision ID: 0436_add_columns_api_keys
Revises: 0435_update_email_templates_2.py
Create Date: 2023-09-01
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import JSONB

revision = "0436_add_columns_api_keys"
down_revision = "0435_update_email_templates_2"

user = "postgres"
timeout = 60 # in seconds, i.e. 1 minute


def upgrade():
op.add_column("api_keys", sa.Column("compromised_key_info", JSONB, nullable=True))
op.add_column("api_keys_history", sa.Column("compromised_key_info", JSONB, nullable=True))


def downgrade():
op.drop_column("api_keys", "compromised_key_info")
op.drop_column("api_keys_history", "compromised_key_info")
22 changes: 22 additions & 0 deletions tests/app/api_key/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime

from app import DATETIME_FORMAT
from app.dao.api_key_dao import get_api_key_by_secret, get_unsigned_secret
from app.models import KEY_TYPE_NORMAL
from tests.app.db import (
create_api_key,
Expand Down Expand Up @@ -75,3 +76,24 @@ def test_get_api_keys_ranked(admin_request, notify_db, notify_db_session):
assert api_keys_ranked[1]["email_notifications"] == total_sends
assert api_keys_ranked[1]["total_notifications"] == total_sends
assert "last_notification_created" in api_keys_ranked[0]


class TestApiKeyRevocation:
def test_revoke_api_keys(self, admin_request, notify_db, notify_db_session):
service = create_service(service_name="Service 1")
api_key_1 = create_api_key(service, key_type=KEY_TYPE_NORMAL, key_name="Key 1")
unsigned_secret = get_unsigned_secret(api_key_1.id)

admin_request.post(
"api_key.revoke_api_keys",
_data=[{"token": unsigned_secret, "type": "cds-tester", "url": "https://example.com", "source": "cds-tester"}],
_expected_status=201,
)

# Get api key from DB
api_key_1 = get_api_key_by_secret(api_key_1.secret)
assert api_key_1.expiry_date is not None
assert api_key_1.compromised_key_info["type"] == "cds-tester"
assert api_key_1.compromised_key_info["url"] == "https://example.com"
assert api_key_1.compromised_key_info["source"] == "cds-tester"
assert api_key_1.compromised_key_info["time_of_revocation"]
21 changes: 21 additions & 0 deletions tests/app/dao/test_api_key_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_unsigned_secrets,
resign_api_keys,
save_model_api_key,
update_compromised_api_key_info,
)
from app.models import KEY_TYPE_NORMAL, ApiKey
from tests.app.db import create_api_key
Expand Down Expand Up @@ -59,6 +60,26 @@ def test_expire_api_key_should_update_the_api_key_and_create_history_record(noti
sorted_all_history[1].version = 2


def test_update_compromised_api_key_info_and_create_history_record(notify_api, sample_api_key):
update_compromised_api_key_info(
service_id=sample_api_key.service_id, api_key_id=sample_api_key.id, compromised_info={"key": "value"}
)
all_api_keys = get_model_api_keys(service_id=sample_api_key.service_id)
assert len(all_api_keys) == 1
assert all_api_keys[0].secret == sample_api_key.secret
assert all_api_keys[0].id == sample_api_key.id
assert all_api_keys[0].service_id == sample_api_key.service_id
assert all_api_keys[0].compromised_key_info == {"key": "value"}

all_history = sample_api_key.get_history_model().query.all()
assert len(all_history) == 2
assert all_history[0].id == sample_api_key.id
assert all_history[1].id == sample_api_key.id
sorted_all_history = sorted(all_history, key=lambda hist: hist.version)
sorted_all_history[0].version = 1
sorted_all_history[1].version = 2


def test_get_api_key_should_raise_exception_when_api_key_does_not_exist(sample_service, fake_uuid):
with pytest.raises(NoResultFound):
get_model_api_keys(sample_service.id, id=fake_uuid)
Expand Down

0 comments on commit 69072a2

Please sign in to comment.