diff --git a/app/api_key/rest.py b/app/api_key/rest.py index 376f3ed057..6d8fdd4301 100644 --- a/app/api_key/rest.py +++ b/app/api_key/rest.py @@ -1,12 +1,20 @@ -from flask import Blueprint, jsonify +from datetime import datetime + +import werkzeug +from flask import Blueprint, current_app, jsonify, request from app import DATETIME_FORMAT +from app.dao.api_key_dao import ( + expire_api_key, + get_api_key_by_secret, + update_compromised_api_key_info, +) from app.dao.fact_notification_status_dao import ( get_api_key_ranked_by_notifications_created, get_last_send_for_api_key, get_total_notifications_sent_for_api_key, ) -from app.errors import register_errors +from app.errors import InvalidRequest, register_errors api_key_blueprint = Blueprint("api_key", __name__) register_errors(api_key_blueprint) @@ -59,3 +67,94 @@ def get_api_keys_ranked(n_days_back): } ) return jsonify(data=data) + + +def send_api_key_revokation_email(service_id, api_key_name, api_key_information): + """ + TODO: this function if not ready yet. It needs a template to be created. + email = email_data_request_schema.load(request.get_json()) + + users_to_send_to = dao_fetch_active_users_for_service(service_id) + + template = dao_get_template_by_id(current_app.config["API_KEY_REVOKED_TEMPLATE_ID"]) # this template currently doesn't exist + service = Service.query.get(current_app.config["NOTIFY_SERVICE_ID"]) + users_service = Service.query.get(service_id) + for user_to_send_to in users_to_send_to: + saved_notification = persist_notification( + template_id=template.id, + template_version=template.version, + recipient=email["email"], + service=service, + personalisation={ + "user_name": user_to_send_to.name, + "api_key_name": api_key_name, + "service_name": users_service.name, + "api_key_information": api_key_information, + }, + notification_type=template.template_type, + api_key_id=None, + key_type=KEY_TYPE_NORMAL, + reply_to_text=service.get_default_reply_to_email_address(), + ) + + send_notification_to_queue(saved_notification, False, queue=QueueNames.NOTIFY) + + """ + return + + +@api_key_blueprint.route("/revoke-api-keys", methods=["POST"]) +def revoke_api_keys(): + """ + We take a list of api keys and revoke them. The data is of the form: + [ + { + "token": "NMIfyYncKcRALEXAMPLE", + "type": "mycompany_api_token", + "url": "https://github.com/octocat/Hello-World/blob/12345600b9cbe38a219f39a9941c9319b600c002/foo/bar.txt", + "source": "content", + } + ] + + The function does 3 things: + 1. Finds the api key by the token + 2. Revokes the api key + 3. Saves the source and url into the compromised_key_info field + 4. Sends the service owners of the api key an email notification indicating that the key has been revoked + """ + try: + data = request.get_json() + except werkzeug.exceptions.BadRequest as errors: + raise InvalidRequest(errors, status_code=400) + + # Step 1 + for api_key_data in data: + try: + # take last 36 chars of string so that it works even if the full key is provided. + api_key_token = api_key_data["token"][-36:] + api_key = get_api_key_by_secret(api_key_token) + except Exception: + current_app.logger.error(f"API key not found for token {api_key_data['type']}") + continue # skip to next api key + + # Step 2 + expire_api_key(api_key.service_id, api_key.id) + + current_app.logger.info("Expired api key {} for service {}".format(api_key.id, api_key.service_id)) + + # Step 3 + update_compromised_api_key_info( + api_key.service_id, + api_key.id, + { + "time_of_revocation": str(datetime.utcnow()), + "type": api_key_data["type"], + "url": api_key_data["url"], + "source": api_key_data["source"], + }, + ) + + # Step 4 + send_api_key_revokation_email(api_key.service_id, api_key.name, api_key_data) + + return jsonify(result="ok"), 201 diff --git a/app/dao/api_key_dao.py b/app/dao/api_key_dao.py index 3206be5474..679467f4f4 100644 --- a/app/dao/api_key_dao.py +++ b/app/dao/api_key_dao.py @@ -68,6 +68,14 @@ def expire_api_key(service_id, api_key_id): db.session.add(api_key) +@transactional +@version_class(ApiKey) +def update_compromised_api_key_info(service_id, api_key_id, compromised_info): + api_key = ApiKey.query.filter_by(id=api_key_id, service_id=service_id).one() + api_key.compromised_key_info = compromised_info + db.session.add(api_key) + + def get_api_key_by_secret(secret): signed_with_all_keys = signer_api_key.sign_with_all_keys(str(secret)) for signed_secret in signed_with_all_keys: diff --git a/app/models.py b/app/models.py index 3f63794374..fd640e1473 100644 --- a/app/models.py +++ b/app/models.py @@ -917,6 +917,7 @@ class ApiKey(BaseModel, Versioned): ) created_by = db.relationship("User") created_by_id = db.Column(UUID(as_uuid=True), db.ForeignKey("users.id"), index=True, nullable=False) + compromised_key_info = db.Column(JSONB(none_as_null=True), nullable=True, default={}) __table_args__ = ( Index( diff --git a/migrations/versions/0436_add_columns_api_keys.py b/migrations/versions/0436_add_columns_api_keys.py new file mode 100644 index 0000000000..44eb063048 --- /dev/null +++ b/migrations/versions/0436_add_columns_api_keys.py @@ -0,0 +1,26 @@ +""" + +Revision ID: 0436_add_columns_api_keys +Revises: 0435_update_email_templates_2.py +Create Date: 2023-09-01 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import JSONB + +revision = "0436_add_columns_api_keys" +down_revision = "0435_update_email_templates_2" + +user = "postgres" +timeout = 60 # in seconds, i.e. 1 minute + + +def upgrade(): + op.add_column("api_keys", sa.Column("compromised_key_info", JSONB, nullable=True)) + op.add_column("api_keys_history", sa.Column("compromised_key_info", JSONB, nullable=True)) + + +def downgrade(): + op.drop_column("api_keys", "compromised_key_info") + op.drop_column("api_keys_history", "compromised_key_info") diff --git a/tests/app/api_key/test_rest.py b/tests/app/api_key/test_rest.py index 9c1e363098..669d2a32ee 100644 --- a/tests/app/api_key/test_rest.py +++ b/tests/app/api_key/test_rest.py @@ -1,6 +1,7 @@ from datetime import datetime from app import DATETIME_FORMAT +from app.dao.api_key_dao import get_api_key_by_secret, get_unsigned_secret from app.models import KEY_TYPE_NORMAL from tests.app.db import ( create_api_key, @@ -75,3 +76,24 @@ def test_get_api_keys_ranked(admin_request, notify_db, notify_db_session): assert api_keys_ranked[1]["email_notifications"] == total_sends assert api_keys_ranked[1]["total_notifications"] == total_sends assert "last_notification_created" in api_keys_ranked[0] + + +class TestApiKeyRevocation: + def test_revoke_api_keys(self, admin_request, notify_db, notify_db_session): + service = create_service(service_name="Service 1") + api_key_1 = create_api_key(service, key_type=KEY_TYPE_NORMAL, key_name="Key 1") + unsigned_secret = get_unsigned_secret(api_key_1.id) + + admin_request.post( + "api_key.revoke_api_keys", + _data=[{"token": unsigned_secret, "type": "cds-tester", "url": "https://example.com", "source": "cds-tester"}], + _expected_status=201, + ) + + # Get api key from DB + api_key_1 = get_api_key_by_secret(api_key_1.secret) + assert api_key_1.expiry_date is not None + assert api_key_1.compromised_key_info["type"] == "cds-tester" + assert api_key_1.compromised_key_info["url"] == "https://example.com" + assert api_key_1.compromised_key_info["source"] == "cds-tester" + assert api_key_1.compromised_key_info["time_of_revocation"] diff --git a/tests/app/dao/test_api_key_dao.py b/tests/app/dao/test_api_key_dao.py index fb5846ba1a..9535ce0343 100644 --- a/tests/app/dao/test_api_key_dao.py +++ b/tests/app/dao/test_api_key_dao.py @@ -13,6 +13,7 @@ get_unsigned_secrets, resign_api_keys, save_model_api_key, + update_compromised_api_key_info, ) from app.models import KEY_TYPE_NORMAL, ApiKey from tests.app.db import create_api_key @@ -59,6 +60,26 @@ def test_expire_api_key_should_update_the_api_key_and_create_history_record(noti sorted_all_history[1].version = 2 +def test_update_compromised_api_key_info_and_create_history_record(notify_api, sample_api_key): + update_compromised_api_key_info( + service_id=sample_api_key.service_id, api_key_id=sample_api_key.id, compromised_info={"key": "value"} + ) + all_api_keys = get_model_api_keys(service_id=sample_api_key.service_id) + assert len(all_api_keys) == 1 + assert all_api_keys[0].secret == sample_api_key.secret + assert all_api_keys[0].id == sample_api_key.id + assert all_api_keys[0].service_id == sample_api_key.service_id + assert all_api_keys[0].compromised_key_info == {"key": "value"} + + all_history = sample_api_key.get_history_model().query.all() + assert len(all_history) == 2 + assert all_history[0].id == sample_api_key.id + assert all_history[1].id == sample_api_key.id + sorted_all_history = sorted(all_history, key=lambda hist: hist.version) + sorted_all_history[0].version = 1 + sorted_all_history[1].version = 2 + + def test_get_api_key_should_raise_exception_when_api_key_does_not_exist(sample_service, fake_uuid): with pytest.raises(NoResultFound): get_model_api_keys(sample_service.id, id=fake_uuid)