Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix schema generation #21

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/houseplant/houseplant.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def update_schema(self):
materialized_views = self.db.get_database_materialized_views()
dictionaries = self.db.get_database_dictionaries()

# Track processed tables to ensure first migration takes precedence
processed_tables = set()

# Group statements by type
table_statements = []
mv_statements = []
Expand All @@ -282,13 +285,18 @@ def update_schema(self):
if not table_name:
continue

# Skip if we've already processed this table
if table_name in processed_tables:
continue

# Check tables first
for table in tables:
if table[0] == table_name:
create_stmt = self.db.client.execute(
f"SHOW CREATE TABLE {table_name}"
)[0][0]
table_statements.append(create_stmt)
processed_tables.add(table_name)

# Then materialized views
for mv in materialized_views:
Expand All @@ -298,6 +306,7 @@ def update_schema(self):
0
][0]
mv_statements.append(create_stmt)
processed_tables.add(table_name)

# Finally dictionaries
for dict in dictionaries:
Expand All @@ -307,6 +316,7 @@ def update_schema(self):
f"SHOW CREATE DICTIONARY {dict_name}"
)[0][0]
dict_statements.append(create_stmt)
processed_tables.add(table_name)

# Write schema file
with open("ch/schema.sql", "w") as f:
Expand Down
101 changes: 101 additions & 0 deletions tests/test_houseplant.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,65 @@ def test_migration_with_view(tmp_path):
return migration_content


@pytest.fixture
def duplicate_migrations(tmp_path):
# Set up test environment
migrations_dir = tmp_path / "ch/migrations"
migrations_dir.mkdir(parents=True)

# Create two migrations that modify the same table
migration1 = migrations_dir / "20240101000000_first_migration.yml"
migration2 = migrations_dir / "20240102000000_second_migration.yml"

migration1_content = """version: "{version}"
name: {name}
table: events

development: &development
up: |
CREATE TABLE events (
id UInt32,
name String
) ENGINE = MergeTree()
ORDER BY id
down: |
DROP TABLE events

test:
<<: *development

production:
<<: *development
"""

migration2_content = """version: "{version}"
name: {name}
table: events

development: &development
up: |
ALTER TABLE events ADD COLUMN description String
down: |
ALTER TABLE events DROP COLUMN description

test:
<<: *development

production:
<<: *development
"""

migration1.write_text(
migration1_content.format(version="20240101000000", name="first_migration")
)
migration2.write_text(
migration2_content.format(version="20240102000000", name="second_migration")
)

os.chdir(tmp_path)
return ("20240101000000", "20240102000000")


def test_migrate_up_development(houseplant, test_migration, mocker):
# Mock environment and database calls
houseplant.env = "development"
Expand Down Expand Up @@ -297,6 +356,48 @@ def test_migrate_up_with_production_view(houseplant, test_migration_with_view, m
mock_get_applied.assert_called_once()


def test_update_schema_no_duplicates(houseplant, duplicate_migrations, mocker):
versions = duplicate_migrations

# Mock database calls
mocker.patch.object(
houseplant.db,
"get_applied_migrations",
return_value=[(versions[0],), (versions[1],)],
)
mocker.patch.object(
houseplant.db, "get_database_tables", return_value=[("events",)]
)
mocker.patch.object(
houseplant.db, "get_database_materialized_views", return_value=[]
)
mocker.patch.object(houseplant.db, "get_database_dictionaries", return_value=[])

# Mock the SHOW CREATE TABLE call
mocker.patch.object(
houseplant.db.client,
"execute",
return_value=[
[
"CREATE TABLE events (id UInt32, name String) ENGINE = MergeTree() ORDER BY id"
]
],
)

# Update schema
houseplant.update_schema()

# Read the generated schema file
with open("ch/schema.sql", "r") as f:
schema_content = f.read()

# Verify the table appears only once in the schema
table_count = schema_content.count("CREATE TABLE events")
assert (
table_count == 1
), f"Table 'events' appears {table_count} times in schema, expected 1"


@pytest.mark.skip
def test_migrate_down(houseplant, test_migration, mocker):
# Mock database calls
Expand Down
Loading