Skip to content

Commit

Permalink
Minor refactorings to clean up the bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicoretti committed May 17, 2022
1 parent 5dd460c commit 107cffc
Showing 1 changed file with 8 additions and 31 deletions.
39 changes: 8 additions & 31 deletions sqlalchemy_exasol/pyodbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from distutils.version import LooseVersion

from sqlalchemy import sql
from sqlalchemy.engine import reflection, Engine, Connection
from sqlalchemy.orm.session import Session
from sqlalchemy.engine import reflection
from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy.util.langhelpers import asbool

Expand Down Expand Up @@ -176,7 +175,7 @@ def _is_sql_fallback_requested(**kwargs):
return is_fallback_requested

@reflection.cache
def _get_tables_for_schema_odbc(self, connection, schema, table_type=None, table_name=None, **kw):
def _tables_for_schema(self, connection, schema, table_type=None, table_name=None):
schema = self._get_schema_for_input_or_current(connection, schema)
table_name = self.denormalize_name(table_name)
conn = connection.engine.raw_connection()
Expand All @@ -190,12 +189,7 @@ def get_view_definition(self, connection, view_name, schema=None, **kw):
if view_name is None:
return None

tables = self._get_tables_for_schema_odbc(
connection, schema,
table_type="VIEW",
table_name=view_name,
**kw
)
tables = self._tables_for_schema(connection, schema, table_type="VIEW", table_name=view_name)
if len(tables) != 1:
return None

Expand All @@ -213,24 +207,14 @@ def get_view_definition(self, connection, view_name, schema=None, **kw):
def get_table_names(self, connection, schema, **kw):
if self._is_sql_fallback_requested(**kw):
return super().get_table_names(connection, schema, **kw)
tables = self._get_tables_for_schema_odbc(
connection,

schema, table_type="TABLE",
**kw
)
tables = self._tables_for_schema(connection, schema, table_type="TABLE")
return [self.normalize_name(row.table_name) for row in tables]

@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
if self._is_sql_fallback_requested(**kw):
return super().get_view_names(connection, schema, **kw)
tables = self._get_tables_for_schema_odbc(
connection,
schema,
table_type="VIEW",
**kw
)
tables = self._tables_for_schema(connection, schema, table_type="VIEW")
return [self.normalize_name(row.table_name) for row in tables]

def has_table(self, connection, table_name, schema=None, **kw):
Expand All @@ -254,13 +238,7 @@ def _get_columns(self, connection, table_name, schema=None, **kw):
if self._is_sql_fallback_requested(**kw):
return super()._get_columns(connection, table_name, schema, **kw)

tables = self._get_tables_for_schema_odbc(
connection,
schema=schema,
table_name=table_name,
**kw
)

tables = self._tables_for_schema(connection, schema=schema, table_name=table_name)
if len(tables) != 1:
return []

Expand Down Expand Up @@ -301,12 +279,11 @@ def _get_foreign_keys(self, connection, table_name, schema=None, **kw):
return super()._get_foreign_keys(connection, table_name, schema, **kw)

# Need to use a workaround, because SQLForeignKeys functions doesn't work for an unknown reason
tables = self._get_tables_for_schema_odbc(
tables = self._tables_for_schema(
connection=connection,
schema=schema,
table_name=table_name,
table_type="TABLE",
**kw
table_type="TABLE"
)
if len(tables) == 0:
return []
Expand Down

0 comments on commit 107cffc

Please sign in to comment.