diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index a6d91f49..69946b79 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -175,28 +175,13 @@ def _is_sql_fallback_requested(**kwargs): logger.warning("Using sql fallback instead of odbc functions") return is_fallback_requested - def _get_odbc_connection(self, connection): - if isinstance(connection, Engine): - odbc_connection = connection.raw_connection().connection - elif isinstance(connection, Session): - odbc_connection = connection.connection() - return self._get_odbc_connection(odbc_connection) - elif isinstance(connection, Connection): - odbc_connection = connection.connection.connection - else: - return None - if "pyodbc.Connection" in str(type(odbc_connection)): - return odbc_connection - else: - return None - @reflection.cache - def _get_tables_for_schema_odbc(self, connection, odbc_connection, schema, table_type=None, table_name=None, **kw): + def _get_tables_for_schema_odbc(self, connection, schema, table_type=None, table_name=None, **kw): schema = self._get_schema_for_input_or_current(connection, schema) table_name = self.denormalize_name(table_name) - with odbc_connection.cursor().tables(schema=schema, tableType=table_type, table=table_name) as table_cursor: - rows = [row for row in table_cursor] - return rows + conn = connection.engine.raw_connection() + with conn.cursor().tables(schema=schema, tableType=table_type, table=table_name) as table_cursor: + return [row for row in table_cursor] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): @@ -205,9 +190,8 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): if view_name is None: return None - odbc_connection = self._get_odbc_connection(connection) tables = self._get_tables_for_schema_odbc( - connection, odbc_connection, schema, + connection, schema, table_type="VIEW", table_name=view_name, **kw @@ -229,10 +213,9 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): def get_table_names(self, connection, schema, **kw): if self._is_sql_fallback_requested(**kw): return super().get_table_names(connection, schema, **kw) - odbc_connection = self._get_odbc_connection(connection) tables = self._get_tables_for_schema_odbc( connection, - odbc_connection, + schema, table_type="TABLE", **kw ) @@ -242,10 +225,8 @@ def get_table_names(self, connection, schema, **kw): def get_view_names(self, connection, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super().get_view_names(connection, schema, **kw) - odbc_connection = self._get_odbc_connection(connection) tables = self._get_tables_for_schema_odbc( connection, - odbc_connection, schema, table_type="VIEW", **kw @@ -273,9 +254,8 @@ def _get_columns(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super()._get_columns(connection, table_name, schema, **kw) - odbc_connection = self._get_odbc_connection(connection) tables = self._get_tables_for_schema_odbc( - connection, odbc_connection, + connection, schema=schema, table_name=table_name, **kw @@ -301,10 +281,10 @@ def _get_pk_constraint(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super()._get_pk_constraint(connection, table_name, schema, **kw) - odbc_connection = self._get_odbc_connection(connection) + conn = connection.engine.raw_connection() schema = self._get_schema_for_input_or_current(connection, schema) table_name = self.denormalize_name(table_name) - with odbc_connection.cursor().primaryKeys(table=table_name, schema=schema) as cursor: + with conn.cursor().primaryKeys(table=table_name, schema=schema) as cursor: pkeys = [] constraint_name = None for row in cursor: @@ -320,11 +300,9 @@ def _get_foreign_keys(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super()._get_foreign_keys(connection, table_name, schema, **kw) - odbc_connection = self._get_odbc_connection(connection) # Need to use a workaround, because SQLForeignKeys functions doesn't work for an unknown reason tables = self._get_tables_for_schema_odbc( connection=connection, - odbc_connection=odbc_connection, schema=schema, table_name=table_name, table_type="TABLE",