diff --git a/mwdb/model/migrations/versions/72a94f88d2b6_oidc_group_referenced_by_id_instead_of_.py b/mwdb/model/migrations/versions/72a94f88d2b6_oidc_group_referenced_by_id_instead_of_.py new file mode 100644 index 00000000..3902fad3 --- /dev/null +++ b/mwdb/model/migrations/versions/72a94f88d2b6_oidc_group_referenced_by_id_instead_of_.py @@ -0,0 +1,78 @@ +"""OIDC group referenced by id instead of name + convert to non-workspace + +Revision ID: 72a94f88d2b6 +Revises: 6fc42e070495 +Create Date: 2024-08-20 13:29:36.839985 + +""" +import logging + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "72a94f88d2b6" +down_revision = "6fc42e070495" +branch_labels = None +depends_on = None + +group_helper = sa.Table( + "group", + sa.MetaData(), + sa.Column("id", sa.Integer()), + sa.Column("name", sa.String(32)), + sa.Column("private", sa.Boolean()), + sa.Column("default", sa.Boolean()), + sa.Column("workspace", sa.Boolean()), +) + +provider_helper = sa.Table( + "openid_provider", + sa.MetaData(), + sa.Column("id", sa.Integer()), + sa.Column("name", sa.String(64)), + sa.Column("group_id", sa.Integer()), +) + +logger = logging.getLogger("alembic") + + +def group_name_from_provider_name(provider_name): + return ("OpenID_" + provider_name)[:32] + + +def upgrade(): + connection = op.get_bind() + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("openid_provider", sa.Column("group_id", sa.Integer(), nullable=True)) + op.create_foreign_key(None, "openid_provider", "group", ["group_id"], ["id"]) + + # Migrate existing providers + for provider in connection.execute(provider_helper.select()): + group_name = group_name_from_provider_name(provider.name) + group = connection.execute( + group_helper.select().where(group_helper.c.name == group_name) + ).first() + connection.execute( + group_helper.update() + .where(group_helper.c.name == group_name) + .values(workspace=False) + ) + connection.execute( + provider_helper.update() + .where(provider_helper.c.id == provider.id) + .values(group_id=group.id) + ) + + # Set group_id as non-nullable + op.alter_column( + "openid_provider", "group_id", existing_type=sa.INTEGER(), nullable=False + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "openid_provider", type_="foreignkey") + op.drop_column("openid_provider", "group_id") + # ### end Alembic commands ### diff --git a/mwdb/model/oauth.py b/mwdb/model/oauth.py index b6d1f422..53ed624b 100644 --- a/mwdb/model/oauth.py +++ b/mwdb/model/oauth.py @@ -1,7 +1,4 @@ -from werkzeug.exceptions import NotFound - from mwdb.core.oauth import OpenIDClient -from mwdb.model import Group from . import db @@ -19,11 +16,17 @@ class OpenIDProvider(db.Model): jwks_endpoint = db.Column(db.Text, nullable=True) logout_endpoint = db.Column(db.Text, nullable=True) + group_id = db.Column(db.Integer, db.ForeignKey("group.id"), nullable=False) + identities = db.relationship( "OpenIDUserIdentity", back_populates="provider", cascade="all, delete-orphan", ) + group = db.relationship( + "Group", + cascade="all, delete", + ) def get_oidc_client(self): return OpenIDClient( @@ -39,12 +42,9 @@ def get_oidc_client(self): state=None, ) - def get_group(self): - group_name = ("OpenID_" + self.name)[:32] - group = db.session.query(Group).filter(Group.name == group_name).first() - if group is None: - raise NotFound("No such group") - return group + @property + def group_name(self): + return ("OpenID_" + self.name)[:32] class OpenIDUserIdentity(db.Model): diff --git a/mwdb/resources/oauth.py b/mwdb/resources/oauth.py index e6fa8c0e..c13f8b96 100644 --- a/mwdb/resources/oauth.py +++ b/mwdb/resources/oauth.py @@ -121,20 +121,22 @@ def post(self): logout_endpoint=logout_endpoint, ) - group_name = ("OpenID_" + obj["name"])[:32] - - group_name_obj = load_schema({"name": group_name}, GroupNameSchemaBase()) + group_name_obj = load_schema( + {"name": provider.group_name}, GroupNameSchemaBase() + ) if db.session.query( exists().where(Group.name == group_name_obj["name"]) ).scalar(): raise Conflict("Group exists yet, choose another provider name") - group = Group(name=group_name_obj["name"], immutable=True) - + group = Group(name=group_name_obj["name"], immutable=True, workspace=False) db.session.add(group) - db.session.add(provider) + db.session.flush() + db.session.refresh(group) + provider.group_id = group.id + db.session.add(provider) db.session.commit() hooks.on_created_group(group) @@ -301,15 +303,14 @@ def delete(self, provider_name): .filter(OpenIDProvider.name == provider_name) .first() ) + provider_group_name = provider.group_name if not provider: raise NotFound(f"Requested provider name '{provider_name}' not found") - group = provider.get_group() db.session.delete(provider) - db.session.delete(group) db.session.commit() - hooks.on_removed_group(("OpenID_" + provider_name)[:32]) + hooks.on_removed_group(provider_group_name) logger.info("Provider was deleted", extra={"provider": provider_name}) schema = OpenIDProviderSuccessResponseSchema() return schema.dump({"name": provider_name}) @@ -429,7 +430,7 @@ def post(self, provider_name): if not provider: raise NotFound(f"Requested provider name '{provider_name}' not found") - group = provider.get_group() + group = provider.group schema = OpenIDAuthorizeRequestSchema() obj = loads_schema(request.get_data(as_text=True), schema) @@ -564,7 +565,7 @@ def post(self, provider_name): if not provider: raise NotFound(f"Requested provider name '{provider_name}' not found") - group = provider.get_group() + group = provider.group schema = OpenIDAuthorizeRequestSchema() obj = loads_schema(request.get_data(as_text=True), schema)