diff --git a/defog/__init__.py b/defog/__init__.py index 81aad61..0aea36e 100644 --- a/defog/__init__.py +++ b/defog/__init__.py @@ -227,22 +227,30 @@ def generate_postgres_schema(self, tables: list, upload: bool = True) -> str: cur.execute( "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" ) - tables = [row[0] for row in cur.fetchall()] + tables = [f"public.{row[0]}" for row in cur.fetchall()] + else: + for table in tables: + if not table or len(table.split("."))!=2: + raise ValueError(f"PostgreSQL table names should be of the following format . which is violated by '{table}`") + print("Retrieved the following tables:") + for t in tables: print(f"\t{t}") print("Getting schema for each table in your database...") # get the schema for each table - for table_name in tables: + for table in tables: + + schema,table_name = table.split(".") cur.execute( - "SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = %s;", - (table_name,), + "SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = %s and table_schema::text = %s;", + (table_name, schema), ) rows = cur.fetchall() rows = [row for row in rows] rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] - schemas[table_name] = rows + schemas[table] = rows # get foreign key relationships print("Getting foreign keys for each table in your database...")