diff --git a/app/dao/email_branding_dao.py b/app/dao/email_branding_dao.py index d8738e9200..1ed90ae1e6 100644 --- a/app/dao/email_branding_dao.py +++ b/app/dao/email_branding_dao.py @@ -3,7 +3,9 @@ from app.models import EmailBranding -def dao_get_email_branding_options(): +def dao_get_email_branding_options(filter_by_organisation_id=None): + if filter_by_organisation_id: + return EmailBranding.query.filter_by(organisation_id=filter_by_organisation_id).all() return EmailBranding.query.all() diff --git a/app/email_branding/rest.py b/app/email_branding/rest.py index 3dc5086148..6ae95745be 100644 --- a/app/email_branding/rest.py +++ b/app/email_branding/rest.py @@ -20,7 +20,10 @@ @email_branding_blueprint.route("", methods=["GET"]) def get_email_branding_options(): - email_branding_options = [o.serialize() for o in dao_get_email_branding_options()] + filter_by_organisation_id = request.args.get("organisation_id", None) + email_branding_options = [ + o.serialize() for o in dao_get_email_branding_options(filter_by_organisation_id=filter_by_organisation_id) + ] return jsonify(email_branding=email_branding_options) diff --git a/tests/app/dao/test_email_branding_dao.py b/tests/app/dao/test_email_branding_dao.py index a69c912577..a3bc948a34 100644 --- a/tests/app/dao/test_email_branding_dao.py +++ b/tests/app/dao/test_email_branding_dao.py @@ -5,11 +5,12 @@ dao_update_email_branding, ) from app.models import EmailBranding -from tests.app.db import create_email_branding +from tests.app.db import create_email_branding, create_organisation def test_get_email_branding_options_gets_all_email_branding(notify_db, notify_db_session): - email_branding_1 = create_email_branding(name="test_email_branding_1") + org_1 = create_organisation() + email_branding_1 = create_email_branding(name="test_email_branding_1", organisation_id=org_1.id) email_branding_2 = create_email_branding(name="test_email_branding_2") email_branding = dao_get_email_branding_options() @@ -18,6 +19,13 @@ def test_get_email_branding_options_gets_all_email_branding(notify_db, notify_db assert email_branding_1 == email_branding[0] assert email_branding_2 == email_branding[1] + org_1_id = email_branding_1.organisation_id + + email_branding = dao_get_email_branding_options(filter_by_organisation_id=org_1_id) + assert len(email_branding) == 1 + assert email_branding_1 == email_branding[0] + assert email_branding[0].organisation_id == org_1_id + def test_get_email_branding_by_id_gets_correct_email_branding(notify_db, notify_db_session): email_branding = create_email_branding() diff --git a/tests/app/db.py b/tests/app/db.py index 1dacec37bc..c9ff33427c 100644 --- a/tests/app/db.py +++ b/tests/app/db.py @@ -502,13 +502,15 @@ def create_service_callback_api( return service_callback_api -def create_email_branding(colour="blue", logo="test_x2.png", name="test_org_1", text="DisplayName"): +def create_email_branding(colour="blue", logo="test_x2.png", name="test_org_1", text="DisplayName", organisation_id=None): data = { "colour": colour, "logo": logo, "name": name, "text": text, } + if organisation_id: + data["organisation_id"] = organisation_id email_branding = EmailBranding(**data) dao_create_email_branding(email_branding) diff --git a/tests/app/email_branding/test_rest.py b/tests/app/email_branding/test_rest.py index 9d7bd1f6f7..05e0b5a48e 100644 --- a/tests/app/email_branding/test_rest.py +++ b/tests/app/email_branding/test_rest.py @@ -21,6 +21,23 @@ def test_get_email_branding_options(admin_request, notify_db, notify_db_session, assert email_branding[1]["organisation_id"] == "" +def test_get_email_branding_options_filter_org(admin_request, notify_db, notify_db_session, sample_organisation): + email_branding1 = EmailBranding(colour="#FFFFFF", logo="/path/image.png", name="Org1", organisation_id=sample_organisation.id) + email_branding2 = EmailBranding(colour="#000000", logo="/path/other.png", name="Org2") + notify_db.session.add_all([email_branding1, email_branding2]) + notify_db.session.commit() + email_branding = admin_request.get("email_branding.get_email_branding_options", organisation_id=sample_organisation.id)[ + "email_branding" + ] + + assert len(email_branding) == 1 + assert email_branding[0]["organisation_id"] == str(sample_organisation.id) + + email_branding2 = admin_request.get("email_branding.get_email_branding_options")["email_branding"] + + assert len(email_branding2) == 2 + + def test_get_email_branding_by_id(admin_request, notify_db, notify_db_session): email_branding = EmailBranding(colour="#FFFFFF", logo="/path/image.png", name="Some Org", text="My Org") notify_db.session.add(email_branding)