diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index 4b024d1a..bd75cd7b 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -50,8 +50,7 @@ from decimal import Decimal from sqlalchemy import sql, schema, types as sqltypes, util, event from sqlalchemy.schema import AddConstraint, ForeignKeyConstraint -from sqlalchemy.engine import default, reflection, Engine, Connection -from sqlalchemy.orm.session import Session +from sqlalchemy.engine import default, reflection from sqlalchemy.sql import compiler from sqlalchemy.sql.elements import quoted_name from datetime import date, datetime @@ -435,7 +434,8 @@ def _get_default_schema_name(self, connection): def _get_schema_from_url(self, connection, schema): if connection.engine.url is not None and connection.engine.url != "": schema = self.normalize_name( - connection.engine.url.translate_connect_args().get('database')) + connection.engine.url.translate_connect_args().get('database') + ) return schema def normalize_name(self, name): @@ -459,8 +459,7 @@ def denormalize_name(self, name): """ if name is None or len(name) == 0: return None - elif name.lower() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): + elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): name = name.upper() return name @@ -471,37 +470,15 @@ def on_connect(self): # TODO: set isolation level pass - def getODBCConnection(self, connection): - if isinstance(connection, Engine): - odbc_connection = connection.raw_connection().connection - elif isinstance(connection, Session): - odbc_connection = connection.connection() - return self.getODBCConnection(odbc_connection) - elif isinstance(connection, Connection): - odbc_connection = connection.connection.connection - else: - return None - if "pyodbc.Connection" in str(type(odbc_connection)): - return odbc_connection - else: - return None - - def use_sql_fallback(self, **kw): - result = "use_sql_fallback" in kw and kw.get("use_sql_fallback") == True - if result: - logger.warning("Using sql fallback instead of odbc functions") - return result + def _get_schema_names_query(self, connection, **kw): + return "select SCHEMA_NAME from SYS.EXA_SCHEMAS" # never called during reflection @reflection.cache def get_schema_names(self, connection, **kw): - if self.use_sql_fallback(**kw): - prefix = "/*snapshot execution*/ " - else: - prefix = "" - sql_stmnt = "%sselect SCHEMA_NAME from SYS.EXA_SCHEMAS" % prefix - rs = connection.execute(sql.text(sql_stmnt)) - return [self.normalize_name(row[0]) for row in rs] + sql_statement = self._get_schema_names_query(connection, **kw) + result = connection.execute(sql.text(sql_statement)) + return [self.normalize_name(row[0]) for row in result] def _get_schema_for_input_or_current(self, connection, schema): schema = self._get_schema_for_input(connection, schema) @@ -510,217 +487,118 @@ def _get_schema_for_input_or_current(self, connection, schema): return self.denormalize_name(schema) def _get_schema_for_input(self, connection, schema): - schema = self.denormalize_name(schema or self._get_schema_from_url(connection, schema)) - return schema + return self.denormalize_name( + schema or self._get_schema_from_url(connection, schema) + ) - def _get_current_schema(self, connection): - current_schema_stmnt = "SELECT CURRENT_SCHEMA" - current_schema = connection.execute(current_schema_stmnt).fetchone()[0] + @staticmethod + def _get_current_schema(connection): + sql_statement = "SELECT CURRENT_SCHEMA" + current_schema = connection.execute(sql_statement).fetchone()[0] return current_schema - @reflection.cache - def _get_tables_for_schema_odbc(self, connection, odbc_connection, schema, table_type=None, table_name=None, **kw): - schema = self._get_schema_for_input_or_current(connection, schema) - table_name = self.denormalize_name(table_name) - with odbc_connection.cursor().tables(schema=schema, tableType=table_type, table=table_name) as table_cursor: - rows = [row for row in table_cursor] - return rows - @reflection.cache def get_table_names(self, connection, schema, **kw): - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - return self.get_table_names_odbc(connection, odbc_connection, schema, **kw) - else: - return self.get_table_names_sql(connection, schema, **kw) - - @reflection.cache - def get_table_names_odbc(self, connection, odbc_connection, schema, **kw): - tables = self._get_tables_for_schema_odbc(connection, odbc_connection, schema, table_type="TABLE", **kw) - normalized_tables = [self.normalize_name(row.table_name) - for row in tables] - return normalized_tables - - @reflection.cache - def get_table_names_sql(self, connection, schema, **kw): schema = self._get_schema_for_input(connection, schema) - sql_stmnt = "SELECT table_name FROM SYS.EXA_ALL_TABLES WHERE table_schema = " + sql_statement = "SELECT table_name FROM SYS.EXA_ALL_TABLES WHERE table_schema = " if schema is None: - sql_stmnt += "CURRENT_SCHEMA ORDER BY table_name" - rs = connection.execute(sql_stmnt) + sql_statement += "CURRENT_SCHEMA ORDER BY table_name" + result = connection.execute(sql_statement) else: - sql_stmnt += ":schema ORDER BY table_name" - rs = connection.execute(sql.text(sql_stmnt), - schema=self.denormalize_name(schema)) - tables = [self.normalize_name(row[0]) for row in rs] + sql_statement += ":schema ORDER BY table_name" + result = connection.execute( + sql.text(sql_statement), + schema=self.denormalize_name(schema) + ) + tables = [self.normalize_name(row[0]) for row in result] return tables def has_table(self, connection, table_name, schema=None, **kw): - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - result=self.has_table_odbc(connection, odbc_connection, schema=schema, table_name=table_name, **kw) - else: - result=self.has_table_sql(connection, schema=schema, table_name=table_name, **kw) - return result - - def has_table_odbc(self, connection, odbc_connection, table_name, schema=None, **kw): - tables = self.get_table_names_odbc(connection=connection, - odbc_connection=odbc_connection, - schema=schema, table_name=table_name, **kw) - result = self.normalize_name(table_name) in tables - return result - - def has_table_sql(self, connection, table_name, schema=None, **kw): schema = self._get_schema_for_input(connection, schema) - sql_stmnt = "SELECT table_name from SYS.EXA_ALL_TABLES " \ - "WHERE table_name = :table_name " + sql_statement = ( + "SELECT table_name from SYS.EXA_ALL_TABLES " + "WHERE table_name = :table_name " + ) if schema is not None: - sql_stmnt += "AND table_schema = :schema" - rp = connection.execute( - sql.text(sql_stmnt), - table_name=self.denormalize_name(table_name), - schema=self.denormalize_name(schema)) - row = rp.fetchone() + sql_statement += "AND table_schema = :schema" - return (row is not None) + result = connection.execute( + sql.text(sql_statement), + table_name=self.denormalize_name(table_name), + schema=self.denormalize_name(schema) + ) + row = result.fetchone() + return row is not None @reflection.cache def get_view_names(self, connection, schema=None, **kw): - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - return self.get_view_names_odbc(connection, odbc_connection, schema, **kw) - else: - return self.get_view_names_sql(connection, schema, **kw) - - @reflection.cache - def get_view_names_odbc(self, connection, odbc_connection, schema=None, **kw): - tables = self._get_tables_for_schema_odbc(connection, odbc_connection, schema, table_type="VIEW", **kw) - return [self.normalize_name(row.table_name) - for row in tables] - - def get_view_names_sql(self, connection, schema=None, **kw): schema = self._get_schema_for_input(connection, schema) - sql_stmnt = "SELECT view_name FROM SYS.EXA_ALL_VIEWS WHERE view_schema = " + sql_statement = "SELECT view_name FROM SYS.EXA_ALL_VIEWS WHERE view_schema = " if schema is None: - sql_stmnt += "CURRENT_SCHEMA ORDER BY view_name" - rs = connection.execute(sql.text(sql_stmnt)) + sql_statement += "CURRENT_SCHEMA ORDER BY view_name" + result = connection.execute(sql.text(sql_statement)) else: - sql_stmnt += ":schema ORDER BY view_name" - rs = connection.execute(sql.text(sql_stmnt), - schema=self.denormalize_name(schema)) - return [self.normalize_name(row[0]) for row in rs] + sql_statement += ":schema ORDER BY view_name" + result = connection.execute( + sql.text(sql_statement), + schema=self.denormalize_name(schema) + ) + return [self.normalize_name(row[0]) for row in result] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - return self.get_view_definition_odbc(connection, odbc_connection, view_name, schema, **kw) - else: - return self.get_view_definition_sql(connection, view_name, schema, **kw) - - def quote_string_value(self, string_value): - return "'%s'" % (string_value.replace("'", "''")) - - @reflection.cache - def get_view_definition_odbc(self, connection, odbc_connection, view_name, schema=None, **kw): - if view_name is None: - return None - tables = self._get_tables_for_schema_odbc(connection, odbc_connection, schema, table_type="VIEW", - table_name=view_name, **kw) - if len(tables) == 1: - quoted_view_name_string = self.quote_string_value(tables[0][2]) - quoted_view_schema_string = self.quote_string_value(tables[0][1]) - sql_stmnt = \ - "/*snapshot execution*/ SELECT view_text FROM sys.exa_all_views WHERE view_name = {view_name} AND view_schema = {view_schema}".format( - view_name=quoted_view_name_string, view_schema=quoted_view_schema_string) - rp = connection.execute(sql.text(sql_stmnt)).scalar() - if rp: - return rp - else: - return None - else: - return None - - def get_view_definition_sql(self, connection, view_name, schema=None, **kw): schema = self._get_schema_for_input(connection, schema) sql_stmnt = "SELECT view_text FROM sys.exa_all_views WHERE view_name = :view_name AND view_schema = " if schema is None: sql_stmnt += "CURRENT_SCHEMA" else: sql_stmnt += ":schema" - rp = connection.execute(sql.text(sql_stmnt), - view_name=self.denormalize_name(view_name), - schema=self.denormalize_name(schema)).scalar() - if rp: - return rp - else: - return None - - def get_column_sql_query_str(self): - return "SELECT " \ - "column_name, " \ - "column_type, " \ - "column_maxsize, " \ - "column_num_prec, " \ - "column_num_scale, " \ - "column_is_nullable, " \ - "column_default, " \ - "column_identity, " \ - "column_is_distribution_key, " \ - "column_table " \ - "FROM sys.exa_all_columns " \ - "WHERE " \ - "column_object_type IN ('TABLE', 'VIEW') AND " \ - "column_schema = {schema} AND " \ - "column_table = {table} " \ - "ORDER BY column_ordinal_position" + result = connection.execute( + sql.text(sql_stmnt), + view_name=self.denormalize_name(view_name), + schema=self.denormalize_name(schema) + ).scalar() + return result if result else None - @reflection.cache - def _get_columns_odbc(self, connection, odbc_connection, table_name, schema, **kw): - tables = self._get_tables_for_schema_odbc(connection, odbc_connection, - schema=schema, table_name=table_name, **kw) - if len(tables) == 1: - # get_columns_sql originally returned all columns of all tables if table_name is None, - # we follow this behavior here for compatibility. However, the documentation for Dialects - # does not mentions this behavior: - # https://docs.sqlalchemy.org/en/13/core/internals.html#sqlalchemy.engine.interfaces.Dialect - quoted_schema_string = self.quote_string_value(tables[0].table_schem) - quoted_table_string = self.quote_string_value(tables[0].table_name) - sql_stmnt = \ - "/*snapshot execution*/ " + \ - self.get_column_sql_query_str() \ - .format(schema=quoted_schema_string, table=quoted_table_string) - rp = connection.execute(sql.text(sql_stmnt)) - return list(rp) - else: - return [] + @staticmethod + def quote_string_value(string_value): + return "'%s'" % (string_value.replace("'", "''")) - @reflection.cache - def _get_columns_sql(self, connection, table_name, schema=None, **kw): - schema = self._get_schema_for_input(connection, schema) - if schema is None: - schema_str = "CURRENT_SCHEMA" - else: - schema_str = ":schema" - table_name_str = ":table" - sql_stmnt = \ - self.get_column_sql_query_str() \ - .format(schema=schema_str, table=table_name_str) - stmnt = sql.text(sql_stmnt) - rp = connection.execute(stmnt, - schema=self.denormalize_name(schema), - table=self.denormalize_name(table_name)) + @staticmethod + def get_column_sql_query_str(): + return ( + "SELECT " + "column_name, " + "column_type, " + "column_maxsize, " + "column_num_prec, " + "column_num_scale, " + "column_is_nullable, " + "column_default, " + "column_identity, " + "column_is_distribution_key, " + "column_table " + "FROM sys.exa_all_columns " + "WHERE " + "column_object_type IN ('TABLE', 'VIEW') AND " + "column_schema = {schema} AND " + "column_table = {table} " + "ORDER BY column_ordinal_position" + ) - return list(rp) @reflection.cache def _get_columns(self, connection, table_name, schema=None, **kw): - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - columns = self._get_columns_odbc(connection, odbc_connection, table_name, schema, **kw) - else: - columns = self._get_columns_sql(connection, table_name, schema, **kw) - return columns + schema = self._get_schema_for_input(connection, schema) + schema_str = "CURRENT_SCHEMA" if schema is None else ":schema" + table_name_str = ":table" + sql_statement = self.get_column_sql_query_str().format(schema=schema_str, table=table_name_str) + result = connection.execute( + sql.text(sql_statement), + schema=self.denormalize_name(schema), + table=self.denormalize_name(table_name) + ) + return list(result) @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -728,10 +606,12 @@ def get_columns(self, connection, table_name, schema=None, **kw): return [] columns = [] - rows = self._get_columns(connection, - table_name=table_name, - schema=schema, **kw) - table_name = self.denormalize_name(table_name) + rows = self._get_columns( + connection, + table_name=table_name, + schema=schema, + **kw + ) for row in rows: (colname, coltype, length, precision, scale, nullable, default, identity, is_distribution_key) = \ (row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8]) @@ -783,41 +663,28 @@ def get_columns(self, connection, table_name, schema=None, **kw): columns.append(cdict) return columns - def _get_constraint_sql_str(self, schema, table_name, contraint_type): - sql_stmnt = \ - "SELECT " \ - "constraint_name, " \ - "column_name, " \ - "referenced_schema, " \ - "referenced_table, " \ - "referenced_column, " \ - "constraint_table, " \ - "constraint_type " \ - "FROM SYS.EXA_ALL_CONSTRAINT_COLUMNS " \ - "WHERE " \ - "constraint_schema={schema} AND " \ - "constraint_table={table_name} AND " \ - "constraint_type='{contraint_type}' " \ - "ORDER BY ordinal_position" \ - .format(schema=schema, table_name=table_name, contraint_type=contraint_type) - return sql_stmnt + @staticmethod + def _get_constraint_sql_str(schema, table_name, contraint_type): + return ( + "SELECT " + "constraint_name, " + "column_name, " + "referenced_schema, " + "referenced_table, " + "referenced_column, " + "constraint_table, " + "constraint_type " + "FROM SYS.EXA_ALL_CONSTRAINT_COLUMNS " + "WHERE " + f"constraint_schema={schema} AND " + f"constraint_table={table_name} AND " + f"constraint_type='{contraint_type}' " + "ORDER BY ordinal_position" + ) - @reflection.cache - def get_pk_constraint_odbc(self, connection, odbc_connection, table_name, schema=None, **kw): - schema = self._get_schema_for_input_or_current(connection, schema) - table_name = self.denormalize_name(table_name) - with odbc_connection.cursor().primaryKeys(table=table_name, schema=schema) as primaryKeys_cursor: - pkeys = [] - constraint_name = None - for row in primaryKeys_cursor: - if row[2] != table_name and table_name is not None: - continue - pkeys.append(self.normalize_name(row[3])) - constraint_name = self.normalize_name(row[5]) - return {'constrained_columns': pkeys, 'name': constraint_name} @reflection.cache - def get_pk_constraint_sql(self, connection, table_name, schema=None, **kw): + def _get_pk_constraint(self, connection, table_name, schema, **kw): schema = self._get_schema_for_input(connection, schema) table_name = self.denormalize_name(table_name) table_name_string = ":table" @@ -825,13 +692,15 @@ def get_pk_constraint_sql(self, connection, table_name, schema=None, **kw): schema_string = "CURRENT_SCHEMA " else: schema_string = ":schema " - sql_stmnt=self._get_constraint_sql_str(schema_string, table_name_string, "PRIMARY KEY") - rp = connection.execute(sql.text(sql_stmnt), - schema=self.denormalize_name(schema), - table=table_name) + sql_statement = self._get_constraint_sql_str(schema_string, table_name_string, "PRIMARY KEY") + result = connection.execute( + sql.text(sql_statement), + schema=self.denormalize_name(schema), + table=table_name + ) pkeys = [] constraint_name = None - for row in list(rp): + for row in list(result): if (row[5] != table_name and table_name is not None) or row[6] != 'PRIMARY KEY': continue pkeys.append(self.normalize_name(row[1])) @@ -843,41 +712,21 @@ def get_pk_constraint_sql(self, connection, table_name, schema=None, **kw): def get_pk_constraint(self, connection, table_name, schema=None, **kw): if table_name is None: return None - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - return self.get_pk_constraint_odbc(connection, odbc_connection, table_name=table_name, schema=schema, **kw) - else: - return self.get_pk_constraint_sql(connection, table_name=table_name, schema=schema, **kw) + return self._get_pk_constraint(connection, table_name, schema=schema, **kw) - @reflection.cache - def _get_foreign_keys_odbc(self, connection, odbc_connection, table_name, schema=None, **kw): - # Need to use a workaround, because SQLForeignKeys functions doesn't work for an unknown reason - tables = self._get_tables_for_schema_odbc(connection=connection, odbc_connection=odbc_connection, - schema=schema, table_name=table_name, table_type="TABLE", **kw) - if len(tables) > 0: - quoted_schema_string = self.quote_string_value(tables[0].table_schem) - quoted_table_string = self.quote_string_value(tables[0].table_name) - sql_stmnt = \ - "/*snapshot execution*/ " + \ - self._get_constraint_sql_str(quoted_schema_string,quoted_table_string,"FOREIGN KEY") - rp = connection.execute(sql.text(sql_stmnt)) - return list(rp) - else: - return [] @reflection.cache - def _get_foreign_keys_sql(self, connection, table_name, schema=None, **kw): + def _get_foreign_keys(self, connection, table_name, schema=None, **kw): table_name_string = ":table" - if schema is None: - schema_string = "CURRENT_SCHEMA " - else: - schema_string = ":schema " - sql_stmnt = \ - self._get_constraint_sql_str(schema_string, table_name_string, "FOREIGN KEY") - rp = connection.execute(sql.text(sql_stmnt), - schema=self.denormalize_name(schema), - table=self.denormalize_name(table_name)) - return list(rp) + schema_string = "CURRENT_SCHEMA " if schema is None else ":schema " + sql_statement = self._get_constraint_sql_str(schema_string, table_name_string, "FOREIGN KEY") + result = connection.execute( + sql.text(sql_statement), + schema=self.denormalize_name(schema), + table=self.denormalize_name(table_name) + ) + return list(result) + @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -895,13 +744,7 @@ def fkey_rec(): } fkeys = util.defaultdict(fkey_rec) - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - constraints = self._get_foreign_keys_odbc(connection, odbc_connection, table_name=table_name, - schema=schema_int, **kw) - else: - constraints = self._get_foreign_keys_sql(connection, table_name=table_name, schema=schema_int, **kw) - table_name = self.denormalize_name(table_name) + constraints = self._get_foreign_keys(connection, table_name=table_name, schema=schema_int, **kw) for row in constraints: (cons_name, local_column, remote_schema, remote_table, remote_column) = \ (row[0], row[1], row[2], row[3], row[4]) @@ -926,5 +769,5 @@ def fkey_rec(): @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): - # EXASolution has no explicit indexes + """ EXASolution has no explicit indexes""" return [] diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index dc3a1080..e5af6470 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -8,16 +8,20 @@ import re import sys +import logging from distutils.version import LooseVersion +from sqlalchemy import sql +from sqlalchemy.engine import reflection from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy.util.langhelpers import asbool from sqlalchemy_exasol.base import EXADialect, EXAExecutionContext +logger = logging.getLogger("sqlalchemy_exasol") -class EXADialect_pyodbc(EXADialect, PyODBCConnector): +class EXADialect_pyodbc(EXADialect, PyODBCConnector): execution_ctx_cls = EXAExecutionContext driver_version = None @@ -58,7 +62,6 @@ def _get_server_version_info(self, connection): return self.server_version_info if sys.platform == "darwin": - def connect(self, *cargs, **cparams): # Get connection conn = super().connect(*cargs, **cparams) @@ -164,5 +167,139 @@ def is_disconnect(self, e, connection, cursor): return super().is_disconnect(e, connection, cursor) + @staticmethod + def _is_sql_fallback_requested(**kwargs): + is_fallback_requested = kwargs.get("use_sql_fallback", False) + if is_fallback_requested: + logger.warning("Using sql fallback instead of odbc functions") + return is_fallback_requested + + @reflection.cache + def _tables_for_schema(self, connection, schema, table_type=None, table_name=None): + schema = self._get_schema_for_input_or_current(connection, schema) + table_name = self.denormalize_name(table_name) + conn = connection.engine.raw_connection() + with conn.cursor().tables(schema=schema, tableType=table_type, table=table_name) as table_cursor: + return [row for row in table_cursor] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if self._is_sql_fallback_requested(**kw): + return super().get_view_definition(connection, view_name, schema, **kw) + if view_name is None: + return None + + tables = self._tables_for_schema(connection, schema, table_type="VIEW", table_name=view_name) + if len(tables) != 1: + return None + + quoted_view_name_string = self.quote_string_value(tables[0][2]) + quoted_view_schema_string = self.quote_string_value(tables[0][1]) + sql_statement = ( + "/*snapshot execution*/ SELECT view_text " + f"FROM sys.exa_all_views WHERE view_name = {quoted_view_name_string} " + f"AND view_schema = {quoted_view_schema_string}" + ) + result = connection.execute(sql.text(sql_statement)).scalar() + return result if result else None + + @reflection.cache + def get_table_names(self, connection, schema, **kw): + if self._is_sql_fallback_requested(**kw): + return super().get_table_names(connection, schema, **kw) + tables = self._tables_for_schema(connection, schema, table_type="TABLE") + return [self.normalize_name(row.table_name) for row in tables] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if self._is_sql_fallback_requested(**kw): + return super().get_view_names(connection, schema, **kw) + tables = self._tables_for_schema(connection, schema, table_type="VIEW") + return [self.normalize_name(row.table_name) for row in tables] + + def has_table(self, connection, table_name, schema=None, **kw): + if self._is_sql_fallback_requested(**kw): + return super().has_table(connection, table_name, schema, **kw) + tables = self.get_table_names( + connection=connection, + schema=schema, + table_name=table_name, + **kw + ) + return self.normalize_name(table_name) in tables + + def _get_schema_names_query(self, connection, **kw): + if self._is_sql_fallback_requested(**kw): + return super()._get_schema_names_query(connection, **kw) + return "/*snapshot execution*/ " + super()._get_schema_names_query(connection, **kw) + + @reflection.cache + def _get_columns(self, connection, table_name, schema=None, **kw): + if self._is_sql_fallback_requested(**kw): + return super()._get_columns(connection, table_name, schema, **kw) + + tables = self._tables_for_schema(connection, schema=schema, table_name=table_name) + if len(tables) != 1: + return [] + + # get_columns_sql originally returned all columns of all tables if table_name is None, + # we follow this behavior here for compatibility. However, the documentation for Dialects + # does not mention this behavior: + # https://docs.sqlalchemy.org/en/13/core/internals.html#sqlalchemy.engine.interfaces.Dialect + quoted_schema_string = self.quote_string_value(tables[0].table_schem) + quoted_table_string = self.quote_string_value(tables[0].table_name) + sql_statement = "/*snapshot execution*/ {query}".format(query=self.get_column_sql_query_str()) + sql_statement = sql_statement.format(schema=quoted_schema_string, table=quoted_table_string) + response = connection.execute(sql_statement) + + return list(response) + + @reflection.cache + def _get_pk_constraint(self, connection, table_name, schema=None, **kw): + if self._is_sql_fallback_requested(**kw): + return super()._get_pk_constraint(connection, table_name, schema, **kw) + + conn = connection.engine.raw_connection() + schema = self._get_schema_for_input_or_current(connection, schema) + table_name = self.denormalize_name(table_name) + with conn.cursor().primaryKeys(table=table_name, schema=schema) as cursor: + pkeys = [] + constraint_name = None + for row in cursor: + table, primary_key, constraint = row[2], row[3], row[5] + if table != table_name and table_name is not None: + continue + pkeys.append(self.normalize_name(primary_key)) + constraint_name = self.normalize_name(constraint) + return {'constrained_columns': pkeys, 'name': constraint_name} + + @reflection.cache + def _get_foreign_keys(self, connection, table_name, schema=None, **kw): + if self._is_sql_fallback_requested(**kw): + return super()._get_foreign_keys(connection, table_name, schema, **kw) + + # Need to use a workaround, because SQLForeignKeys functions doesn't work for an unknown reason + tables = self._tables_for_schema( + connection=connection, + schema=schema, + table_name=table_name, + table_type="TABLE" + ) + if len(tables) == 0: + return [] + + quoted_schema_string = self.quote_string_value(tables[0].table_schem) + quoted_table_string = self.quote_string_value(tables[0].table_name) + sql_statement = "/*snapshot execution*/ {query}".format( + query=self._get_constraint_sql_str( + quoted_schema_string, + quoted_table_string, + "FOREIGN KEY" + ) + ) + response = connection.execute(sql_statement) + + return list(response) + dialect = EXADialect_pyodbc diff --git a/test/test_deadlock.py b/test/test_deadlock.py index ab8a0add..2278d300 100644 --- a/test/test_deadlock.py +++ b/test/test_deadlock.py @@ -6,12 +6,15 @@ import pytest from sqlalchemy import create_engine from sqlalchemy.testing import fixtures, config +from sqlalchemy.engine.reflection import Inspector import sqlalchemy.testing as testing -from sqlalchemy_exasol.base import EXADialect -#TODO get_schema_names, get_view_names and get_view_definition didn't cause deadlocks in this scenario -@pytest.mark.skipif("turbodbc" in str(testing.db.url), reason="We currently don't support snapshot metadata requests for turbodbc") +# TODO: get_schema_names, get_view_names and get_view_definition didn't cause deadlocks in this scenario +@pytest.mark.skipif( + "pyodbc" not in str(testing.db.url), + reason="We currently only support snapshot metadata requests in pyodbc based dialect" +) class MetadataTest(fixtures.TablesTest): __backend__ = True @@ -25,14 +28,14 @@ def create_transaction(self, url, con_name): def test_no_deadlock_for_get_table_names_without_fallback(self): def without_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_table_names(session2, schema=schema, use_sql_fallback=False) self.run_deadlock_for_table(without_fallback) def test_deadlock_for_get_table_names_with_fallback(self): def with_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_table_names(session2, schema=schema, use_sql_fallback=True) with pytest.raises(Exception): @@ -40,7 +43,7 @@ def with_fallback(session2, schema, table): def test_no_deadlock_for_get_columns_without_fallback(self): def without_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_columns(session2, schema=schema, table_name=table, use_sql_fallback=False) self.run_deadlock_for_table(without_fallback) @@ -48,36 +51,35 @@ def without_fallback(session2, schema, table): def test_no_deadlock_for_get_columns_with_fallback(self): # TODO: Doesnt produce a deadlock anymore since last commit? def with_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_columns(session2, schema=schema, table_name=table, use_sql_fallback=True) self.run_deadlock_for_table(with_fallback) def test_no_deadlock_for_get_pk_constraint_without_fallback(self): def without_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_pk_constraint(session2, table_name=table, schema=schema, use_sql_fallback=False) self.run_deadlock_for_table(without_fallback) def test_no_deadlock_for_get_pk_constraint_with_fallback(self): def with_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_pk_constraint(session2, table_name=table, schema=schema, use_sql_fallback=True) self.run_deadlock_for_table(with_fallback) def test_no_deadlock_for_get_foreign_keys_without_fallback(self): def without_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_foreign_keys(session2, table_name=table, schema=schema, use_sql_fallback=False) self.run_deadlock_for_table(without_fallback) - def test_no_deadlock_for_get_foreign_keys_with_fallback(self): def with_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_foreign_keys(session2, table_name=table, schema=schema, use_sql_fallback=True) self.run_deadlock_for_table(with_fallback) @@ -85,7 +87,7 @@ def with_fallback(session2, schema, table): def test_no_deadlock_for_get_view_names_without_fallback(self): # TODO: think of other scenarios where metadata deadlocks with view could happen def without_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_view_names(session2, table_name=table, schema=schema, use_sql_fallback=False) self.run_deadlock_for_table(without_fallback) @@ -93,7 +95,7 @@ def without_fallback(session2, schema, table): def test_no_deadlock_for_get_view_names_with_fallback(self): # TODO: think of other scenarios where metadata deadlocks with view could happen def with_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_view_names(session2, table_name=table, schema=schema, use_sql_fallback=True) self.run_deadlock_for_table(with_fallback) diff --git a/test/test_get_metadata_functions.py b/test/test_get_metadata_functions.py index c171e9fd..c66dd87d 100644 --- a/test/test_get_metadata_functions.py +++ b/test/test_get_metadata_functions.py @@ -5,8 +5,10 @@ from sqlalchemy.engine.url import URL from sqlalchemy.sql.sqltypes import INTEGER, VARCHAR from sqlalchemy.testing import fixtures, config +from sqlalchemy.engine.reflection import Inspector from sqlalchemy_exasol.base import EXADialect +from sqlalchemy_exasol.pyodbc import EXADialect_pyodbc TEST_GET_METADATA_FUNCTIONS_SCHEMA = "test_get_metadata_functions_schema" ENGINE_NONE_DATABASE = "ENGINE_NONE_DATABASE" @@ -74,14 +76,14 @@ def create_engine_with_database_name(cls, connection, new_database_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_schema_names(self, engine_name, use_sql_fallback): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect schema_names = dialect.get_schema_names(connection=c, use_sql_fallback=use_sql_fallback) assert self.schema in schema_names and self.schema_2 in schema_names @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_compare_get_schema_names_for_sql_and_odbc(self, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect schema_names_fallback = dialect.get_schema_names(connection=c, use_sql_fallback=True) schema_names_odbc = dialect.get_schema_names(connection=c) assert sorted(schema_names_fallback) == sorted(schema_names_odbc) @@ -90,7 +92,7 @@ def test_compare_get_schema_names_for_sql_and_odbc(self, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_table_names(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect table_names = dialect.get_table_names(connection=c, schema=self.schema, use_sql_fallback=use_sql_fallback) assert "t" in table_names and "s" in table_names @@ -100,7 +102,7 @@ def test_compare_get_table_names_for_sql_and_odbc(self, schema, engine_name): with self.engine_map[engine_name].begin() as c: if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) - dialect = EXADialect() + dialect = Inspector(c).dialect table_names_fallback = dialect.get_table_names(connection=c, schema=schema, use_sql_fallback=True) table_names_odbc = dialect.get_table_names(connection=c, schema=schema) assert table_names_fallback == table_names_odbc @@ -109,7 +111,7 @@ def test_compare_get_table_names_for_sql_and_odbc(self, schema, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_has_table_table_exists(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect has_table = dialect.has_table(connection=c, schema=self.schema, table_name="t", use_sql_fallback=use_sql_fallback) assert has_table, "Table %s.T was not found, but should exist" % self.schema @@ -118,7 +120,7 @@ def test_has_table_table_exists(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_has_table_table_exists_not(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect has_table = dialect.has_table(connection=c, schema=self.schema, table_name="not_exist", use_sql_fallback=use_sql_fallback) assert not has_table, "Table %s.not_exist was found, but should not exist" % self.schema @@ -127,7 +129,7 @@ def test_has_table_table_exists_not(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_compare_has_table_for_sql_and_odbc(self, schema, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect has_table_fallback = dialect.has_table(connection=c, schema=schema, use_sql_fallback=True, table_name="t") has_table_odbc = dialect.has_table(connection=c, schema=schema, table_name="t") assert has_table_fallback == has_table_odbc, "Expected table %s.t with odbc and fallback" % schema @@ -136,7 +138,7 @@ def test_compare_has_table_for_sql_and_odbc(self, schema, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_view_names(self, use_sql_fallback,engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect view_names = dialect.get_view_names(connection=c, schema=self.schema, use_sql_fallback=use_sql_fallback) assert "v" in view_names @@ -144,7 +146,7 @@ def test_get_view_names(self, use_sql_fallback,engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_view_names_for_sys(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect view_names = dialect.get_view_names(connection=c, schema="sys", use_sql_fallback=use_sql_fallback) assert len(view_names) == 0 @@ -152,7 +154,7 @@ def test_get_view_names_for_sys(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_view_definition(self, use_sql_fallback,engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect view_definition = dialect.get_view_definition(connection=c, schema=self.schema, view_name="v", use_sql_fallback=use_sql_fallback) assert self.view_defintion == view_definition @@ -161,7 +163,7 @@ def test_get_view_definition(self, use_sql_fallback,engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_view_definition_view_name_none(self, use_sql_fallback,engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect view_definition = dialect.get_view_definition(connection=c, schema=self.schema, view_name=None, use_sql_fallback=use_sql_fallback) assert view_definition is None @@ -170,7 +172,7 @@ def test_get_view_definition_view_name_none(self, use_sql_fallback,engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_compare_get_view_names_for_sql_and_odbc(self, schema,engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect c.execute("OPEN SCHEMA %s" % self.schema) view_names_fallback = dialect.get_view_names(connection=c, schema=schema, use_sql_fallback=True) view_names_odbc = dialect.get_view_names(connection=c, schema=schema) @@ -183,7 +185,7 @@ def test_compare_get_view_definition_for_sql_and_odbc(self, schema,engine_name): if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) view_name = "v" - dialect = EXADialect() + dialect = Inspector(c).dialect view_definition_fallback = dialect.get_view_definition( connection=c, view_name=view_name, schema=schema, use_sql_fallback=True) view_definition_odbc = dialect.get_view_definition( @@ -195,9 +197,9 @@ def test_compare_get_view_definition_for_sql_and_odbc(self, schema,engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_compare_get_columns_for_sql_and_odbc(self, schema, table, engine_name): with self.engine_map[engine_name].begin() as c: + dialect = Inspector(c).dialect if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) - dialect = EXADialect() columns_fallback = dialect.get_columns(connection=c, table_name=table, schema=schema, use_sql_fallback=True) columns_odbc = dialect.get_columns(connection=c, table_name=table, schema=schema) assert str(columns_fallback) == str(columns_odbc) # object equality doesn't work for sqltypes @@ -208,7 +210,7 @@ def test_compare_get_columns_none_table_for_sql_and_odbc(self, schema, engine_na with self.engine_map[engine_name].begin() as c: if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) - dialect = EXADialect() + dialect = Inspector(c).dialect table = None columns_fallback = dialect.get_columns(connection=c, table_name=table, schema=schema, use_sql_fallback=True) @@ -222,7 +224,7 @@ def make_columns_comparable(self, column_list): # object equality doesn't work @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_columns(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect columns = dialect.get_columns(connection=c, schema=self.schema, table_name="t", use_sql_fallback=use_sql_fallback) expected = [{'default': None, @@ -253,7 +255,7 @@ def test_get_columns(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_columns_table_name_none(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect columns = dialect.get_columns(connection=c, schema=self.schema, table_name=None, use_sql_fallback=use_sql_fallback) assert columns == [] @@ -265,7 +267,7 @@ def test_compare_get_pk_constraint_for_sql_and_odbc(self, schema, table, engine_ with self.engine_map[engine_name].begin() as c: if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) - dialect = EXADialect() + dialect = Inspector(c).dialect pk_constraint_fallback = dialect.get_pk_constraint(connection=c, table_name=table, schema=schema, use_sql_fallback=True) pk_constraint_odbc = dialect.get_pk_constraint(connection=c, table_name=table, schema=schema) @@ -275,7 +277,7 @@ def test_compare_get_pk_constraint_for_sql_and_odbc(self, schema, table, engine_ @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_pk_constraint(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect pk_constraint = dialect.get_pk_constraint(connection=c, schema=self.schema, table_name="t", use_sql_fallback=use_sql_fallback) assert pk_constraint["constrained_columns"] == ['pid1', 'pid2'] and \ @@ -285,7 +287,7 @@ def test_get_pk_constraint(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_pk_constraint_table_name_none(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect pk_constraint = dialect.get_pk_constraint(connection=c, schema=self.schema, table_name=None, use_sql_fallback=use_sql_fallback) assert pk_constraint is None @@ -297,7 +299,7 @@ def test_compare_get_foreign_keys_for_sql_and_odbc(self, schema, table, engine_n with self.engine_map[engine_name].begin() as c: if schema is None: c.execute("OPEN SCHEMA %s" % self.schema_2) - dialect = EXADialect() + dialect = Inspector(c).dialect foreign_keys_fallback = dialect.get_foreign_keys(connection=c, table_name=table, schema=schema, use_sql_fallback=True) foreign_keys_odbc = dialect.get_foreign_keys(connection=c, table_name=table, schema=schema) @@ -307,7 +309,7 @@ def test_compare_get_foreign_keys_for_sql_and_odbc(self, schema, table, engine_n @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_foreign_keys(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect foreign_keys = dialect.get_foreign_keys(connection=c, schema=self.schema, table_name="s", use_sql_fallback=use_sql_fallback) expected = [{'name': 'fk_test', @@ -322,7 +324,7 @@ def test_get_foreign_keys(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_foreign_keys_table_name_none(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect foreign_keys = dialect.get_foreign_keys(connection=c, schema=self.schema, table_name=None, use_sql_fallback=use_sql_fallback) assert foreign_keys == [] diff --git a/test/test_regression.py b/test/test_regression.py index 2f7076c5..78e9817f 100644 --- a/test/test_regression.py +++ b/test/test_regression.py @@ -1,29 +1,40 @@ """This module contains various regression test for issues which have been fixed in the past""" -from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, schema +import pytest +from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, inspect +from sqlalchemy.pool import ( + AssertionPool, + NullPool, + QueuePool, + SingletonThreadPool, + StaticPool, +) +from sqlalchemy.schema import CreateSchema, DropSchema from sqlalchemy.testing import fixtures from sqlalchemy.testing.fixtures import config -class RegressionTest(fixtures.TestBase): - def setUp(self): - self.table_name = "my_table" - self.tenant_schema_name = "tenant_schema" +class TranslateMap(fixtures.TestBase): + @classmethod + def setup_class(cls): + cls.table_name = "my_table" + cls.tenant_schema_name = "tenant_schema" engine = config.db with config.db.connect() as conn: - conn.execute(schema.CreateSchema(self.tenant_schema_name)) + conn.execute(CreateSchema(cls.tenant_schema_name)) metadata = MetaData() Table( - self.table_name, + cls.table_name, metadata, Column("id", Integer, primary_key=True), Column("name", String(1000), nullable=False), - schema=self.tenant_schema_name, + schema=cls.tenant_schema_name, ) metadata.create_all(engine) - def tearDown(self): + @classmethod + def teardown_class(cls): with config.db.connect() as conn: - conn.execute(schema.DropSchema(self.tenant_schema_name, cascade=True)) + conn.execute(DropSchema(cls.tenant_schema_name, cascade=True)) def test_use_schema_translate_map_in_get_last_row_id(self): """See also: https://github.com/exasol/sqlalchemy-exasol/issues/104""" @@ -40,3 +51,80 @@ def test_use_schema_translate_map_in_get_last_row_id(self): engine = create_engine(config.db.url, execution_options=options) with engine.connect() as conn: conn.execute(my_table.insert().values(name="John Doe")) + + +class Introspection(fixtures.TestBase): + """Regression(s) for issue: https://github.com/exasol/sqlalchemy-exasol/issues/136""" + + POOL_TYPES = [ + QueuePool, + NullPool, + AssertionPool, + StaticPool, + SingletonThreadPool, + ] + + @classmethod + def setup_class(cls): + def _create_tables(schema, tables): + engine = config.db + with engine.connect(): + metadata = MetaData() + for name in tables: + Table( + name, + metadata, + Column("id", Integer, primary_key=True), + Column("random_field", String(1000)), + schema=schema, + ) + metadata.create_all(engine) + + def _create_views(schema, views): + engine = config.db + with engine.connect() as conn: + for name in views: + conn.execute( + f"CREATE OR REPLACE VIEW {schema}.{name} AS SELECT 1 as COLUMN_1;" + ) + + cls.schema = "test" + cls.tables = ["a_table", "b_table", "c_table", "d_table"] + cls.views = ["a_view", "b_view"] + + _create_tables(cls.schema, cls.tables) + _create_views(cls.schema, cls.views) + + @classmethod + def teardown_class(cls): + engine = config.db + + def _drop_tables(schema): + metadata = MetaData(engine, schema=schema) + metadata.reflect() + to_be_deleted = [metadata.tables[name] for name in metadata.tables] + metadata.drop_all(engine, to_be_deleted) + + def _drop_views(schema, views): + with engine.connect() as conn: + for name in views: + conn.execute(f"DROP VIEW {schema}.{name};") + + _drop_tables(cls.schema) + _drop_views(cls.schema, cls.views) + + @pytest.mark.parametrize("pool_type", POOL_TYPES) + def test_introspection_of_tables_works_with(self, pool_type): + expected = self.tables + engine = create_engine(config.db.url, poolclass=pool_type) + inspector = inspect(engine) + tables = inspector.get_table_names(schema=self.schema) + assert expected == tables + + @pytest.mark.parametrize("pool_type", POOL_TYPES) + def test_introspection_of_views_works_with(self, pool_type): + expected = self.views + engine = create_engine(config.db.url, poolclass=pool_type) + inspector = inspect(engine) + tables = inspector.get_view_names(schema=self.schema) + assert expected == tables