From fe81fe423b90e2ad374b58ce6fb493a7b827c8cd Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Thu, 12 May 2022 11:27:59 +0200 Subject: [PATCH 01/27] Add regression test for Github issue #136 --- test/test_regression.py | 92 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/test/test_regression.py b/test/test_regression.py index 2f7076c5..ecdde09d 100644 --- a/test/test_regression.py +++ b/test/test_regression.py @@ -1,5 +1,14 @@ """This module contains various regression test for issues which have been fixed in the past""" -from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, schema +import pytest +from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, inspect +from sqlalchemy.pool import ( + AssertionPool, + NullPool, + QueuePool, + SingletonThreadPool, + StaticPool, +) +from sqlalchemy.schema import CreateSchema, DropSchema from sqlalchemy.testing import fixtures from sqlalchemy.testing.fixtures import config @@ -10,7 +19,7 @@ def setUp(self): self.tenant_schema_name = "tenant_schema" engine = config.db with config.db.connect() as conn: - conn.execute(schema.CreateSchema(self.tenant_schema_name)) + conn.execute(CreateSchema(self.tenant_schema_name)) metadata = MetaData() Table( self.table_name, @@ -23,7 +32,7 @@ def setUp(self): def tearDown(self): with config.db.connect() as conn: - conn.execute(schema.DropSchema(self.tenant_schema_name, cascade=True)) + conn.execute(DropSchema(self.tenant_schema_name, cascade=True)) def test_use_schema_translate_map_in_get_last_row_id(self): """See also: https://github.com/exasol/sqlalchemy-exasol/issues/104""" @@ -40,3 +49,80 @@ def test_use_schema_translate_map_in_get_last_row_id(self): engine = create_engine(config.db.url, execution_options=options) with engine.connect() as conn: conn.execute(my_table.insert().values(name="John Doe")) + + +class Introspection(fixtures.TestBase): + """Regression(s) for issue: https://github.com/exasol/sqlalchemy-exasol/issues/136""" + + POOL_TYPES = [ + QueuePool, + NullPool, + AssertionPool, + StaticPool, + SingletonThreadPool, + ] + + @classmethod + def setup_class(cls): + def _create_tables(schema, tables): + engine = config.db + with engine.connect(): + metadata = MetaData() + for name in tables: + Table( + name, + metadata, + Column("id", Integer, primary_key=True), + Column("random_field", String(1000)), + schema=schema, + ) + metadata.create_all(engine) + + def _create_views(schema, views): + engine = config.db + with engine.connect() as conn: + for name in views: + conn.execute( + f"CREATE OR REPLACE VIEW {schema}.{name} AS SELECT 1 as COLUMN_1;" + ) + + cls.schema = "test" + cls.tables = ["a_table", "b_table", "c_table", "d_table"] + cls.views = ["a_view", "b_view"] + + _create_tables(cls.schema, cls.tables) + _create_views(cls.schema, cls.views) + + @classmethod + def teardown_class(cls): + engine = config.db + + def _drop_tables(schema, tables): + metadata = MetaData(engine, schema=cls.schema) + metadata.reflect() + to_be_deleted = [metadata.tables[name] for name in metadata.tables] + metadata.drop_all(engine, to_be_deleted) + + def _drop_views(schema, views): + with engine.connect() as conn: + for name in views: + conn.execute(f"DROP VIEW {schema}.{name};") + + _drop_tables(cls.schema, cls.tables) + _drop_views(cls.schema, cls.views) + + @pytest.mark.parametrize("pool_type", POOL_TYPES) + def test_introinspection_of_tables_works_with(self, pool_type): + expected = self.tables + engine = create_engine(config.db.url, poolclass=pool_type) + inspector = inspect(engine) + tables = inspector.get_table_names(schema=self.schema) + assert expected == tables + + @pytest.mark.parametrize("pool_type", POOL_TYPES) + def test_introinspection_of_views_works_with(self, pool_type): + expected = self.views + engine = create_engine(config.db.url, poolclass=pool_type) + inspector = inspect(engine) + tables = inspector.get_view_names(schema=self.schema) + assert expected == tables From 44c8e8e10919b75c6b05479774e6c08a4395fb16 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Thu, 12 May 2022 11:30:56 +0200 Subject: [PATCH 02/27] Improve translate map regression * Reduce noise in test name * Change setup/teardown to pytests class based mechanism --- test/test_regression.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/test/test_regression.py b/test/test_regression.py index ecdde09d..96c4202b 100644 --- a/test/test_regression.py +++ b/test/test_regression.py @@ -13,26 +13,29 @@ from sqlalchemy.testing.fixtures import config -class RegressionTest(fixtures.TestBase): - def setUp(self): - self.table_name = "my_table" - self.tenant_schema_name = "tenant_schema" +class TranslateMap(fixtures.TestBase): + + @classmethod + def setup_class(cls): + cls.table_name = "my_table" + cls.tenant_schema_name = "tenant_schema" engine = config.db with config.db.connect() as conn: - conn.execute(CreateSchema(self.tenant_schema_name)) + conn.execute(CreateSchema(cls.tenant_schema_name)) metadata = MetaData() Table( - self.table_name, + cls.table_name, metadata, Column("id", Integer, primary_key=True), Column("name", String(1000), nullable=False), - schema=self.tenant_schema_name, + schema=cls.tenant_schema_name, ) metadata.create_all(engine) - def tearDown(self): + @classmethod + def teardown_class(cls): with config.db.connect() as conn: - conn.execute(DropSchema(self.tenant_schema_name, cascade=True)) + conn.execute(DropSchema(cls.tenant_schema_name, cascade=True)) def test_use_schema_translate_map_in_get_last_row_id(self): """See also: https://github.com/exasol/sqlalchemy-exasol/issues/104""" From e9904393e524b73479bf8978d3a335944978beaf Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Thu, 12 May 2022 11:35:59 +0200 Subject: [PATCH 03/27] Remove unnecessary parameters from function --- test/test_regression.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_regression.py b/test/test_regression.py index 96c4202b..4e53e20a 100644 --- a/test/test_regression.py +++ b/test/test_regression.py @@ -14,7 +14,6 @@ class TranslateMap(fixtures.TestBase): - @classmethod def setup_class(cls): cls.table_name = "my_table" @@ -100,8 +99,8 @@ def _create_views(schema, views): def teardown_class(cls): engine = config.db - def _drop_tables(schema, tables): - metadata = MetaData(engine, schema=cls.schema) + def _drop_tables(schema): + metadata = MetaData(engine, schema=schema) metadata.reflect() to_be_deleted = [metadata.tables[name] for name in metadata.tables] metadata.drop_all(engine, to_be_deleted) @@ -111,7 +110,7 @@ def _drop_views(schema, views): for name in views: conn.execute(f"DROP VIEW {schema}.{name};") - _drop_tables(cls.schema, cls.tables) + _drop_tables(cls.schema) _drop_views(cls.schema, cls.views) @pytest.mark.parametrize("pool_type", POOL_TYPES) From 89cc5146a56c1b0383512bbe57d36b6cc63dcf1e Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Thu, 12 May 2022 16:15:20 +0200 Subject: [PATCH 04/27] Fix typo's in test names --- test/test_regression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_regression.py b/test/test_regression.py index 4e53e20a..78e9817f 100644 --- a/test/test_regression.py +++ b/test/test_regression.py @@ -114,7 +114,7 @@ def _drop_views(schema, views): _drop_views(cls.schema, cls.views) @pytest.mark.parametrize("pool_type", POOL_TYPES) - def test_introinspection_of_tables_works_with(self, pool_type): + def test_introspection_of_tables_works_with(self, pool_type): expected = self.tables engine = create_engine(config.db.url, poolclass=pool_type) inspector = inspect(engine) @@ -122,7 +122,7 @@ def test_introinspection_of_tables_works_with(self, pool_type): assert expected == tables @pytest.mark.parametrize("pool_type", POOL_TYPES) - def test_introinspection_of_views_works_with(self, pool_type): + def test_introspection_of_views_works_with(self, pool_type): expected = self.views engine = create_engine(config.db.url, poolclass=pool_type) inspector = inspect(engine) From 20814e96d0b278d8c612b9fb958bb0c02467defb Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Mon, 16 May 2022 09:45:00 +0200 Subject: [PATCH 05/27] Move odbc based get_foreign_keys to pyodbc dialect --- sqlalchemy_exasol/base.py | 43 ++++++++++--------------------------- sqlalchemy_exasol/pyodbc.py | 38 ++++++++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index 4b024d1a..8846b480 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -849,35 +849,19 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw): else: return self.get_pk_constraint_sql(connection, table_name=table_name, schema=schema, **kw) - @reflection.cache - def _get_foreign_keys_odbc(self, connection, odbc_connection, table_name, schema=None, **kw): - # Need to use a workaround, because SQLForeignKeys functions doesn't work for an unknown reason - tables = self._get_tables_for_schema_odbc(connection=connection, odbc_connection=odbc_connection, - schema=schema, table_name=table_name, table_type="TABLE", **kw) - if len(tables) > 0: - quoted_schema_string = self.quote_string_value(tables[0].table_schem) - quoted_table_string = self.quote_string_value(tables[0].table_name) - sql_stmnt = \ - "/*snapshot execution*/ " + \ - self._get_constraint_sql_str(quoted_schema_string,quoted_table_string,"FOREIGN KEY") - rp = connection.execute(sql.text(sql_stmnt)) - return list(rp) - else: - return [] @reflection.cache - def _get_foreign_keys_sql(self, connection, table_name, schema=None, **kw): + def _get_foreign_keys(self, connection, table_name, schema=None, **kw): table_name_string = ":table" - if schema is None: - schema_string = "CURRENT_SCHEMA " - else: - schema_string = ":schema " - sql_stmnt = \ - self._get_constraint_sql_str(schema_string, table_name_string, "FOREIGN KEY") - rp = connection.execute(sql.text(sql_stmnt), - schema=self.denormalize_name(schema), - table=self.denormalize_name(table_name)) - return list(rp) + schema_string = "CURRENT_SCHEMA " if schema is None else ":schema " + sql_statement = self._get_constraint_sql_str(schema_string, table_name_string, "FOREIGN KEY") + response = connection.execute( + sql.text(sql_statement), + schema=self.denormalize_name(schema), + table=self.denormalize_name(table_name) + ) + return list(response) + @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -895,12 +879,7 @@ def fkey_rec(): } fkeys = util.defaultdict(fkey_rec) - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - constraints = self._get_foreign_keys_odbc(connection, odbc_connection, table_name=table_name, - schema=schema_int, **kw) - else: - constraints = self._get_foreign_keys_sql(connection, table_name=table_name, schema=schema_int, **kw) + constraints = self._get_foreign_keys(connection, table_name=table_name, schema=schema_int, **kw) table_name = self.denormalize_name(table_name) for row in constraints: (cons_name, local_column, remote_schema, remote_table, remote_column) = \ diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index dc3a1080..24ade76c 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -8,16 +8,19 @@ import re import sys +import logging from distutils.version import LooseVersion +from sqlalchemy.engine import reflection from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy.util.langhelpers import asbool from sqlalchemy_exasol.base import EXADialect, EXAExecutionContext +logger = logging.getLogger("sqlalchemy_exasol") -class EXADialect_pyodbc(EXADialect, PyODBCConnector): +class EXADialect_pyodbc(EXADialect, PyODBCConnector): execution_ctx_cls = EXAExecutionContext driver_version = None @@ -58,7 +61,6 @@ def _get_server_version_info(self, connection): return self.server_version_info if sys.platform == "darwin": - def connect(self, *cargs, **cparams): # Get connection conn = super().connect(*cargs, **cparams) @@ -164,5 +166,37 @@ def is_disconnect(self, e, connection, cursor): return super().is_disconnect(e, connection, cursor) + @staticmethod + def _is_sql_fallback_requested(**kwargs): + is_fallback_requested = kwargs.get("use_sql_fallback", False) + if is_fallback_requested: + logger.warning("Using sql fallback instead of odbc functions") + return is_fallback_requested + + + @reflection.cache + def _get_foreign_keys(self, connection, table_name, schema=None, **kw): + if self._is_sql_fallback_requested(): + return super().get_foreign_keys(connection, table_name, schema, **kw) + + odbc_connection = self.getODBCConnection(connection) + # Need to use a workaround, because SQLForeignKeys functions doesn't work for an unknown reason + tables = self._get_tables_for_schema_odbc(connection=connection, odbc_connection=odbc_connection, + schema=schema, table_name=table_name, table_type="TABLE", **kw) + if len(tables) == 0: + return [] + + quoted_schema_string = self.quote_string_value(tables[0].table_schem) + quoted_table_string = self.quote_string_value(tables[0].table_name) + sql_statement = "/*snapshot execution*/ {query}".format( + query=self._get_constraint_sql_str( + quoted_schema_string, + quoted_table_string, + "FOREIGN KEY" + ) + ) + response = connection.execute(sql_statement) + return list(response) + dialect = EXADialect_pyodbc From 8fca3161274d22be4eb56adb3d97eb8004848e5f Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Mon, 16 May 2022 09:59:00 +0200 Subject: [PATCH 06/27] Addition fk refactoring --- sqlalchemy_exasol/pyodbc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index 24ade76c..0a940c5d 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -177,7 +177,7 @@ def _is_sql_fallback_requested(**kwargs): @reflection.cache def _get_foreign_keys(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(): - return super().get_foreign_keys(connection, table_name, schema, **kw) + return super()._get_foreign_keys(connection, table_name, schema, **kw) odbc_connection = self.getODBCConnection(connection) # Need to use a workaround, because SQLForeignKeys functions doesn't work for an unknown reason From 26fc0f23793a622b7e89812bef7e29910db07330 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Mon, 16 May 2022 10:10:53 +0200 Subject: [PATCH 07/27] Move the odbc based get_pk_constraint to pyodbc dialect --- sqlalchemy_exasol/base.py | 21 ++------------------- sqlalchemy_exasol/pyodbc.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index 8846b480..b5e56e3d 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -802,22 +802,9 @@ def _get_constraint_sql_str(self, schema, table_name, contraint_type): .format(schema=schema, table_name=table_name, contraint_type=contraint_type) return sql_stmnt - @reflection.cache - def get_pk_constraint_odbc(self, connection, odbc_connection, table_name, schema=None, **kw): - schema = self._get_schema_for_input_or_current(connection, schema) - table_name = self.denormalize_name(table_name) - with odbc_connection.cursor().primaryKeys(table=table_name, schema=schema) as primaryKeys_cursor: - pkeys = [] - constraint_name = None - for row in primaryKeys_cursor: - if row[2] != table_name and table_name is not None: - continue - pkeys.append(self.normalize_name(row[3])) - constraint_name = self.normalize_name(row[5]) - return {'constrained_columns': pkeys, 'name': constraint_name} @reflection.cache - def get_pk_constraint_sql(self, connection, table_name, schema=None, **kw): + def _get_pk_constraint(self, connection, table_name, schema, **kw): schema = self._get_schema_for_input(connection, schema) table_name = self.denormalize_name(table_name) table_name_string = ":table" @@ -843,11 +830,7 @@ def get_pk_constraint_sql(self, connection, table_name, schema=None, **kw): def get_pk_constraint(self, connection, table_name, schema=None, **kw): if table_name is None: return None - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - return self.get_pk_constraint_odbc(connection, odbc_connection, table_name=table_name, schema=schema, **kw) - else: - return self.get_pk_constraint_sql(connection, table_name=table_name, schema=schema, **kw) + return self._get_pk_constraint(connection, table_name, schema=schema, **kw) @reflection.cache diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index 0a940c5d..76ebeb84 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -173,6 +173,24 @@ def _is_sql_fallback_requested(**kwargs): logger.warning("Using sql fallback instead of odbc functions") return is_fallback_requested + @reflection.cache + def _get_pk_constraint(self, connection, table_name, schema=None, **kw): + if self._is_sql_fallback_requested(): + return super()._get_pk_constraint(connection, table_name, schema, **kw) + + odbc_connection = self.getODBCConnection(connection) + schema = self._get_schema_for_input_or_current(connection, schema) + table_name = self.denormalize_name(table_name) + with odbc_connection.cursor().primaryKeys(table=table_name, schema=schema) as cursor: + pkeys = [] + constraint_name = None + for row in cursor: + table, primary_key, constraint = row[2], row[3], row[5] + if table != table_name and table_name is not None: + continue + pkeys.append(self.normalize_name(primary_key)) + constraint_name = self.normalize_name(constraint) + return {'constrained_columns': pkeys, 'name': constraint_name} @reflection.cache def _get_foreign_keys(self, connection, table_name, schema=None, **kw): From 21f4457a9d149cd1ee450b907a43b4534301a8f2 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Mon, 16 May 2022 10:19:21 +0200 Subject: [PATCH 08/27] Refactor _get_constraint_sql_str method --- sqlalchemy_exasol/base.py | 40 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index b5e56e3d..1a608d15 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -560,7 +560,7 @@ def has_table(self, connection, table_name, schema=None, **kw): if odbc_connection is not None and not self.use_sql_fallback(**kw): result=self.has_table_odbc(connection, odbc_connection, schema=schema, table_name=table_name, **kw) else: - result=self.has_table_sql(connection, schema=schema, table_name=table_name, **kw) + result=self.has_table_sql(connection, schema=schema, table_name=table_name, **kw) return result def has_table_odbc(self, connection, odbc_connection, table_name, schema=None, **kw): @@ -707,7 +707,7 @@ def _get_columns_sql(self, connection, table_name, schema=None, **kw): self.get_column_sql_query_str() \ .format(schema=schema_str, table=table_name_str) stmnt = sql.text(sql_stmnt) - rp = connection.execute(stmnt, + rp = connection.execute(stmnt, schema=self.denormalize_name(schema), table=self.denormalize_name(table_name)) @@ -783,24 +783,24 @@ def get_columns(self, connection, table_name, schema=None, **kw): columns.append(cdict) return columns - def _get_constraint_sql_str(self, schema, table_name, contraint_type): - sql_stmnt = \ - "SELECT " \ - "constraint_name, " \ - "column_name, " \ - "referenced_schema, " \ - "referenced_table, " \ - "referenced_column, " \ - "constraint_table, " \ - "constraint_type " \ - "FROM SYS.EXA_ALL_CONSTRAINT_COLUMNS " \ - "WHERE " \ - "constraint_schema={schema} AND " \ - "constraint_table={table_name} AND " \ - "constraint_type='{contraint_type}' " \ - "ORDER BY ordinal_position" \ - .format(schema=schema, table_name=table_name, contraint_type=contraint_type) - return sql_stmnt + @staticmethod + def _get_constraint_sql_str(schema, table_name, contraint_type): + return ( + "SELECT " + "constraint_name, " + "column_name, " + "referenced_schema, " + "referenced_table, " + "referenced_column, " + "constraint_table, " + "constraint_type " + "FROM SYS.EXA_ALL_CONSTRAINT_COLUMNS " + "WHERE " + f"constraint_schema={schema} AND " + f"constraint_table={table_name} AND " + f"constraint_type='{contraint_type}' " + "ORDER BY ordinal_position" + ) @reflection.cache From 1b0279c9d5213e302bafc5cd4f54d7cb2ccc62c7 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Mon, 16 May 2022 11:09:02 +0200 Subject: [PATCH 09/27] Adjust deadlock tests and markers to take odbc refactoring into account --- test/test_deadlock.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/test/test_deadlock.py b/test/test_deadlock.py index ab8a0add..2278d300 100644 --- a/test/test_deadlock.py +++ b/test/test_deadlock.py @@ -6,12 +6,15 @@ import pytest from sqlalchemy import create_engine from sqlalchemy.testing import fixtures, config +from sqlalchemy.engine.reflection import Inspector import sqlalchemy.testing as testing -from sqlalchemy_exasol.base import EXADialect -#TODO get_schema_names, get_view_names and get_view_definition didn't cause deadlocks in this scenario -@pytest.mark.skipif("turbodbc" in str(testing.db.url), reason="We currently don't support snapshot metadata requests for turbodbc") +# TODO: get_schema_names, get_view_names and get_view_definition didn't cause deadlocks in this scenario +@pytest.mark.skipif( + "pyodbc" not in str(testing.db.url), + reason="We currently only support snapshot metadata requests in pyodbc based dialect" +) class MetadataTest(fixtures.TablesTest): __backend__ = True @@ -25,14 +28,14 @@ def create_transaction(self, url, con_name): def test_no_deadlock_for_get_table_names_without_fallback(self): def without_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_table_names(session2, schema=schema, use_sql_fallback=False) self.run_deadlock_for_table(without_fallback) def test_deadlock_for_get_table_names_with_fallback(self): def with_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_table_names(session2, schema=schema, use_sql_fallback=True) with pytest.raises(Exception): @@ -40,7 +43,7 @@ def with_fallback(session2, schema, table): def test_no_deadlock_for_get_columns_without_fallback(self): def without_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_columns(session2, schema=schema, table_name=table, use_sql_fallback=False) self.run_deadlock_for_table(without_fallback) @@ -48,36 +51,35 @@ def without_fallback(session2, schema, table): def test_no_deadlock_for_get_columns_with_fallback(self): # TODO: Doesnt produce a deadlock anymore since last commit? def with_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_columns(session2, schema=schema, table_name=table, use_sql_fallback=True) self.run_deadlock_for_table(with_fallback) def test_no_deadlock_for_get_pk_constraint_without_fallback(self): def without_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_pk_constraint(session2, table_name=table, schema=schema, use_sql_fallback=False) self.run_deadlock_for_table(without_fallback) def test_no_deadlock_for_get_pk_constraint_with_fallback(self): def with_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_pk_constraint(session2, table_name=table, schema=schema, use_sql_fallback=True) self.run_deadlock_for_table(with_fallback) def test_no_deadlock_for_get_foreign_keys_without_fallback(self): def without_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_foreign_keys(session2, table_name=table, schema=schema, use_sql_fallback=False) self.run_deadlock_for_table(without_fallback) - def test_no_deadlock_for_get_foreign_keys_with_fallback(self): def with_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_foreign_keys(session2, table_name=table, schema=schema, use_sql_fallback=True) self.run_deadlock_for_table(with_fallback) @@ -85,7 +87,7 @@ def with_fallback(session2, schema, table): def test_no_deadlock_for_get_view_names_without_fallback(self): # TODO: think of other scenarios where metadata deadlocks with view could happen def without_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_view_names(session2, table_name=table, schema=schema, use_sql_fallback=False) self.run_deadlock_for_table(without_fallback) @@ -93,7 +95,7 @@ def without_fallback(session2, schema, table): def test_no_deadlock_for_get_view_names_with_fallback(self): # TODO: think of other scenarios where metadata deadlocks with view could happen def with_fallback(session2, schema, table): - dialect = EXADialect() + dialect = Inspector(session2).dialect dialect.get_view_names(session2, table_name=table, schema=schema, use_sql_fallback=True) self.run_deadlock_for_table(with_fallback) From 24b29d369d72ec56257b54cd509c406a93061d58 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Mon, 16 May 2022 11:47:04 +0200 Subject: [PATCH 10/27] Move odbc based get_columns functionality to pyodbc dialect --- sqlalchemy_exasol/base.py | 51 +++++++------------------------------ sqlalchemy_exasol/pyodbc.py | 28 ++++++++++++++++++++ 2 files changed, 37 insertions(+), 42 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index 1a608d15..7a16eac6 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -675,52 +675,19 @@ def get_column_sql_query_str(self): "column_table = {table} " \ "ORDER BY column_ordinal_position" - @reflection.cache - def _get_columns_odbc(self, connection, odbc_connection, table_name, schema, **kw): - tables = self._get_tables_for_schema_odbc(connection, odbc_connection, - schema=schema, table_name=table_name, **kw) - if len(tables) == 1: - # get_columns_sql originally returned all columns of all tables if table_name is None, - # we follow this behavior here for compatibility. However, the documentation for Dialects - # does not mentions this behavior: - # https://docs.sqlalchemy.org/en/13/core/internals.html#sqlalchemy.engine.interfaces.Dialect - quoted_schema_string = self.quote_string_value(tables[0].table_schem) - quoted_table_string = self.quote_string_value(tables[0].table_name) - sql_stmnt = \ - "/*snapshot execution*/ " + \ - self.get_column_sql_query_str() \ - .format(schema=quoted_schema_string, table=quoted_table_string) - rp = connection.execute(sql.text(sql_stmnt)) - return list(rp) - else: - return [] @reflection.cache - def _get_columns_sql(self, connection, table_name, schema=None, **kw): + def _get_columns(self, connection, table_name, schema=None, **kw): schema = self._get_schema_for_input(connection, schema) - if schema is None: - schema_str = "CURRENT_SCHEMA" - else: - schema_str = ":schema" + schema_str = "CURRENT_SCHEMA" if schema is None else ":schema" table_name_str = ":table" - sql_stmnt = \ - self.get_column_sql_query_str() \ - .format(schema=schema_str, table=table_name_str) - stmnt = sql.text(sql_stmnt) - rp = connection.execute(stmnt, - schema=self.denormalize_name(schema), - table=self.denormalize_name(table_name)) - - return list(rp) - - @reflection.cache - def _get_columns(self, connection, table_name, schema=None, **kw): - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - columns = self._get_columns_odbc(connection, odbc_connection, table_name, schema, **kw) - else: - columns = self._get_columns_sql(connection, table_name, schema, **kw) - return columns + sql_statement = self.get_column_sql_query_str().format(schema=schema_str, table=table_name_str) + response = connection.execute( + sql_statement, + schema=self.denormalize_name(schema), + table=self.denormalize_name(table_name) + ) + return list(response) @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index 76ebeb84..33e0cadf 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -173,6 +173,34 @@ def _is_sql_fallback_requested(**kwargs): logger.warning("Using sql fallback instead of odbc functions") return is_fallback_requested + @reflection.cache + def _get_columns(self, connection, table_name, schema=None, **kw): + if self._is_sql_fallback_requested(): + return super()._get_columns(connection, table_name, schema, **kw) + + odbc_connection = self.getODBCConnection(connection) + tables = self._get_tables_for_schema_odbc( + connection, odbc_connection, + schema=schema, + table_name=table_name, + **kw + ) + + if len(tables) != 1: + return [] + + # get_columns_sql originally returned all columns of all tables if table_name is None, + # we follow this behavior here for compatibility. However, the documentation for Dialects + # does not mention this behavior: + # https://docs.sqlalchemy.org/en/13/core/internals.html#sqlalchemy.engine.interfaces.Dialect + quoted_schema_string = self.quote_string_value(tables[0].table_schem) + quoted_table_string = self.quote_string_value(tables[0].table_name) + sql_statement = "/*snapshot execution*/ {query}".format(query=self.get_column_sql_query_str()) + sql_statement = sql_statement.format(schema=quoted_schema_string, table=quoted_table_string) + response = connection.execute(sql_statement) + + return list(response) + @reflection.cache def _get_pk_constraint(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(): From 320d9478dd0fae418158644de1df179398f28cf9 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Mon, 16 May 2022 11:54:58 +0200 Subject: [PATCH 11/27] Refactor get_column_sql_query_str method --- sqlalchemy_exasol/base.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index 7a16eac6..cf7ba620 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -656,24 +656,27 @@ def get_view_definition_sql(self, connection, view_name, schema=None, **kw): else: return None - def get_column_sql_query_str(self): - return "SELECT " \ - "column_name, " \ - "column_type, " \ - "column_maxsize, " \ - "column_num_prec, " \ - "column_num_scale, " \ - "column_is_nullable, " \ - "column_default, " \ - "column_identity, " \ - "column_is_distribution_key, " \ - "column_table " \ - "FROM sys.exa_all_columns " \ - "WHERE " \ - "column_object_type IN ('TABLE', 'VIEW') AND " \ - "column_schema = {schema} AND " \ - "column_table = {table} " \ - "ORDER BY column_ordinal_position" + @staticmethod + def get_column_sql_query_str(): + return ( + "SELECT " + "column_name, " + "column_type, " + "column_maxsize, " + "column_num_prec, " + "column_num_scale, " + "column_is_nullable, " + "column_default, " + "column_identity, " + "column_is_distribution_key, " + "column_table " + "FROM sys.exa_all_columns " + "WHERE " + "column_object_type IN ('TABLE', 'VIEW') AND " + "column_schema = {schema} AND " + "column_table = {table} " + "ORDER BY column_ordinal_position" + ) @reflection.cache From 700845aa90b1d1c4c46217a8e240b3253189f388 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Mon, 16 May 2022 15:32:22 +0200 Subject: [PATCH 12/27] Addition columns refactoring --- sqlalchemy_exasol/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index cf7ba620..cae0d56e 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -686,7 +686,7 @@ def _get_columns(self, connection, table_name, schema=None, **kw): table_name_str = ":table" sql_statement = self.get_column_sql_query_str().format(schema=schema_str, table=table_name_str) response = connection.execute( - sql_statement, + sql.text(sql_statement), schema=self.denormalize_name(schema), table=self.denormalize_name(table_name) ) From 4c5adaf2551ae26f4dfcc476b7f14bff32ebbafb Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Mon, 16 May 2022 15:40:57 +0200 Subject: [PATCH 13/27] Fix fallback checks --- sqlalchemy_exasol/pyodbc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index 33e0cadf..393d496d 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -175,7 +175,7 @@ def _is_sql_fallback_requested(**kwargs): @reflection.cache def _get_columns(self, connection, table_name, schema=None, **kw): - if self._is_sql_fallback_requested(): + if self._is_sql_fallback_requested(**kw): return super()._get_columns(connection, table_name, schema, **kw) odbc_connection = self.getODBCConnection(connection) @@ -203,7 +203,7 @@ def _get_columns(self, connection, table_name, schema=None, **kw): @reflection.cache def _get_pk_constraint(self, connection, table_name, schema=None, **kw): - if self._is_sql_fallback_requested(): + if self._is_sql_fallback_requested(**kw): return super()._get_pk_constraint(connection, table_name, schema, **kw) odbc_connection = self.getODBCConnection(connection) @@ -222,7 +222,7 @@ def _get_pk_constraint(self, connection, table_name, schema=None, **kw): @reflection.cache def _get_foreign_keys(self, connection, table_name, schema=None, **kw): - if self._is_sql_fallback_requested(): + if self._is_sql_fallback_requested(**kw): return super()._get_foreign_keys(connection, table_name, schema, **kw) odbc_connection = self.getODBCConnection(connection) From ccbec12f9567b95f71504532cd9153fc21edeff3 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Mon, 16 May 2022 15:44:25 +0200 Subject: [PATCH 14/27] Ensure test take into account the correct dialect --- test/test_get_metadata_functions.py | 48 +++++++++++++++-------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/test/test_get_metadata_functions.py b/test/test_get_metadata_functions.py index c171e9fd..c66dd87d 100644 --- a/test/test_get_metadata_functions.py +++ b/test/test_get_metadata_functions.py @@ -5,8 +5,10 @@ from sqlalchemy.engine.url import URL from sqlalchemy.sql.sqltypes import INTEGER, VARCHAR from sqlalchemy.testing import fixtures, config +from sqlalchemy.engine.reflection import Inspector from sqlalchemy_exasol.base import EXADialect +from sqlalchemy_exasol.pyodbc import EXADialect_pyodbc TEST_GET_METADATA_FUNCTIONS_SCHEMA = "test_get_metadata_functions_schema" ENGINE_NONE_DATABASE = "ENGINE_NONE_DATABASE" @@ -74,14 +76,14 @@ def create_engine_with_database_name(cls, connection, new_database_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_schema_names(self, engine_name, use_sql_fallback): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect schema_names = dialect.get_schema_names(connection=c, use_sql_fallback=use_sql_fallback) assert self.schema in schema_names and self.schema_2 in schema_names @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_compare_get_schema_names_for_sql_and_odbc(self, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect schema_names_fallback = dialect.get_schema_names(connection=c, use_sql_fallback=True) schema_names_odbc = dialect.get_schema_names(connection=c) assert sorted(schema_names_fallback) == sorted(schema_names_odbc) @@ -90,7 +92,7 @@ def test_compare_get_schema_names_for_sql_and_odbc(self, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_table_names(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect table_names = dialect.get_table_names(connection=c, schema=self.schema, use_sql_fallback=use_sql_fallback) assert "t" in table_names and "s" in table_names @@ -100,7 +102,7 @@ def test_compare_get_table_names_for_sql_and_odbc(self, schema, engine_name): with self.engine_map[engine_name].begin() as c: if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) - dialect = EXADialect() + dialect = Inspector(c).dialect table_names_fallback = dialect.get_table_names(connection=c, schema=schema, use_sql_fallback=True) table_names_odbc = dialect.get_table_names(connection=c, schema=schema) assert table_names_fallback == table_names_odbc @@ -109,7 +111,7 @@ def test_compare_get_table_names_for_sql_and_odbc(self, schema, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_has_table_table_exists(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect has_table = dialect.has_table(connection=c, schema=self.schema, table_name="t", use_sql_fallback=use_sql_fallback) assert has_table, "Table %s.T was not found, but should exist" % self.schema @@ -118,7 +120,7 @@ def test_has_table_table_exists(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_has_table_table_exists_not(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect has_table = dialect.has_table(connection=c, schema=self.schema, table_name="not_exist", use_sql_fallback=use_sql_fallback) assert not has_table, "Table %s.not_exist was found, but should not exist" % self.schema @@ -127,7 +129,7 @@ def test_has_table_table_exists_not(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_compare_has_table_for_sql_and_odbc(self, schema, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect has_table_fallback = dialect.has_table(connection=c, schema=schema, use_sql_fallback=True, table_name="t") has_table_odbc = dialect.has_table(connection=c, schema=schema, table_name="t") assert has_table_fallback == has_table_odbc, "Expected table %s.t with odbc and fallback" % schema @@ -136,7 +138,7 @@ def test_compare_has_table_for_sql_and_odbc(self, schema, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_view_names(self, use_sql_fallback,engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect view_names = dialect.get_view_names(connection=c, schema=self.schema, use_sql_fallback=use_sql_fallback) assert "v" in view_names @@ -144,7 +146,7 @@ def test_get_view_names(self, use_sql_fallback,engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_view_names_for_sys(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect view_names = dialect.get_view_names(connection=c, schema="sys", use_sql_fallback=use_sql_fallback) assert len(view_names) == 0 @@ -152,7 +154,7 @@ def test_get_view_names_for_sys(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_view_definition(self, use_sql_fallback,engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect view_definition = dialect.get_view_definition(connection=c, schema=self.schema, view_name="v", use_sql_fallback=use_sql_fallback) assert self.view_defintion == view_definition @@ -161,7 +163,7 @@ def test_get_view_definition(self, use_sql_fallback,engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_view_definition_view_name_none(self, use_sql_fallback,engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect view_definition = dialect.get_view_definition(connection=c, schema=self.schema, view_name=None, use_sql_fallback=use_sql_fallback) assert view_definition is None @@ -170,7 +172,7 @@ def test_get_view_definition_view_name_none(self, use_sql_fallback,engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_compare_get_view_names_for_sql_and_odbc(self, schema,engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect c.execute("OPEN SCHEMA %s" % self.schema) view_names_fallback = dialect.get_view_names(connection=c, schema=schema, use_sql_fallback=True) view_names_odbc = dialect.get_view_names(connection=c, schema=schema) @@ -183,7 +185,7 @@ def test_compare_get_view_definition_for_sql_and_odbc(self, schema,engine_name): if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) view_name = "v" - dialect = EXADialect() + dialect = Inspector(c).dialect view_definition_fallback = dialect.get_view_definition( connection=c, view_name=view_name, schema=schema, use_sql_fallback=True) view_definition_odbc = dialect.get_view_definition( @@ -195,9 +197,9 @@ def test_compare_get_view_definition_for_sql_and_odbc(self, schema,engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_compare_get_columns_for_sql_and_odbc(self, schema, table, engine_name): with self.engine_map[engine_name].begin() as c: + dialect = Inspector(c).dialect if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) - dialect = EXADialect() columns_fallback = dialect.get_columns(connection=c, table_name=table, schema=schema, use_sql_fallback=True) columns_odbc = dialect.get_columns(connection=c, table_name=table, schema=schema) assert str(columns_fallback) == str(columns_odbc) # object equality doesn't work for sqltypes @@ -208,7 +210,7 @@ def test_compare_get_columns_none_table_for_sql_and_odbc(self, schema, engine_na with self.engine_map[engine_name].begin() as c: if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) - dialect = EXADialect() + dialect = Inspector(c).dialect table = None columns_fallback = dialect.get_columns(connection=c, table_name=table, schema=schema, use_sql_fallback=True) @@ -222,7 +224,7 @@ def make_columns_comparable(self, column_list): # object equality doesn't work @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_columns(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect columns = dialect.get_columns(connection=c, schema=self.schema, table_name="t", use_sql_fallback=use_sql_fallback) expected = [{'default': None, @@ -253,7 +255,7 @@ def test_get_columns(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_columns_table_name_none(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect columns = dialect.get_columns(connection=c, schema=self.schema, table_name=None, use_sql_fallback=use_sql_fallback) assert columns == [] @@ -265,7 +267,7 @@ def test_compare_get_pk_constraint_for_sql_and_odbc(self, schema, table, engine_ with self.engine_map[engine_name].begin() as c: if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) - dialect = EXADialect() + dialect = Inspector(c).dialect pk_constraint_fallback = dialect.get_pk_constraint(connection=c, table_name=table, schema=schema, use_sql_fallback=True) pk_constraint_odbc = dialect.get_pk_constraint(connection=c, table_name=table, schema=schema) @@ -275,7 +277,7 @@ def test_compare_get_pk_constraint_for_sql_and_odbc(self, schema, table, engine_ @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_pk_constraint(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect pk_constraint = dialect.get_pk_constraint(connection=c, schema=self.schema, table_name="t", use_sql_fallback=use_sql_fallback) assert pk_constraint["constrained_columns"] == ['pid1', 'pid2'] and \ @@ -285,7 +287,7 @@ def test_get_pk_constraint(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_pk_constraint_table_name_none(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect pk_constraint = dialect.get_pk_constraint(connection=c, schema=self.schema, table_name=None, use_sql_fallback=use_sql_fallback) assert pk_constraint is None @@ -297,7 +299,7 @@ def test_compare_get_foreign_keys_for_sql_and_odbc(self, schema, table, engine_n with self.engine_map[engine_name].begin() as c: if schema is None: c.execute("OPEN SCHEMA %s" % self.schema_2) - dialect = EXADialect() + dialect = Inspector(c).dialect foreign_keys_fallback = dialect.get_foreign_keys(connection=c, table_name=table, schema=schema, use_sql_fallback=True) foreign_keys_odbc = dialect.get_foreign_keys(connection=c, table_name=table, schema=schema) @@ -307,7 +309,7 @@ def test_compare_get_foreign_keys_for_sql_and_odbc(self, schema, table, engine_n @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_foreign_keys(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect foreign_keys = dialect.get_foreign_keys(connection=c, schema=self.schema, table_name="s", use_sql_fallback=use_sql_fallback) expected = [{'name': 'fk_test', @@ -322,7 +324,7 @@ def test_get_foreign_keys(self, use_sql_fallback, engine_name): @pytest.mark.parametrize("engine_name", [ENGINE_NONE_DATABASE, ENGINE_SCHEMA_DATABASE, ENGINE_SCHEMA_2_DATABASE]) def test_get_foreign_keys_table_name_none(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: - dialect = EXADialect() + dialect = Inspector(c).dialect foreign_keys = dialect.get_foreign_keys(connection=c, schema=self.schema, table_name=None, use_sql_fallback=use_sql_fallback) assert foreign_keys == [] From f160052bd03b54ae4a6bbdc2312522e179825580 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Mon, 16 May 2022 16:14:50 +0200 Subject: [PATCH 15/27] Refactor get_schema_names --- sqlalchemy_exasol/base.py | 11 +++++------ sqlalchemy_exasol/pyodbc.py | 5 +++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index cae0d56e..56d454b6 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -492,15 +492,14 @@ def use_sql_fallback(self, **kw): logger.warning("Using sql fallback instead of odbc functions") return result + def _get_schema_names_query(self, connection, **kw): + return "select SCHEMA_NAME from SYS.EXA_SCHEMAS" + # never called during reflection @reflection.cache def get_schema_names(self, connection, **kw): - if self.use_sql_fallback(**kw): - prefix = "/*snapshot execution*/ " - else: - prefix = "" - sql_stmnt = "%sselect SCHEMA_NAME from SYS.EXA_SCHEMAS" % prefix - rs = connection.execute(sql.text(sql_stmnt)) + sql_statement = self._get_schema_names_query(connection, **kw) + rs = connection.execute(sql.text(sql_statement)) return [self.normalize_name(row[0]) for row in rs] def _get_schema_for_input_or_current(self, connection, schema): diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index 393d496d..e05551e1 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -173,6 +173,11 @@ def _is_sql_fallback_requested(**kwargs): logger.warning("Using sql fallback instead of odbc functions") return is_fallback_requested + def _get_schema_names_query(self, connection, **kw): + if self._is_sql_fallback_requested(**kw): + return super()._get_schema_names_query(connection, **kw) + return "/*snapshot execution*/ " + super()._get_schema_names_query(connection, **kw) + @reflection.cache def _get_columns(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): From 897b034fb447ca136a36daff103f4f2d9ea677b3 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 08:24:00 +0200 Subject: [PATCH 16/27] Move get_table_names and has_table to pyodbc based dialaect --- sqlalchemy_exasol/base.py | 33 +-------------------------------- sqlalchemy_exasol/pyodbc.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index 56d454b6..5273a6fb 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -527,21 +527,6 @@ def _get_tables_for_schema_odbc(self, connection, odbc_connection, schema, table @reflection.cache def get_table_names(self, connection, schema, **kw): - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - return self.get_table_names_odbc(connection, odbc_connection, schema, **kw) - else: - return self.get_table_names_sql(connection, schema, **kw) - - @reflection.cache - def get_table_names_odbc(self, connection, odbc_connection, schema, **kw): - tables = self._get_tables_for_schema_odbc(connection, odbc_connection, schema, table_type="TABLE", **kw) - normalized_tables = [self.normalize_name(row.table_name) - for row in tables] - return normalized_tables - - @reflection.cache - def get_table_names_sql(self, connection, schema, **kw): schema = self._get_schema_for_input(connection, schema) sql_stmnt = "SELECT table_name FROM SYS.EXA_ALL_TABLES WHERE table_schema = " if schema is None: @@ -555,21 +540,6 @@ def get_table_names_sql(self, connection, schema, **kw): return tables def has_table(self, connection, table_name, schema=None, **kw): - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - result=self.has_table_odbc(connection, odbc_connection, schema=schema, table_name=table_name, **kw) - else: - result=self.has_table_sql(connection, schema=schema, table_name=table_name, **kw) - return result - - def has_table_odbc(self, connection, odbc_connection, table_name, schema=None, **kw): - tables = self.get_table_names_odbc(connection=connection, - odbc_connection=odbc_connection, - schema=schema, table_name=table_name, **kw) - result = self.normalize_name(table_name) in tables - return result - - def has_table_sql(self, connection, table_name, schema=None, **kw): schema = self._get_schema_for_input(connection, schema) sql_stmnt = "SELECT table_name from SYS.EXA_ALL_TABLES " \ "WHERE table_name = :table_name " @@ -580,8 +550,7 @@ def has_table_sql(self, connection, table_name, schema=None, **kw): table_name=self.denormalize_name(table_name), schema=self.denormalize_name(schema)) row = rp.fetchone() - - return (row is not None) + return row is not None @reflection.cache def get_view_names(self, connection, schema=None, **kw): diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index e05551e1..f5a75ca1 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -173,6 +173,26 @@ def _is_sql_fallback_requested(**kwargs): logger.warning("Using sql fallback instead of odbc functions") return is_fallback_requested + def get_table_names(self, connection, schema, **kw): + if self._is_sql_fallback_requested(**kw): + return super().get_table_names(connection, schema, **kw) + odbc_connection = self.getODBCConnection(connection) + tables = self._get_tables_for_schema_odbc(connection, odbc_connection, schema, table_type="TABLE", **kw) + normalized_tables = [self.normalize_name(row.table_name) for row in tables] + return normalized_tables + + def has_table(self, connection, table_name, schema=None, **kw): + if self._is_sql_fallback_requested(**kw): + return super().has_table(connection, table_name, schema, **kw) + tables = self.get_table_names( + connection=connection, + schema=schema, + table_name=table_name, + **kw + ) + result = self.normalize_name(table_name) in tables + return result + def _get_schema_names_query(self, connection, **kw): if self._is_sql_fallback_requested(**kw): return super()._get_schema_names_query(connection, **kw) From 8aefefbf2e4588977ac9c0f63bcd3a8ceb3fd5c0 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 08:46:47 +0200 Subject: [PATCH 17/27] Move odbc based get_view_names to pyodbc dialect --- sqlalchemy_exasol/base.py | 13 ------------- sqlalchemy_exasol/pyodbc.py | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index 5273a6fb..9fe40d14 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -554,19 +554,6 @@ def has_table(self, connection, table_name, schema=None, **kw): @reflection.cache def get_view_names(self, connection, schema=None, **kw): - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - return self.get_view_names_odbc(connection, odbc_connection, schema, **kw) - else: - return self.get_view_names_sql(connection, schema, **kw) - - @reflection.cache - def get_view_names_odbc(self, connection, odbc_connection, schema=None, **kw): - tables = self._get_tables_for_schema_odbc(connection, odbc_connection, schema, table_type="VIEW", **kw) - return [self.normalize_name(row.table_name) - for row in tables] - - def get_view_names_sql(self, connection, schema=None, **kw): schema = self._get_schema_for_input(connection, schema) sql_stmnt = "SELECT view_name FROM SYS.EXA_ALL_VIEWS WHERE view_schema = " if schema is None: diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index f5a75ca1..a1121ea6 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -181,6 +181,20 @@ def get_table_names(self, connection, schema, **kw): normalized_tables = [self.normalize_name(row.table_name) for row in tables] return normalized_tables + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if self._is_sql_fallback_requested(**kw): + return super().get_view_names(connection, schema, **kw) + odbc_connection = self.getODBCConnection(connection) + tables = self._get_tables_for_schema_odbc( + connection, + odbc_connection, + schema, + table_type="VIEW", + **kw + ) + return [self.normalize_name(row.table_name) for row in tables] + def has_table(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super().has_table(connection, table_name, schema, **kw) From 2641cbd6eaf64b38186567aaf6753d9319c1b04a Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 09:05:38 +0200 Subject: [PATCH 18/27] Move odbc based get_view_definition to pyodbc based dialect --- sqlalchemy_exasol/base.py | 38 +++++-------------------------------- sqlalchemy_exasol/pyodbc.py | 24 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index 9fe40d14..8400c0cd 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -567,36 +567,6 @@ def get_view_names(self, connection, schema=None, **kw): @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): - odbc_connection = self.getODBCConnection(connection) - if odbc_connection is not None and not self.use_sql_fallback(**kw): - return self.get_view_definition_odbc(connection, odbc_connection, view_name, schema, **kw) - else: - return self.get_view_definition_sql(connection, view_name, schema, **kw) - - def quote_string_value(self, string_value): - return "'%s'" % (string_value.replace("'", "''")) - - @reflection.cache - def get_view_definition_odbc(self, connection, odbc_connection, view_name, schema=None, **kw): - if view_name is None: - return None - tables = self._get_tables_for_schema_odbc(connection, odbc_connection, schema, table_type="VIEW", - table_name=view_name, **kw) - if len(tables) == 1: - quoted_view_name_string = self.quote_string_value(tables[0][2]) - quoted_view_schema_string = self.quote_string_value(tables[0][1]) - sql_stmnt = \ - "/*snapshot execution*/ SELECT view_text FROM sys.exa_all_views WHERE view_name = {view_name} AND view_schema = {view_schema}".format( - view_name=quoted_view_name_string, view_schema=quoted_view_schema_string) - rp = connection.execute(sql.text(sql_stmnt)).scalar() - if rp: - return rp - else: - return None - else: - return None - - def get_view_definition_sql(self, connection, view_name, schema=None, **kw): schema = self._get_schema_for_input(connection, schema) sql_stmnt = "SELECT view_text FROM sys.exa_all_views WHERE view_name = :view_name AND view_schema = " if schema is None: @@ -606,10 +576,12 @@ def get_view_definition_sql(self, connection, view_name, schema=None, **kw): rp = connection.execute(sql.text(sql_stmnt), view_name=self.denormalize_name(view_name), schema=self.denormalize_name(schema)).scalar() - if rp: - return rp - else: + if not rp: return None + return rp + + def quote_string_value(self, string_value): + return "'%s'" % (string_value.replace("'", "''")) @staticmethod def get_column_sql_query_str(): diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index a1121ea6..45f1cfc7 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -11,6 +11,7 @@ import logging from distutils.version import LooseVersion +from sqlalchemy import sql from sqlalchemy.engine import reflection from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy.util.langhelpers import asbool @@ -173,6 +174,29 @@ def _is_sql_fallback_requested(**kwargs): logger.warning("Using sql fallback instead of odbc functions") return is_fallback_requested + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if self._is_sql_fallback_requested(**kw): + return super().get_view_definition(connection, view_name, schema, **kw) + if view_name is None: + return None + + odbc_connection = self.getODBCConnection(connection) + tables = self._get_tables_for_schema_odbc(connection, odbc_connection, schema, table_type="VIEW", + table_name=view_name, **kw) + if len(tables) != 1: + return None + + quoted_view_name_string = self.quote_string_value(tables[0][2]) + quoted_view_schema_string = self.quote_string_value(tables[0][1]) + sql_stmnt = \ + "/*snapshot execution*/ SELECT view_text FROM sys.exa_all_views WHERE view_name = {view_name} AND view_schema = {view_schema}".format( + view_name=quoted_view_name_string, view_schema=quoted_view_schema_string) + rp = connection.execute(sql.text(sql_stmnt)).scalar() + if not rp: + return None + return rp + def get_table_names(self, connection, schema, **kw): if self._is_sql_fallback_requested(**kw): return super().get_table_names(connection, schema, **kw) From e12b6b0a4cc15e4444cf5299c36854fb305ac389 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 09:06:33 +0200 Subject: [PATCH 19/27] Remove use_sql_fallback method from base dialect --- sqlalchemy_exasol/base.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index 8400c0cd..4adbfa1e 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -486,12 +486,6 @@ def getODBCConnection(self, connection): else: return None - def use_sql_fallback(self, **kw): - result = "use_sql_fallback" in kw and kw.get("use_sql_fallback") == True - if result: - logger.warning("Using sql fallback instead of odbc functions") - return result - def _get_schema_names_query(self, connection, **kw): return "select SCHEMA_NAME from SYS.EXA_SCHEMAS" From c5d25a1d68410fbe782b50ecff772d3407132583 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 09:13:42 +0200 Subject: [PATCH 20/27] Move getODBCConnection from based dialect to pyodbc dialect --- sqlalchemy_exasol/base.py | 18 +----------------- sqlalchemy_exasol/pyodbc.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index 4adbfa1e..c296e714 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -50,8 +50,7 @@ from decimal import Decimal from sqlalchemy import sql, schema, types as sqltypes, util, event from sqlalchemy.schema import AddConstraint, ForeignKeyConstraint -from sqlalchemy.engine import default, reflection, Engine, Connection -from sqlalchemy.orm.session import Session +from sqlalchemy.engine import default, reflection from sqlalchemy.sql import compiler from sqlalchemy.sql.elements import quoted_name from datetime import date, datetime @@ -471,21 +470,6 @@ def on_connect(self): # TODO: set isolation level pass - def getODBCConnection(self, connection): - if isinstance(connection, Engine): - odbc_connection = connection.raw_connection().connection - elif isinstance(connection, Session): - odbc_connection = connection.connection() - return self.getODBCConnection(odbc_connection) - elif isinstance(connection, Connection): - odbc_connection = connection.connection.connection - else: - return None - if "pyodbc.Connection" in str(type(odbc_connection)): - return odbc_connection - else: - return None - def _get_schema_names_query(self, connection, **kw): return "select SCHEMA_NAME from SYS.EXA_SCHEMAS" diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index 45f1cfc7..85f3b3b7 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -12,7 +12,8 @@ from distutils.version import LooseVersion from sqlalchemy import sql -from sqlalchemy.engine import reflection +from sqlalchemy.engine import reflection, Engine, Connection +from sqlalchemy.orm.session import Session from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy.util.langhelpers import asbool @@ -174,6 +175,21 @@ def _is_sql_fallback_requested(**kwargs): logger.warning("Using sql fallback instead of odbc functions") return is_fallback_requested + def getODBCConnection(self, connection): + if isinstance(connection, Engine): + odbc_connection = connection.raw_connection().connection + elif isinstance(connection, Session): + odbc_connection = connection.connection() + return self.getODBCConnection(odbc_connection) + elif isinstance(connection, Connection): + odbc_connection = connection.connection.connection + else: + return None + if "pyodbc.Connection" in str(type(odbc_connection)): + return odbc_connection + else: + return None + @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): From a89bd0ad4ef36934b9da8bb28a2f584e96ef7ee1 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 09:25:29 +0200 Subject: [PATCH 21/27] Move _get_tables_for_schema_odbc from base to pyodbc dialect --- sqlalchemy_exasol/base.py | 8 -------- sqlalchemy_exasol/pyodbc.py | 8 ++++++++ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index c296e714..f575d62b 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -495,14 +495,6 @@ def _get_current_schema(self, connection): current_schema = connection.execute(current_schema_stmnt).fetchone()[0] return current_schema - @reflection.cache - def _get_tables_for_schema_odbc(self, connection, odbc_connection, schema, table_type=None, table_name=None, **kw): - schema = self._get_schema_for_input_or_current(connection, schema) - table_name = self.denormalize_name(table_name) - with odbc_connection.cursor().tables(schema=schema, tableType=table_type, table=table_name) as table_cursor: - rows = [row for row in table_cursor] - return rows - @reflection.cache def get_table_names(self, connection, schema, **kw): schema = self._get_schema_for_input(connection, schema) diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index 85f3b3b7..fc3b5103 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -190,6 +190,14 @@ def getODBCConnection(self, connection): else: return None + @reflection.cache + def _get_tables_for_schema_odbc(self, connection, odbc_connection, schema, table_type=None, table_name=None, **kw): + schema = self._get_schema_for_input_or_current(connection, schema) + table_name = self.denormalize_name(table_name) + with odbc_connection.cursor().tables(schema=schema, tableType=table_type, table=table_name) as table_cursor: + rows = [row for row in table_cursor] + return rows + @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): From 5fb3bd0cef714c0b560fb17a50751bc67eec4a98 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 12:28:32 +0200 Subject: [PATCH 22/27] Minor refactorings on Exasol base dialect --- sqlalchemy_exasol/base.py | 118 ++++++++++++++++++++++---------------- 1 file changed, 67 insertions(+), 51 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index f575d62b..70c445d3 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -434,7 +434,8 @@ def _get_default_schema_name(self, connection): def _get_schema_from_url(self, connection, schema): if connection.engine.url is not None and connection.engine.url != "": schema = self.normalize_name( - connection.engine.url.translate_connect_args().get('database')) + connection.engine.url.translate_connect_args().get('database') + ) return schema def normalize_name(self, name): @@ -458,8 +459,7 @@ def denormalize_name(self, name): """ if name is None or len(name) == 0: return None - elif name.lower() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): + elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): name = name.upper() return name @@ -477,8 +477,8 @@ def _get_schema_names_query(self, connection, **kw): @reflection.cache def get_schema_names(self, connection, **kw): sql_statement = self._get_schema_names_query(connection, **kw) - rs = connection.execute(sql.text(sql_statement)) - return [self.normalize_name(row[0]) for row in rs] + result = connection.execute(sql.text(sql_statement)) + return [self.normalize_name(row[0]) for row in result] def _get_schema_for_input_or_current(self, connection, schema): schema = self._get_schema_for_input(connection, schema) @@ -487,53 +487,63 @@ def _get_schema_for_input_or_current(self, connection, schema): return self.denormalize_name(schema) def _get_schema_for_input(self, connection, schema): - schema = self.denormalize_name(schema or self._get_schema_from_url(connection, schema)) - return schema + return self.denormalize_name( + schema or self._get_schema_from_url(connection, schema) + ) - def _get_current_schema(self, connection): - current_schema_stmnt = "SELECT CURRENT_SCHEMA" - current_schema = connection.execute(current_schema_stmnt).fetchone()[0] + @staticmethod + def _get_current_schema(connection): + sql_statement = "SELECT CURRENT_SCHEMA" + current_schema = connection.execute(sql_statement).fetchone()[0] return current_schema @reflection.cache def get_table_names(self, connection, schema, **kw): schema = self._get_schema_for_input(connection, schema) - sql_stmnt = "SELECT table_name FROM SYS.EXA_ALL_TABLES WHERE table_schema = " + sql_statement = "SELECT table_name FROM SYS.EXA_ALL_TABLES WHERE table_schema = " if schema is None: - sql_stmnt += "CURRENT_SCHEMA ORDER BY table_name" - rs = connection.execute(sql_stmnt) + sql_statement += "CURRENT_SCHEMA ORDER BY table_name" + result = connection.execute(sql_statement) else: - sql_stmnt += ":schema ORDER BY table_name" - rs = connection.execute(sql.text(sql_stmnt), - schema=self.denormalize_name(schema)) - tables = [self.normalize_name(row[0]) for row in rs] + sql_statement += ":schema ORDER BY table_name" + result = connection.execute( + sql.text(sql_statement), + schema=self.denormalize_name(schema) + ) + tables = [self.normalize_name(row[0]) for row in result] return tables def has_table(self, connection, table_name, schema=None, **kw): schema = self._get_schema_for_input(connection, schema) - sql_stmnt = "SELECT table_name from SYS.EXA_ALL_TABLES " \ - "WHERE table_name = :table_name " + sql_statement = ( + "SELECT table_name from SYS.EXA_ALL_TABLES " + "WHERE table_name = :table_name " + ) if schema is not None: - sql_stmnt += "AND table_schema = :schema" - rp = connection.execute( - sql.text(sql_stmnt), + sql_statement += "AND table_schema = :schema" + + result = connection.execute( + sql.text(sql_statement), table_name=self.denormalize_name(table_name), - schema=self.denormalize_name(schema)) - row = rp.fetchone() + schema=self.denormalize_name(schema) + ) + row = result.fetchone() return row is not None @reflection.cache def get_view_names(self, connection, schema=None, **kw): schema = self._get_schema_for_input(connection, schema) - sql_stmnt = "SELECT view_name FROM SYS.EXA_ALL_VIEWS WHERE view_schema = " + sql_statement = "SELECT view_name FROM SYS.EXA_ALL_VIEWS WHERE view_schema = " if schema is None: - sql_stmnt += "CURRENT_SCHEMA ORDER BY view_name" - rs = connection.execute(sql.text(sql_stmnt)) + sql_statement += "CURRENT_SCHEMA ORDER BY view_name" + result = connection.execute(sql.text(sql_statement)) else: - sql_stmnt += ":schema ORDER BY view_name" - rs = connection.execute(sql.text(sql_stmnt), - schema=self.denormalize_name(schema)) - return [self.normalize_name(row[0]) for row in rs] + sql_statement += ":schema ORDER BY view_name" + result = connection.execute( + sql.text(sql_statement), + schema=self.denormalize_name(schema) + ) + return [self.normalize_name(row[0]) for row in result] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): @@ -543,14 +553,15 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): sql_stmnt += "CURRENT_SCHEMA" else: sql_stmnt += ":schema" - rp = connection.execute(sql.text(sql_stmnt), - view_name=self.denormalize_name(view_name), - schema=self.denormalize_name(schema)).scalar() - if not rp: - return None - return rp + result = connection.execute( + sql.text(sql_stmnt), + view_name=self.denormalize_name(view_name), + schema=self.denormalize_name(schema) + ).scalar() + return result if result else None - def quote_string_value(self, string_value): + @staticmethod + def quote_string_value(string_value): return "'%s'" % (string_value.replace("'", "''")) @staticmethod @@ -582,12 +593,12 @@ def _get_columns(self, connection, table_name, schema=None, **kw): schema_str = "CURRENT_SCHEMA" if schema is None else ":schema" table_name_str = ":table" sql_statement = self.get_column_sql_query_str().format(schema=schema_str, table=table_name_str) - response = connection.execute( + result = connection.execute( sql.text(sql_statement), schema=self.denormalize_name(schema), table=self.denormalize_name(table_name) ) - return list(response) + return list(result) @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -595,9 +606,12 @@ def get_columns(self, connection, table_name, schema=None, **kw): return [] columns = [] - rows = self._get_columns(connection, - table_name=table_name, - schema=schema, **kw) + rows = self._get_columns( + connection, + table_name=table_name, + schema=schema, + **kw + ) table_name = self.denormalize_name(table_name) for row in rows: (colname, coltype, length, precision, scale, nullable, default, identity, is_distribution_key) = \ @@ -679,13 +693,15 @@ def _get_pk_constraint(self, connection, table_name, schema, **kw): schema_string = "CURRENT_SCHEMA " else: schema_string = ":schema " - sql_stmnt=self._get_constraint_sql_str(schema_string, table_name_string, "PRIMARY KEY") - rp = connection.execute(sql.text(sql_stmnt), - schema=self.denormalize_name(schema), - table=table_name) + sql_statement = self._get_constraint_sql_str(schema_string, table_name_string, "PRIMARY KEY") + result = connection.execute( + sql.text(sql_statement), + schema=self.denormalize_name(schema), + table=table_name + ) pkeys = [] constraint_name = None - for row in list(rp): + for row in list(result): if (row[5] != table_name and table_name is not None) or row[6] != 'PRIMARY KEY': continue pkeys.append(self.normalize_name(row[1])) @@ -705,12 +721,12 @@ def _get_foreign_keys(self, connection, table_name, schema=None, **kw): table_name_string = ":table" schema_string = "CURRENT_SCHEMA " if schema is None else ":schema " sql_statement = self._get_constraint_sql_str(schema_string, table_name_string, "FOREIGN KEY") - response = connection.execute( + result = connection.execute( sql.text(sql_statement), schema=self.denormalize_name(schema), table=self.denormalize_name(table_name) ) - return list(response) + return list(result) @reflection.cache @@ -755,5 +771,5 @@ def fkey_rec(): @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): - # EXASolution has no explicit indexes + """ EXASolution has no explicit indexes""" return [] From b88f0560f7f27cb1a77b6f810174eda76aa20924 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 12:32:30 +0200 Subject: [PATCH 23/27] Remove unused variables/instructions --- sqlalchemy_exasol/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sqlalchemy_exasol/base.py b/sqlalchemy_exasol/base.py index 70c445d3..bd75cd7b 100644 --- a/sqlalchemy_exasol/base.py +++ b/sqlalchemy_exasol/base.py @@ -612,7 +612,6 @@ def get_columns(self, connection, table_name, schema=None, **kw): schema=schema, **kw ) - table_name = self.denormalize_name(table_name) for row in rows: (colname, coltype, length, precision, scale, nullable, default, identity, is_distribution_key) = \ (row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8]) @@ -746,7 +745,6 @@ def fkey_rec(): fkeys = util.defaultdict(fkey_rec) constraints = self._get_foreign_keys(connection, table_name=table_name, schema=schema_int, **kw) - table_name = self.denormalize_name(table_name) for row in constraints: (cons_name, local_column, remote_schema, remote_table, remote_column) = \ (row[0], row[1], row[2], row[3], row[4]) From 7ced60543808b0051fc848a0eed165f81bbf4987 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 12:43:34 +0200 Subject: [PATCH 24/27] Minor refactorings pyodbc based dialect --- sqlalchemy_exasol/pyodbc.py | 62 +++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index fc3b5103..827a8a74 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -175,12 +175,12 @@ def _is_sql_fallback_requested(**kwargs): logger.warning("Using sql fallback instead of odbc functions") return is_fallback_requested - def getODBCConnection(self, connection): + def _get_odbc_connection(self, connection): if isinstance(connection, Engine): odbc_connection = connection.raw_connection().connection elif isinstance(connection, Session): odbc_connection = connection.connection() - return self.getODBCConnection(odbc_connection) + return self._get_odbc_connection(odbc_connection) elif isinstance(connection, Connection): odbc_connection = connection.connection.connection else: @@ -205,35 +205,43 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): if view_name is None: return None - odbc_connection = self.getODBCConnection(connection) - tables = self._get_tables_for_schema_odbc(connection, odbc_connection, schema, table_type="VIEW", - table_name=view_name, **kw) + odbc_connection = self._get_odbc_connection(connection) + tables = self._get_tables_for_schema_odbc( + connection, odbc_connection, schema, + table_type="VIEW", + table_name=view_name, + **kw + ) if len(tables) != 1: return None quoted_view_name_string = self.quote_string_value(tables[0][2]) quoted_view_schema_string = self.quote_string_value(tables[0][1]) - sql_stmnt = \ - "/*snapshot execution*/ SELECT view_text FROM sys.exa_all_views WHERE view_name = {view_name} AND view_schema = {view_schema}".format( - view_name=quoted_view_name_string, view_schema=quoted_view_schema_string) - rp = connection.execute(sql.text(sql_stmnt)).scalar() - if not rp: - return None - return rp + sql_statement = ( + "/*snapshot execution*/ SELECT view_text " + f"FROM sys.exa_all_views WHERE view_name = {quoted_view_name_string} " + f"AND view_schema = {quoted_view_schema_string}" + ) + result = connection.execute(sql.text(sql_statement)).scalar() + return result if result else None def get_table_names(self, connection, schema, **kw): if self._is_sql_fallback_requested(**kw): return super().get_table_names(connection, schema, **kw) - odbc_connection = self.getODBCConnection(connection) - tables = self._get_tables_for_schema_odbc(connection, odbc_connection, schema, table_type="TABLE", **kw) - normalized_tables = [self.normalize_name(row.table_name) for row in tables] - return normalized_tables + odbc_connection = self._get_odbc_connection(connection) + tables = self._get_tables_for_schema_odbc( + connection, + odbc_connection, + schema, table_type="TABLE", + **kw + ) + return [self.normalize_name(row.table_name) for row in tables] @reflection.cache def get_view_names(self, connection, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super().get_view_names(connection, schema, **kw) - odbc_connection = self.getODBCConnection(connection) + odbc_connection = self._get_odbc_connection(connection) tables = self._get_tables_for_schema_odbc( connection, odbc_connection, @@ -252,8 +260,7 @@ def has_table(self, connection, table_name, schema=None, **kw): table_name=table_name, **kw ) - result = self.normalize_name(table_name) in tables - return result + return self.normalize_name(table_name) in tables def _get_schema_names_query(self, connection, **kw): if self._is_sql_fallback_requested(**kw): @@ -265,7 +272,7 @@ def _get_columns(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super()._get_columns(connection, table_name, schema, **kw) - odbc_connection = self.getODBCConnection(connection) + odbc_connection = self._get_odbc_connection(connection) tables = self._get_tables_for_schema_odbc( connection, odbc_connection, schema=schema, @@ -293,7 +300,7 @@ def _get_pk_constraint(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super()._get_pk_constraint(connection, table_name, schema, **kw) - odbc_connection = self.getODBCConnection(connection) + odbc_connection = self._get_odbc_connection(connection) schema = self._get_schema_for_input_or_current(connection, schema) table_name = self.denormalize_name(table_name) with odbc_connection.cursor().primaryKeys(table=table_name, schema=schema) as cursor: @@ -312,10 +319,16 @@ def _get_foreign_keys(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super()._get_foreign_keys(connection, table_name, schema, **kw) - odbc_connection = self.getODBCConnection(connection) + odbc_connection = self._get_odbc_connection(connection) # Need to use a workaround, because SQLForeignKeys functions doesn't work for an unknown reason - tables = self._get_tables_for_schema_odbc(connection=connection, odbc_connection=odbc_connection, - schema=schema, table_name=table_name, table_type="TABLE", **kw) + tables = self._get_tables_for_schema_odbc( + connection=connection, + odbc_connection=odbc_connection, + schema=schema, + table_name=table_name, + table_type="TABLE", + **kw + ) if len(tables) == 0: return [] @@ -329,6 +342,7 @@ def _get_foreign_keys(self, connection, table_name, schema=None, **kw): ) ) response = connection.execute(sql_statement) + return list(response) From 17937e40b360d9227b23a4624022a864c5cdeb77 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 12:53:45 +0200 Subject: [PATCH 25/27] Add missing cache annotation --- sqlalchemy_exasol/pyodbc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index 827a8a74..a6d91f49 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -225,6 +225,7 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): result = connection.execute(sql.text(sql_statement)).scalar() return result if result else None + @reflection.cache def get_table_names(self, connection, schema, **kw): if self._is_sql_fallback_requested(**kw): return super().get_table_names(connection, schema, **kw) From 5dd460ce814716bcdda1fd3ce5d5b5805dc3f9af Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 13:07:59 +0200 Subject: [PATCH 26/27] Fix bug #136 --- sqlalchemy_exasol/pyodbc.py | 40 +++++++++---------------------------- 1 file changed, 9 insertions(+), 31 deletions(-) diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index a6d91f49..69946b79 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -175,28 +175,13 @@ def _is_sql_fallback_requested(**kwargs): logger.warning("Using sql fallback instead of odbc functions") return is_fallback_requested - def _get_odbc_connection(self, connection): - if isinstance(connection, Engine): - odbc_connection = connection.raw_connection().connection - elif isinstance(connection, Session): - odbc_connection = connection.connection() - return self._get_odbc_connection(odbc_connection) - elif isinstance(connection, Connection): - odbc_connection = connection.connection.connection - else: - return None - if "pyodbc.Connection" in str(type(odbc_connection)): - return odbc_connection - else: - return None - @reflection.cache - def _get_tables_for_schema_odbc(self, connection, odbc_connection, schema, table_type=None, table_name=None, **kw): + def _get_tables_for_schema_odbc(self, connection, schema, table_type=None, table_name=None, **kw): schema = self._get_schema_for_input_or_current(connection, schema) table_name = self.denormalize_name(table_name) - with odbc_connection.cursor().tables(schema=schema, tableType=table_type, table=table_name) as table_cursor: - rows = [row for row in table_cursor] - return rows + conn = connection.engine.raw_connection() + with conn.cursor().tables(schema=schema, tableType=table_type, table=table_name) as table_cursor: + return [row for row in table_cursor] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): @@ -205,9 +190,8 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): if view_name is None: return None - odbc_connection = self._get_odbc_connection(connection) tables = self._get_tables_for_schema_odbc( - connection, odbc_connection, schema, + connection, schema, table_type="VIEW", table_name=view_name, **kw @@ -229,10 +213,9 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): def get_table_names(self, connection, schema, **kw): if self._is_sql_fallback_requested(**kw): return super().get_table_names(connection, schema, **kw) - odbc_connection = self._get_odbc_connection(connection) tables = self._get_tables_for_schema_odbc( connection, - odbc_connection, + schema, table_type="TABLE", **kw ) @@ -242,10 +225,8 @@ def get_table_names(self, connection, schema, **kw): def get_view_names(self, connection, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super().get_view_names(connection, schema, **kw) - odbc_connection = self._get_odbc_connection(connection) tables = self._get_tables_for_schema_odbc( connection, - odbc_connection, schema, table_type="VIEW", **kw @@ -273,9 +254,8 @@ def _get_columns(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super()._get_columns(connection, table_name, schema, **kw) - odbc_connection = self._get_odbc_connection(connection) tables = self._get_tables_for_schema_odbc( - connection, odbc_connection, + connection, schema=schema, table_name=table_name, **kw @@ -301,10 +281,10 @@ def _get_pk_constraint(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super()._get_pk_constraint(connection, table_name, schema, **kw) - odbc_connection = self._get_odbc_connection(connection) + conn = connection.engine.raw_connection() schema = self._get_schema_for_input_or_current(connection, schema) table_name = self.denormalize_name(table_name) - with odbc_connection.cursor().primaryKeys(table=table_name, schema=schema) as cursor: + with conn.cursor().primaryKeys(table=table_name, schema=schema) as cursor: pkeys = [] constraint_name = None for row in cursor: @@ -320,11 +300,9 @@ def _get_foreign_keys(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super()._get_foreign_keys(connection, table_name, schema, **kw) - odbc_connection = self._get_odbc_connection(connection) # Need to use a workaround, because SQLForeignKeys functions doesn't work for an unknown reason tables = self._get_tables_for_schema_odbc( connection=connection, - odbc_connection=odbc_connection, schema=schema, table_name=table_name, table_type="TABLE", From 107cffc11c9282891e6f107d43a9af1e2ef491f4 Mon Sep 17 00:00:00 2001 From: Nicola Coretti Date: Tue, 17 May 2022 13:18:35 +0200 Subject: [PATCH 27/27] Minor refactorings to clean up the bug fix --- sqlalchemy_exasol/pyodbc.py | 39 ++++++++----------------------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/sqlalchemy_exasol/pyodbc.py b/sqlalchemy_exasol/pyodbc.py index 69946b79..e5af6470 100644 --- a/sqlalchemy_exasol/pyodbc.py +++ b/sqlalchemy_exasol/pyodbc.py @@ -12,8 +12,7 @@ from distutils.version import LooseVersion from sqlalchemy import sql -from sqlalchemy.engine import reflection, Engine, Connection -from sqlalchemy.orm.session import Session +from sqlalchemy.engine import reflection from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy.util.langhelpers import asbool @@ -176,7 +175,7 @@ def _is_sql_fallback_requested(**kwargs): return is_fallback_requested @reflection.cache - def _get_tables_for_schema_odbc(self, connection, schema, table_type=None, table_name=None, **kw): + def _tables_for_schema(self, connection, schema, table_type=None, table_name=None): schema = self._get_schema_for_input_or_current(connection, schema) table_name = self.denormalize_name(table_name) conn = connection.engine.raw_connection() @@ -190,12 +189,7 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): if view_name is None: return None - tables = self._get_tables_for_schema_odbc( - connection, schema, - table_type="VIEW", - table_name=view_name, - **kw - ) + tables = self._tables_for_schema(connection, schema, table_type="VIEW", table_name=view_name) if len(tables) != 1: return None @@ -213,24 +207,14 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): def get_table_names(self, connection, schema, **kw): if self._is_sql_fallback_requested(**kw): return super().get_table_names(connection, schema, **kw) - tables = self._get_tables_for_schema_odbc( - connection, - - schema, table_type="TABLE", - **kw - ) + tables = self._tables_for_schema(connection, schema, table_type="TABLE") return [self.normalize_name(row.table_name) for row in tables] @reflection.cache def get_view_names(self, connection, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super().get_view_names(connection, schema, **kw) - tables = self._get_tables_for_schema_odbc( - connection, - schema, - table_type="VIEW", - **kw - ) + tables = self._tables_for_schema(connection, schema, table_type="VIEW") return [self.normalize_name(row.table_name) for row in tables] def has_table(self, connection, table_name, schema=None, **kw): @@ -254,13 +238,7 @@ def _get_columns(self, connection, table_name, schema=None, **kw): if self._is_sql_fallback_requested(**kw): return super()._get_columns(connection, table_name, schema, **kw) - tables = self._get_tables_for_schema_odbc( - connection, - schema=schema, - table_name=table_name, - **kw - ) - + tables = self._tables_for_schema(connection, schema=schema, table_name=table_name) if len(tables) != 1: return [] @@ -301,12 +279,11 @@ def _get_foreign_keys(self, connection, table_name, schema=None, **kw): return super()._get_foreign_keys(connection, table_name, schema, **kw) # Need to use a workaround, because SQLForeignKeys functions doesn't work for an unknown reason - tables = self._get_tables_for_schema_odbc( + tables = self._tables_for_schema( connection=connection, schema=schema, table_name=table_name, - table_type="TABLE", - **kw + table_type="TABLE" ) if len(tables) == 0: return []