Skip to content

Commit

Permalink
drop foreign key first before dropping column
Browse files Browse the repository at this point in the history
  • Loading branch information
daimor committed Apr 30, 2023
1 parent 4dfc236 commit 4f42588
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 25 deletions.
57 changes: 40 additions & 17 deletions sqlalchemy_iris/alembic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import logging

from typing import Optional

from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.base import Executable
from sqlalchemy.sql.elements import ClauseElement

from alembic.ddl import DefaultImpl
from alembic.ddl.base import ColumnNullable
from alembic.ddl.base import ColumnType
Expand Down Expand Up @@ -56,25 +61,43 @@ def compare_server_default(
rendered_inspector_default,
)

def correct_for_autogen_constraints(
def drop_column(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):

doubled_constraints = {
index
for index in conn_indexes
if index.info.get("duplicates_constraint")
}

for ix in doubled_constraints:
conn_indexes.remove(ix)
table_name: str,
column: Column,
schema: Optional[str] = None,
**kw,
) -> None:
column_name = column.name
fkeys = self.dialect.get_foreign_keys(self.connection, table_name, schema)
fkey = [
fkey["name"] for fkey in fkeys if column_name in fkey["constrained_columns"]
]
if len(fkey) == 1:
self._exec(_ExecDropForeignKey(table_name, fkey[0], schema))
super().drop_column(table_name, column, schema, **kw)


class _ExecDropForeignKey(Executable, ClauseElement):
inherit_cache = False

def __init__(
self, table_name: str, foreignkey_name: Column, schema: Optional[str]
) -> None:
self.table_name = table_name
self.foreignkey_name = foreignkey_name
self.schema = schema


@compiles(_ExecDropForeignKey, "iris")
def _exec_drop_foreign_key(
element: _ExecDropForeignKey, compiler: IRISDDLCompiler, **kw
) -> str:
return "%s DROP FOREIGN KEY %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.foreignkey_name),
)

# if not sqla_compat.sqla_2:
# self._skip_functional_indexes(metadata_indexes, conn_indexes)

@compiles(ColumnNullable, "iris")
def visit_column_nullable(
Expand Down
8 changes: 0 additions & 8 deletions sqlalchemy_iris/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,6 @@ def get_temp_table_names(self, connection, dblink=None, **kw):

@reflection.cache
def has_table(self, connection, table_name, schema=None, **kw):
self._ensure_has_table_connection(connection)
tables = ischema.tables
schema_name = self.get_schema(schema)

Expand All @@ -1085,14 +1084,7 @@ def has_table(self, connection, table_name, schema=None, **kw):
)
return bool(connection.execute(s).scalar())

def _default_or_error(self, connection, tablename, schema, method, **kw):
if self.has_table(connection, tablename, schema, **kw):
return method()
else:
raise exc.NoSuchTableError(f"{schema}.{tablename}")

def _get_all_objects(self, connection, schema, filter_names, scope, kind, **kw):
self._ensure_has_table_connection(connection)
tables = ischema.tables
schema_name = self.get_schema(schema)

Expand Down
47 changes: 47 additions & 0 deletions tests/test_alembic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,24 @@
except: # noqa
pass
else:
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy import inspect
from sqlalchemy import ForeignKey
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import text
from sqlalchemy.types import Text
from sqlalchemy.types import LargeBinary

from alembic import op
from alembic.testing import fixture
from alembic.testing import combinations
from alembic.testing import eq_
from alembic.testing.fixtures import TestBase
from alembic.testing.fixtures import op_fixture
from alembic.testing.suite._autogen_fixtures import AutogenFixtureTest

from alembic.testing.suite.test_op import (
BackendAlterColumnTest as _BackendAlterColumnTest,
)
Expand All @@ -23,3 +41,32 @@ def test_alter_column_autoincrement_pk_implicit_true(self):

def test_alter_column_autoincrement_pk_explicit_true(self):
pass

@combinations(
(None,),
("test",),
argnames="schema",
id_="s",
)
class RoundTripTest(TestBase):
@fixture
def tables(self, connection):
self.meta = MetaData()
self.meta.schema = self.schema
self.tbl_other = Table(
"other", self.meta, Column("oid", Integer, primary_key=True)
)
self.tbl = Table(
"round_trip_table",
self.meta,
Column("id", Integer, primary_key=True),
Column("oid_fk", ForeignKey("other.oid")),
)
self.meta.create_all(connection)
yield
self.meta.drop_all(connection)

def test_drop_col_with_fk(self, ops_context, connection, tables):
ops_context.drop_column("round_trip_table", "oid_fk", self.meta.schema)
insp = inspect(connection)
eq_(insp.get_foreign_keys("round_trip_table", schema=self.meta.schema), [])

0 comments on commit 4f42588

Please sign in to comment.