Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PATCH] Circular Inheritance Support #8

Merged
merged 6 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
fake = faker.Faker()
number_of_fields = 1
excluded_tables = ["system_setting"]
tables_to_fill = ["user"]
tables_to_fill = []
graph = True

fields = [
Expand Down Expand Up @@ -81,7 +81,18 @@
"table": None,
"generator": lambda: fake.word().capitalize(),
},
{"name": None, "type": "date", "table": None, "generator": lambda: fake.date()},
{
"name": None,
"type": "float",
"table": None,
"generator": lambda: fake.random_element(elements=(1.0, 10.0)),
},
{
"name": None,
"type": "date",
"table": None,
"generator": lambda: fake.date(),
},
{
"name": None,
"type": "datetime",
Expand Down
27 changes: 25 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,31 @@
black==23.7.0
click==8.1.7
colorama==0.4.6
contourpy==1.1.0
cycler==0.11.0
Faker==18.9.0
fonttools==4.42.0
greenlet==2.0.2
keyboard==0.13.5
kiwisolver==1.4.4
markdown-it-py==3.0.0
matplotlib==3.7.2
mdurl==0.1.2
mypy-extensions==1.0.0
mysql-connector-python==8.1.0
networkx==3.1
numpy==1.25.2
packaging==23.1
pathspec==0.11.2
Pillow==10.0.0
platformdirs==3.10.0
protobuf==4.21.12
Pygments==2.16.1
pyparsing==3.0.9
python-dateutil==2.8.2
python-decouple==3.8
rich==13.5.2
six==1.16.0
SQLAlchemy==2.0.20
mysql-connector-python==8.1.0
SQLAlchemy-Utils==0.41.1
SQLAlchemy-Utils==0.41.1
typing_extensions==4.7.1
4 changes: 4 additions & 0 deletions src/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import enum

class Nothing(enum.Enum):
Nada = "This operation returned nothing"
120 changes: 83 additions & 37 deletions src/populate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import contextlib
import random
import re
from collections import OrderedDict
import time
from collections import OrderedDict

import keyboard
import matplotlib.pyplot as plt
import networkx as nx
import sqlalchemy
Expand All @@ -16,7 +17,10 @@
from rich.table import Table
from sqlalchemy import create_engine, inspect
from sqlalchemy_utils import has_unique_index
from rich import print

from .enums import Nothing

Nada = Nothing.Nada.value


class DatabasePopulator:
Expand Down Expand Up @@ -49,7 +53,6 @@ def __init__(
graph: bool = True,
special_fields: list[dict] = None,
) -> None:

db_url = f"mysql+mysqlconnector://{user}:{password}@{host}/{database}"

self.completed_tables_list = []
Expand All @@ -60,7 +63,7 @@ def __init__(
self.engine = create_engine(db_url, echo=False)
self.rows = rows
inspector = inspect(self.engine)

# If no tables are specified, fill all tables in the database
# Otherwise, fill the specified tables
tables_to_fill = tables_to_fill or inspector.get_table_names()
Expand Down Expand Up @@ -105,13 +108,15 @@ def show_end_banner(self):
banner = f.readlines()

print()
[ print(
Align(
bane.strip(),
align="center",
[
print(
Align(
bane.strip(),
align="center",
)
)

) for bane in banner ]
for bane in banner
]

success = random.choice(
[
Expand Down Expand Up @@ -312,28 +317,50 @@ def draw_graph(self):
plt.title("Database Inheritance Relationships")
plt.axis("off")
plt.show()

def remove_cycles(self, graph):
try:
# Find a cycle in the graph
cycle = nx.find_cycle(graph, orientation='original')
except nx.NetworkXNoCycle:
# No cycle found, return the graph as is
return graph

# If a cycle is found, remove an edge from the cycle
graph.remove_edge(*cycle[0][:2])

# Recursively call remove_cycles to remove other cycles
return self.remove_cycles(graph)

def arrange_graph(self):
"""
The function arranges identified inheritance relations in a directed graph and orders them
topologically.
topologically. It involves the user in resolving circular dependencies.
"""
graph = nx.DiGraph()

# Populate the graph
for table, inherited_tables in self.inheritance_relations.items():
if inherited_tables:
for inherited_table in inherited_tables:
if table != inherited_table: # Skip self-references
graph.add_edge(inherited_table, table)
adjacency = dict(graph.adjacency())
if table in adjacency:
adjacency = adjacency[table]
if inherited_table not in adjacency:
graph.add_edge(inherited_table, table)
else:
graph.add_edge(inherited_table, table)

else:
graph.add_node(table)

self.job_progress.advance(self.identifying_relations)

graph = self.remove_cycles(graph)
ordered_tables = list(nx.topological_sort(graph))


# Order the tables based on topological sort
ordered_inheritance_relations = OrderedDict()

for table in ordered_tables:
if table in self.inheritance_relations:
ordered_inheritance_relations[table] = self.inheritance_relations[table]
Expand All @@ -345,7 +372,7 @@ def arrange_graph(self):

def populate_fields(self, column, table):
"""
The function `populate_fields` populates a
The function `populate_fields` populates a
column with a value based on the column's name, type, and
table name.
"""
Expand Down Expand Up @@ -410,8 +437,14 @@ def handle_column_population(self, table, column):
value = self.populate_fields(column, table)
count -= 1
if count <= 0:
if column.nullable:
return None
raise ValueError(
f"I can't find a unique value to insert into column '{column.name}' in table '{table.name}'"
(
f"I can't find a unique value "
f"to insert into column '{column.name}' in "
f"table '{table.name}'"
)
)

return value
Expand All @@ -429,53 +462,61 @@ def get_unique_column_values(self, column, unique_columns, table):

conn = self.engine.connect()
s = sqlalchemy.select(table.c[column.name])

# Cache the column's unique values
self.cached_unique_column_values[column] = {
row[0] for row in conn.execute(s).fetchall()
}
conn.close()

return self.cached_unique_column_values[column]
return set()

def get_value(self, column, foreign_columns, unique_columns, table):
"""
The function `get_value` returns a value for a column in a table.
"""
# It first checks if the column is unique, if it is, it fetches a
# Check if the column is nullable with a 1 in 300 chance of returning None
if column.nullable and random.random() < 1 / 300:
return None

# It first checks if the column is unique, if it is, it fetches a
# set of unique values to insert

self.existing_values = self.get_unique_column_values(
column=column, unique_columns=unique_columns, table=table
)
# it calls the `process_foreign`
# function to check if the column is a foreign key
# if it is, it returns a value from the related table
value = self.process_foreign(
column=column,
foreign_columns=foreign_columns,
table=table,
)
if value is not None:
if Nada is not (
value := self.process_foreign(
column=column,
foreign_columns=foreign_columns,
table=table,
)
):
return value
# if the column is not a foreign key, it calls the `handle_column_population`
# function to populate the column with a value based on the definition from
# the `data.py` file
value = self.handle_column_population(table=table, column=column)
if value is not None:
elif Nada is not (
value := self.handle_column_population(table=table, column=column)
):
return value
else:
raise NotImplementedError(
"I have no idea what value to assign "
f"to the field '{column.name}' in '{table}'. "
"Maybe updating my `data.py` will help?"
f"to the field '{column.name}' of type ",
f"{column.type} in '{table}'. "
"Maybe updating my `data.py` will help?",
)

def get_related_table_fields(self, column, foreign_columns):
"""
The function `get_related_table_fields` returns a set of values from a related table
"""
# desc is a tuple containing the
# desc is a tuple containing the
# (name of the column, the name of the related table)
desc = foreign_columns[column.name]
# If the related table fields have already been cached, return them
Expand Down Expand Up @@ -504,17 +545,22 @@ def process_foreign(self, foreign_columns, table, column):
it returns a value from the related table.
"""
if column.name not in foreign_columns:
return None
return Nada
# Gets the related table fields from the `get_related_table_fields` function
related_table_fields = self.get_related_table_fields(column, foreign_columns)

# self.existing_values only gets populated if the column only accepts to unique values
if selectable_fields := related_table_fields - self.existing_values:
return random.choice(list(selectable_fields))
else:
raise ValueError(
f"Can't find a unique value to insert into column '{column.name}' in table '{table.name}'"
elif column.nullable:
return None
raise ValueError(
(
f"I can't find a unique value "
f"to insert into column '{column.name}' in "
f"table '{table.name}'"
)
)

def get_unique_columns(self, table):
return [column.name for column in table.columns if has_unique_index(column)]
Expand Down Expand Up @@ -577,10 +623,10 @@ def fill_table(self, inspector):

# Update the table panel with the current table being filled's name
self.handle_table_panel(self.inheritance_relations_list)

# Call the `handle_database_insertion` function to fill the current table
self.handle_database_insertion(table_name, inspector)

# Logic for how to display the table after it has been filled
self.inheritance_relations_list.remove(f"[yellow]{table_name}")
self.completed_tables_list.append(f"[green]{table_name}")
Expand Down