From 961e3a946c4b69a441020a45661dd0e624b1c851 Mon Sep 17 00:00:00 2001 From: Zehua Zou <41586196+HuaHuaY@users.noreply.github.com> Date: Tue, 29 Nov 2022 00:26:17 +0800 Subject: [PATCH 1/2] add get_indexes --- sqlalchemy_risingwave/base.py | 33 ++++++++++++++++++++++++++++++++- test/test_schema.py | 14 ++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sqlalchemy_risingwave/base.py b/sqlalchemy_risingwave/base.py index b2fc16f..5499cc8 100644 --- a/sqlalchemy_risingwave/base.py +++ b/sqlalchemy_risingwave/base.py @@ -113,7 +113,38 @@ def get_columns(self, conn, table_name, schema=None, **kw): return res def get_indexes(self, conn, table_name, schema=None, **kw): - return [] + table_oid = self.get_table_oid( + conn, table_name, schema, info_cache=kw.get("info_cache") + ) + + sql = ( + "select i.relname, a.attname, a.attnotnull from pg_catalog.pg_class t " + "join pg_catalog.pg_index ix on t.oid = ix.indrelid " + "join pg_catalog.pg_class i on i.oid = ix.indexrelid " + "join pg_catalog.pg_attribute a on t.oid = a.attrelid " + "where t.oid = :table_oid" + ) + rows = conn.execute( + text(sql), + {"table_oid": table_oid}, + ) + + indexes = {} + for row in rows: + if not row.relname in indexes: + indexes[row.relname] = [] + indexes[row.relname].append(row.attname) + + res = [] + for index in indexes: + res.append( + { + "name": index, + "column_names": indexes[index], + "unique": False, + } + ) + return res def get_foreign_keys_v1(self, conn, table_name, schema=None, **kw): raise [] diff --git a/test/test_schema.py b/test/test_schema.py index ee7577e..b09b886 100644 --- a/test/test_schema.py +++ b/test/test_schema.py @@ -20,6 +20,11 @@ def setup_method(self): """ ) ) + conn.execute( + text( + "CREATE index users_idx on users(name)" + ) + ) self.meta = MetaData(schema="public") def test_get_columns_indexes_across_schema(self): @@ -34,3 +39,12 @@ def test_returning_clause(self): for t in table_names: assert t == str("users") + + def test_get_indexes(self): + with testing.db.begin() as conn: + insp = inspect(testing.db) + indexes = insp.get_indexes("users") + + for index in indexes: + assert index["name"] == str("users_idx") + assert index["column_names"] == ["name"] From 26a14fd5ba5ba4e25b55fa5e06803da016084cb5 Mon Sep 17 00:00:00 2001 From: Zehua Zou <41586196+HuaHuaY@users.noreply.github.com> Date: Sat, 3 Dec 2022 12:49:07 +0800 Subject: [PATCH 2/2] fix wrong sql --- sqlalchemy_risingwave/base.py | 4 ++-- test/test_schema.py | 30 +++++++++++------------------- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/sqlalchemy_risingwave/base.py b/sqlalchemy_risingwave/base.py index 5499cc8..91fe021 100644 --- a/sqlalchemy_risingwave/base.py +++ b/sqlalchemy_risingwave/base.py @@ -118,10 +118,10 @@ def get_indexes(self, conn, table_name, schema=None, **kw): ) sql = ( - "select i.relname, a.attname, a.attnotnull from pg_catalog.pg_class t " + "select i.relname, a.attname from pg_catalog.pg_class t " "join pg_catalog.pg_index ix on t.oid = ix.indrelid " "join pg_catalog.pg_class i on i.oid = ix.indexrelid " - "join pg_catalog.pg_attribute a on t.oid = a.attrelid " + "join pg_catalog.pg_attribute a on t.oid = a.attrelid and a.attnum = ANY(ix.indkey)" "where t.oid = :table_oid" ) rows = conn.execute( diff --git a/test/test_schema.py b/test/test_schema.py index b09b886..ff97ebb 100644 --- a/test/test_schema.py +++ b/test/test_schema.py @@ -7,24 +7,11 @@ class SchemaTest(fixtures.TestBase): def teardown_method(self, method): with testing.db.begin() as conn: - conn.execute(text("DROP TABLE IF EXISTS users")) + conn.execute("DROP TABLE IF EXISTS users") def setup_method(self): with testing.db.begin() as conn: - conn.execute( - text( - """ - CREATE TABLE users ( - name STRING PRIMARY KEY - ) - """ - ) - ) - conn.execute( - text( - "CREATE index users_idx on users(name)" - ) - ) + conn.execute("CREATE TABLE users (name STRING PRIMARY KEY)") self.meta = MetaData(schema="public") def test_get_columns_indexes_across_schema(self): @@ -42,9 +29,14 @@ def test_returning_clause(self): def test_get_indexes(self): with testing.db.begin() as conn: + conn.execute("CREATE TABLE three_columns (id1 INT, id2 INT, id3 INT)") + conn.execute("CREATE INDEX three_columns_idx ON three_columns(id2) INCLUDE(id1)") + insp = inspect(testing.db) - indexes = insp.get_indexes("users") + indexes = insp.get_indexes("three_columns") + + assert len(indexes) == 1 + assert indexes[0]["name"] == "three_columns_idx" + assert indexes[0]["column_names"] == ["id2", "id1"] - for index in indexes: - assert index["name"] == str("users_idx") - assert index["column_names"] == ["name"] + conn.execute("DROP TABLE three_columns")