diff --git a/app/dao/organisation_dao.py b/app/dao/organisation_dao.py index 8c2ef63ddd..06ed25958d 100644 --- a/app/dao/organisation_dao.py +++ b/app/dao/organisation_dao.py @@ -2,7 +2,14 @@ from app import db from app.dao.dao_utils import transactional, version_class -from app.models import Domain, InvitedOrganisationUser, Organisation, Service, User +from app.models import ( + Domain, + EmailBranding, + InvitedOrganisationUser, + Organisation, + Service, + User, +) def dao_get_organisations(): @@ -55,6 +62,10 @@ def dao_update_organisation(organisation_id, **kwargs): domains = kwargs.pop("domains", None) num_updated = Organisation.query.filter_by(id=organisation_id).update(kwargs) + if "email_branding_id" in kwargs: + email_brand = EmailBranding.query.filter_by(id=kwargs["email_branding_id"]).one() + org = Organisation.query.get(organisation_id) + org.email_branding = email_brand if isinstance(domains, list): Domain.query.filter_by(organisation_id=organisation_id).delete() diff --git a/app/job/rest.py b/app/job/rest.py index 950f1554c9..214eaea866 100644 --- a/app/job/rest.py +++ b/app/job/rest.py @@ -168,27 +168,27 @@ def create_job(service_id): if template.template_type == SMS_TYPE: # calculate the number of simulated recipients - numberOfSimulated = sum( - simulated_recipient(i["phone_number"].data, template.template_type) for i in list(recipient_csv.get_rows()) - ) - mixedRecipients = numberOfSimulated > 0 and numberOfSimulated != len(list(recipient_csv.get_rows())) + numberOfSimulated = sum(simulated_recipient(i["phone_number"].data, template.template_type) for i in recipient_csv.rows) + mixedRecipients = numberOfSimulated > 0 and numberOfSimulated != len(recipient_csv) # if they have specified testing and NON-testing recipients, raise an error if mixedRecipients: raise InvalidRequest(message="Bulk sending to testing and non-testing numbers is not supported", status_code=400) - is_test_notification = len(list(recipient_csv.get_rows())) == numberOfSimulated + is_test_notification = len(recipient_csv) == numberOfSimulated if not is_test_notification: check_sms_daily_limit(service, len(recipient_csv)) increment_sms_daily_count_send_warnings_if_needed(service, len(recipient_csv)) elif template.template_type == EMAIL_TYPE: - check_email_daily_limit(service, len(list(recipient_csv.get_rows()))) + notification_count = int(data.get("notification_count", len(recipient_csv))) + check_email_daily_limit(service, notification_count) + scheduled_for = datetime.fromisoformat(data.get("scheduled_for")) if data.get("scheduled_for") else None if scheduled_for is None or not scheduled_for.date() > datetime.today().date(): - increment_email_daily_count_send_warnings_if_needed(service, len(list(recipient_csv.get_rows()))) + increment_email_daily_count_send_warnings_if_needed(service, notification_count) data.update({"template_version": template.version}) diff --git a/app/models.py b/app/models.py index 704ccf798a..81767406a1 100644 --- a/app/models.py +++ b/app/models.py @@ -276,6 +276,10 @@ class EmailBranding(BaseModel): nullable=False, default=BRANDING_ORG_NEW, ) + organisation_id = db.Column( + UUID(as_uuid=True), db.ForeignKey("organisation.id", ondelete="SET NULL"), index=True, nullable=True + ) + organisation = db.relationship("Organisation", back_populates="email_branding", foreign_keys=[organisation_id]) def serialize(self) -> dict: serialized = { @@ -285,6 +289,7 @@ def serialize(self) -> dict: "name": self.name, "text": self.text, "brand_type": self.brand_type, + "organisation_id": str(self.organisation_id) if self.organisation_id else "", } return serialized @@ -449,10 +454,9 @@ class Organisation(BaseModel): "Domain", ) - email_branding = db.relationship("EmailBranding") + email_branding = db.relationship("EmailBranding", uselist=False) email_branding_id = db.Column( UUID(as_uuid=True), - db.ForeignKey("email_branding.id"), nullable=True, ) diff --git a/migrations/versions/0445_add_org_id_branding.py b/migrations/versions/0445_add_org_id_branding.py new file mode 100644 index 0000000000..0504d5f492 --- /dev/null +++ b/migrations/versions/0445_add_org_id_branding.py @@ -0,0 +1,46 @@ +""" +Revision ID: 0445_add_org_id_branding +Revises: 0444_add_index_n_history2.py +Create Date: 2024-02-27 +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision = "0445_add_org_id_branding" +down_revision = "0444_add_index_n_history2" + + +def upgrade(): + op.add_column( + "email_branding", + sa.Column("organisation_id", postgresql.UUID(as_uuid=True), nullable=True), + ) + op.create_index( + op.f("ix_email_branding_organisation_id"), + "email_branding", + ["organisation_id"], + unique=False, + ) + op.create_foreign_key( + "fk_email_branding_organisation", + "email_branding", + "organisation", + ["organisation_id"], + ["id"], + ondelete="SET NULL", + ) + op.drop_constraint("fk_organisation_email_branding_id", "organisation", type_="foreignkey") + + +def downgrade(): + op.drop_index(op.f("ix_email_branding_organisation_id"), table_name="email_branding") + op.drop_constraint("fk_email_branding_organisation", "email_branding", type_="foreignkey") + op.drop_column("email_branding", "organisation_id") + op.create_foreign_key( + "fk_organisation_email_branding_id", + "organisation", + "email_branding", + ["email_branding_id"], + ["id"], + ) diff --git a/tests/app/email_branding/test_rest.py b/tests/app/email_branding/test_rest.py index c09218d62d..9d7bd1f6f7 100644 --- a/tests/app/email_branding/test_rest.py +++ b/tests/app/email_branding/test_rest.py @@ -4,8 +4,8 @@ from tests.app.db import create_email_branding -def test_get_email_branding_options(admin_request, notify_db, notify_db_session): - email_branding1 = EmailBranding(colour="#FFFFFF", logo="/path/image.png", name="Org1") +def test_get_email_branding_options(admin_request, notify_db, notify_db_session, sample_organisation): + email_branding1 = EmailBranding(colour="#FFFFFF", logo="/path/image.png", name="Org1", organisation_id=sample_organisation.id) email_branding2 = EmailBranding(colour="#000000", logo="/path/other.png", name="Org2") notify_db.session.add_all([email_branding1, email_branding2]) notify_db.session.commit() @@ -17,6 +17,8 @@ def test_get_email_branding_options(admin_request, notify_db, notify_db_session) str(email_branding1.id), str(email_branding2.id), } + assert email_branding[0]["organisation_id"] == str(sample_organisation.id) + assert email_branding[1]["organisation_id"] == "" def test_get_email_branding_by_id(admin_request, notify_db, notify_db_session): @@ -37,6 +39,7 @@ def test_get_email_branding_by_id(admin_request, notify_db, notify_db_session): "id", "text", "brand_type", + "organisation_id", } assert response["email_branding"]["colour"] == "#FFFFFF" assert response["email_branding"]["logo"] == "/path/image.png"