Skip to content

Commit

Permalink
Fix bug #136
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicoretti committed May 17, 2022
1 parent 17937e4 commit 5dd460c
Showing 1 changed file with 9 additions and 31 deletions.
40 changes: 9 additions & 31 deletions sqlalchemy_exasol/pyodbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,28 +175,13 @@ def _is_sql_fallback_requested(**kwargs):
logger.warning("Using sql fallback instead of odbc functions")
return is_fallback_requested

def _get_odbc_connection(self, connection):
if isinstance(connection, Engine):
odbc_connection = connection.raw_connection().connection
elif isinstance(connection, Session):
odbc_connection = connection.connection()
return self._get_odbc_connection(odbc_connection)
elif isinstance(connection, Connection):
odbc_connection = connection.connection.connection
else:
return None
if "pyodbc.Connection" in str(type(odbc_connection)):
return odbc_connection
else:
return None

@reflection.cache
def _get_tables_for_schema_odbc(self, connection, odbc_connection, schema, table_type=None, table_name=None, **kw):
def _get_tables_for_schema_odbc(self, connection, schema, table_type=None, table_name=None, **kw):
schema = self._get_schema_for_input_or_current(connection, schema)
table_name = self.denormalize_name(table_name)
with odbc_connection.cursor().tables(schema=schema, tableType=table_type, table=table_name) as table_cursor:
rows = [row for row in table_cursor]
return rows
conn = connection.engine.raw_connection()
with conn.cursor().tables(schema=schema, tableType=table_type, table=table_name) as table_cursor:
return [row for row in table_cursor]

@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
Expand All @@ -205,9 +190,8 @@ def get_view_definition(self, connection, view_name, schema=None, **kw):
if view_name is None:
return None

odbc_connection = self._get_odbc_connection(connection)
tables = self._get_tables_for_schema_odbc(
connection, odbc_connection, schema,
connection, schema,
table_type="VIEW",
table_name=view_name,
**kw
Expand All @@ -229,10 +213,9 @@ def get_view_definition(self, connection, view_name, schema=None, **kw):
def get_table_names(self, connection, schema, **kw):
if self._is_sql_fallback_requested(**kw):
return super().get_table_names(connection, schema, **kw)
odbc_connection = self._get_odbc_connection(connection)
tables = self._get_tables_for_schema_odbc(
connection,
odbc_connection,

schema, table_type="TABLE",
**kw
)
Expand All @@ -242,10 +225,8 @@ def get_table_names(self, connection, schema, **kw):
def get_view_names(self, connection, schema=None, **kw):
if self._is_sql_fallback_requested(**kw):
return super().get_view_names(connection, schema, **kw)
odbc_connection = self._get_odbc_connection(connection)
tables = self._get_tables_for_schema_odbc(
connection,
odbc_connection,
schema,
table_type="VIEW",
**kw
Expand Down Expand Up @@ -273,9 +254,8 @@ def _get_columns(self, connection, table_name, schema=None, **kw):
if self._is_sql_fallback_requested(**kw):
return super()._get_columns(connection, table_name, schema, **kw)

odbc_connection = self._get_odbc_connection(connection)
tables = self._get_tables_for_schema_odbc(
connection, odbc_connection,
connection,
schema=schema,
table_name=table_name,
**kw
Expand All @@ -301,10 +281,10 @@ def _get_pk_constraint(self, connection, table_name, schema=None, **kw):
if self._is_sql_fallback_requested(**kw):
return super()._get_pk_constraint(connection, table_name, schema, **kw)

odbc_connection = self._get_odbc_connection(connection)
conn = connection.engine.raw_connection()
schema = self._get_schema_for_input_or_current(connection, schema)
table_name = self.denormalize_name(table_name)
with odbc_connection.cursor().primaryKeys(table=table_name, schema=schema) as cursor:
with conn.cursor().primaryKeys(table=table_name, schema=schema) as cursor:
pkeys = []
constraint_name = None
for row in cursor:
Expand All @@ -320,11 +300,9 @@ def _get_foreign_keys(self, connection, table_name, schema=None, **kw):
if self._is_sql_fallback_requested(**kw):
return super()._get_foreign_keys(connection, table_name, schema, **kw)

odbc_connection = self._get_odbc_connection(connection)
# Need to use a workaround, because SQLForeignKeys functions doesn't work for an unknown reason
tables = self._get_tables_for_schema_odbc(
connection=connection,
odbc_connection=odbc_connection,
schema=schema,
table_name=table_name,
table_type="TABLE",
Expand Down

0 comments on commit 5dd460c

Please sign in to comment.