diff --git a/sde_collections/management/commands/database_backup.py b/sde_collections/management/commands/database_backup.py index edb93351..5f6551b3 100644 --- a/sde_collections/management/commands/database_backup.py +++ b/sde_collections/management/commands/database_backup.py @@ -4,6 +4,7 @@ Usage: docker-compose -f local.yml run --rm django python manage.py database_backup docker-compose -f local.yml run --rm django python manage.py database_backup --no-compress + docker-compose -f local.yml run --rm django python manage.py database_backup --output /path/to/output.sql docker-compose -f production.yml run --rm django python manage.py database_backup """ @@ -54,19 +55,41 @@ def add_arguments(self, parser): action="store_true", help="Disable backup file compression (enabled by default)", ) + parser.add_argument( + "--output", + type=str, + help="Output file path (default: auto-generated based on server name and date)", + ) - def get_backup_filename(self, server: Server, compress: bool) -> tuple[str, str]: + def get_backup_filename(self, server: Server, compress: bool, custom_output: str = None) -> tuple[str, str]: """Generate backup filename and actual dump path. + Args: + server: Server enum indicating the environment + compress: Whether the output should be compressed + custom_output: Optional custom output path + Returns: tuple[str, str]: A tuple containing (final_filename, temp_filename) - final_filename: The name of the final backup file (with .gz if compressed) - temp_filename: The name of the temporary dump file (always without .gz) """ - date_str = datetime.now().strftime("%Y%m%d") - temp_filename = f"{server.value.lower()}_backup_{date_str}.sql" - final_filename = f"{temp_filename}.gz" if compress else temp_filename - return final_filename, temp_filename + if custom_output: + # Ensure the output directory exists + output_dir = os.path.dirname(custom_output) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + if compress: + return custom_output + (".gz" if not custom_output.endswith(".gz") else ""), custom_output.removesuffix( + ".gz" + ) + return custom_output, custom_output + else: + date_str = datetime.now().strftime("%Y%m%d") + temp_filename = f"{server.value.lower()}_backup_{date_str}.sql" + final_filename = f"{temp_filename}.gz" if compress else temp_filename + return final_filename, temp_filename def run_pg_dump(self, output_file: str, env: dict) -> None: """Execute pg_dump with given parameters.""" @@ -95,7 +118,7 @@ def compress_file(self, input_file: str, output_file: str) -> None: def handle(self, *args, **options): server = detect_server() compress = not options["no_compress"] - backup_file, dump_file = self.get_backup_filename(server, compress) + backup_file, dump_file = self.get_backup_filename(server, compress, options.get("output")) env = os.environ.copy() env["PGPASSWORD"] = settings.DATABASES["default"]["PASSWORD"] diff --git a/sde_collections/management/commands/database_restore.py b/sde_collections/management/commands/database_restore.py index 2779cf51..ece94cce 100644 --- a/sde_collections/management/commands/database_restore.py +++ b/sde_collections/management/commands/database_restore.py @@ -16,6 +16,7 @@ from django.conf import settings from django.core.management.base import BaseCommand, CommandError +from django.db import connections class Server(enum.Enum): @@ -65,9 +66,32 @@ def run_psql_command(self, command: str, db_name: str = "postgres", env: dict = cmd = ["psql", "-h", db["host"], "-U", db["user"], "-d", db_name, "-c", command] subprocess.run(cmd, env=env, check=True) + def terminate_database_connections(self, env: dict) -> None: + """Terminate all connections to the database.""" + db = self.get_db_settings() + # Close Django's connection first + connections.close_all() + + # Terminate any remaining PostgreSQL connections + terminate_conn_sql = f""" + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE datname = '{db["name"]}' + AND pid <> pg_backend_pid(); + """ + try: + self.run_psql_command(terminate_conn_sql, env=env) + except subprocess.CalledProcessError: + # If this fails, it's usually because there are no connections to terminate + pass + def reset_database(self, env: dict) -> None: """Drop and recreate the database.""" db = self.get_db_settings() + + self.stdout.write(f"Terminating connections to {db['name']}...") + self.terminate_database_connections(env) + self.stdout.write(f"Dropping database {db['name']}...") self.run_psql_command(f"DROP DATABASE IF EXISTS {db['name']}", env=env) diff --git a/sde_collections/tests/test_database_backup.py b/sde_collections/tests/test_database_backup.py new file mode 100644 index 00000000..d8a7be54 --- /dev/null +++ b/sde_collections/tests/test_database_backup.py @@ -0,0 +1,190 @@ +# docker-compose -f local.yml run --rm django pytest sde_collections/tests/test_database_backup.py +import gzip +import os +import subprocess +from datetime import datetime +from unittest.mock import Mock, patch + +import pytest +from django.core.management import call_command + +from sde_collections.management.commands import database_backup +from sde_collections.management.commands.database_backup import ( + Server, + temp_file_handler, +) + + +@pytest.fixture +def mock_subprocess(): + with patch("subprocess.run") as mock_run: + mock_run.return_value.returncode = 0 + yield mock_run + + +@pytest.fixture +def mock_date(): + with patch("sde_collections.management.commands.database_backup.datetime") as mock_dt: + mock_dt.now.return_value = datetime(2024, 1, 15) + yield mock_dt + + +@pytest.fixture +def mock_settings(settings): + """Configure test database settings.""" + settings.DATABASES = { + "default": { + "HOST": "test-db-host", + "NAME": "test_db", + "USER": "test_user", + "PASSWORD": "test_password", + } + } + return settings + + +@pytest.fixture +def command(): + return database_backup.Command() + + +class TestBackupCommand: + def test_get_backup_filename_compressed(self, command, mock_date): + """Test backup filename generation with compression.""" + backup_file, dump_file = command.get_backup_filename(Server.STAGING, compress=True) + assert backup_file == "staging_backup_20240115.sql.gz" + assert dump_file == "staging_backup_20240115.sql" + + def test_get_backup_filename_uncompressed(self, command, mock_date): + """Test backup filename generation without compression.""" + backup_file, dump_file = command.get_backup_filename(Server.PRODUCTION, compress=False) + assert backup_file == "production_backup_20240115.sql" + assert dump_file == backup_file + + def test_run_pg_dump(self, command, mock_subprocess, mock_settings): + """Test pg_dump command execution.""" + env = {"PGPASSWORD": "test_password"} + command.run_pg_dump("test_output.sql", env) + + mock_subprocess.assert_called_once() + cmd_args = mock_subprocess.call_args[0][0] + assert cmd_args == [ + "pg_dump", + "-h", + "test-db-host", + "-U", + "test_user", + "-d", + "test_db", + "--no-owner", + "--no-privileges", + "-f", + "test_output.sql", + ] + + def test_compress_file(self, command, tmp_path): + """Test file compression.""" + input_file = tmp_path / "test.sql" + output_file = tmp_path / "test.sql.gz" + test_content = b"Test database content" + + # Create test input file + input_file.write_bytes(test_content) + + # Compress the file + command.compress_file(str(input_file), str(output_file)) + + # Verify compression + assert output_file.exists() + with gzip.open(output_file, "rb") as f: + assert f.read() == test_content + + def test_temp_file_handler_cleanup(self, tmp_path): + """Test temporary file cleanup.""" + test_file = tmp_path / "temp.sql" + test_file.touch() + + with temp_file_handler(str(test_file)): + assert test_file.exists() + assert not test_file.exists() + + def test_temp_file_handler_cleanup_on_error(self, tmp_path): + """Test temporary file cleanup when an error occurs.""" + test_file = tmp_path / "temp.sql" + test_file.touch() + + with pytest.raises(ValueError): + with temp_file_handler(str(test_file)): + assert test_file.exists() + raise ValueError("Test error") + assert not test_file.exists() + + @patch("socket.gethostname") + def test_server_detection(self, mock_hostname): + """Test server environment detection.""" + test_cases = [ + ("PRODUCTION-SERVER", Server.PRODUCTION), + ("STAGING-DB", Server.STAGING), + ("DEV-HOST", Server.UNKNOWN), + ] + + for hostname, expected_server in test_cases: + mock_hostname.return_value = hostname + with patch("sde_collections.management.commands.database_backup.detect_server") as mock_detect: + mock_detect.return_value = expected_server + server = database_backup.detect_server() + assert server == expected_server + + @pytest.mark.parametrize( + "compress,hostname", + [ + (True, "PRODUCTION-SERVER"), + (False, "STAGING-SERVER"), + (True, "UNKNOWN-SERVER"), + ], + ) + def test_handle_integration(self, compress, hostname, mock_subprocess, mock_date, mock_settings): + """Test full backup process integration.""" + with patch("socket.gethostname", return_value=hostname): + call_command("database_backup", no_compress=not compress) + + # Verify correct command execution + mock_subprocess.assert_called_once() + + # Verify correct filename used + cmd_args = mock_subprocess.call_args[0][0] + date_str = "20240115" + server_type = hostname.split("-")[0].lower() + expected_base = f"{server_type}_backup_{date_str}.sql" + + if compress: + assert cmd_args[-1] == expected_base # Temporary file + # Verify cleanup attempted + assert not os.path.exists(expected_base) + else: + assert cmd_args[-1] == expected_base + + def test_handle_pg_dump_error(self, mock_subprocess, mock_date): + """Test error handling when pg_dump fails.""" + mock_subprocess.side_effect = subprocess.CalledProcessError(1, "pg_dump") + + with patch("socket.gethostname", return_value="STAGING-SERVER"): + call_command("database_backup") + + # Verify error handling and cleanup + date_str = "20240115" + temp_file = f"staging_backup_{date_str}.sql" + assert not os.path.exists(temp_file) + + def test_handle_compression_error(self, mock_subprocess, mock_date, command): + """Test error handling during compression.""" + # Mock compression to fail + command.compress_file = Mock(side_effect=Exception("Compression failed")) + + with patch("socket.gethostname", return_value="STAGING-SERVER"): + call_command("database_backup") + + # Verify cleanup + date_str = "20240115" + temp_file = f"staging_backup_{date_str}.sql" + assert not os.path.exists(temp_file) diff --git a/sde_collections/tests/test_database_restore.py b/sde_collections/tests/test_database_restore.py new file mode 100644 index 00000000..21088ad0 --- /dev/null +++ b/sde_collections/tests/test_database_restore.py @@ -0,0 +1,269 @@ +# docker-compose -f local.yml run --rm django pytest sde_collections/tests/test_database_restore.py +import gzip +from unittest.mock import patch + +import pytest +from django.core.management import call_command +from django.core.management.base import CommandError +from django.db import connections + +from sde_collections.management.commands import database_restore +from sde_collections.models.collection import Collection +from sde_collections.models.delta_url import CuratedUrl, DeltaUrl, DumpUrl +from sde_collections.tests.factories import ( + CollectionFactory, + CuratedUrlFactory, + DeltaUrlFactory, + DumpUrlFactory, +) + +# Register the integration mark +pytest.mark.integration = pytest.mark.django_db(transaction=True) + + +@pytest.fixture +def mock_subprocess(): + with patch("subprocess.run") as mock_run: + mock_run.return_value.returncode = 0 + yield mock_run + + +@pytest.fixture +def mock_settings(settings): + """Configure test database settings.""" + settings.DATABASES = { + "default": { + "HOST": "test-db-host", + "NAME": "test_db", + "USER": "test_user", + "PASSWORD": "test_password", + } + } + return settings + + +@pytest.fixture +def command(): + return database_restore.Command() + + +@pytest.fixture +def backup_file(tmp_path): + """Create a temporary backup file.""" + backup_path = tmp_path / "test_backup.sql" + backup_path.write_text("-- Test backup content") + return str(backup_path) + + +@pytest.fixture +def compressed_backup_file(tmp_path): + """Create a temporary compressed backup file.""" + backup_path = tmp_path / "test_backup.sql.gz" + with gzip.open(backup_path, "wt") as f: + f.write("-- Test backup content") + return str(backup_path) + + +class TestRestoreCommand: + def test_get_db_settings(self, command, mock_settings): + """Test database settings retrieval.""" + settings = command.get_db_settings() + assert settings == { + "host": "test-db-host", + "name": "test_db", + "user": "test_user", + "password": "test_password", + } + + def test_run_psql_command(self, command, mock_subprocess, mock_settings): + """Test psql command execution.""" + env = {"PGPASSWORD": "test_password"} + command.run_psql_command("SELECT 1;", "test_db", env) + + mock_subprocess.assert_called_once() + cmd_args = mock_subprocess.call_args[0][0] + assert cmd_args == [ + "psql", + "-h", + "test-db-host", + "-U", + "test_user", + "-d", + "test_db", + "-c", + "SELECT 1;", + ] + + def test_reset_database(self, command, mock_subprocess, mock_settings): + """Test database reset process.""" + env = {"PGPASSWORD": "test_password"} + command.reset_database(env) + + # Verify drop, create and terminate connections commands were executed + assert mock_subprocess.call_count >= 2 + calls = mock_subprocess.call_args_list + assert any("DROP DATABASE" in call[0][0][-1] for call in calls) + assert any("CREATE DATABASE" in call[0][0][-1] for call in calls) + + def test_restore_backup(self, command, mock_subprocess, mock_settings, backup_file): + """Test backup restoration.""" + env = {"PGPASSWORD": "test_password"} + command.restore_backup(backup_file, env) + + mock_subprocess.assert_called_once() + cmd_args = mock_subprocess.call_args[0][0] + assert cmd_args == [ + "psql", + "-h", + "test-db-host", + "-U", + "test_user", + "-d", + "test_db", + "-f", + backup_file, + ] + + def test_decompress_file(self, command, tmp_path, compressed_backup_file): + """Test backup file decompression.""" + output_file = str(tmp_path / "decompressed.sql") + command.decompress_file(compressed_backup_file, output_file) + + with open(output_file) as f: + content = f.read() + assert content == "-- Test backup content" + + def test_handle_file_not_found(self, command): + """Test error handling for non-existent backup file.""" + with pytest.raises(CommandError): + call_command("database_restore", "nonexistent.sql") + + +@pytest.mark.django_db +class TestDatabaseIntegration: + """Integration tests for backup and restore functionality.""" + + def create_test_data(self): + """Create a set of test data using factories.""" + collection = CollectionFactory() + + # Create some URLs + dump_urls = DumpUrlFactory.create_batch(3, collection=collection) + curated_urls = CuratedUrlFactory.create_batch(3, collection=collection) + delta_urls = DeltaUrlFactory.create_batch(3, collection=collection) + + return { + "collection": collection, + "dump_urls": dump_urls, + "curated_urls": curated_urls, + "delta_urls": delta_urls, + } + + def verify_data_integrity(self, original_data): + """Verify that all data matches the original after restore.""" + # Close all existing database connections before verification + connections.close_all() + + # Verify collection + restored_collection = Collection.objects.get(pk=original_data["collection"].pk) + assert restored_collection.name == original_data["collection"].name + assert restored_collection.config_folder == original_data["collection"].config_folder + + # Verify URLs + for original_url in original_data["dump_urls"]: + restored_url = DumpUrl.objects.get(pk=original_url.pk) + assert restored_url.url == original_url.url + assert restored_url.scraped_title == original_url.scraped_title + + for original_url in original_data["curated_urls"]: + restored_url = CuratedUrl.objects.get(pk=original_url.pk) + assert restored_url.url == original_url.url + assert restored_url.scraped_title == original_url.scraped_title + + for original_url in original_data["delta_urls"]: + restored_url = DeltaUrl.objects.get(pk=original_url.pk) + assert restored_url.url == original_url.url + assert restored_url.scraped_title == original_url.scraped_title + + @pytest.mark.integration + def test_full_backup_restore_cycle(self, tmp_path): + """Test complete backup and restore cycle with actual data.""" + # Create test data + original_data = self.create_test_data() + + # Create backup + backup_file = str(tmp_path / "integration_test_backup.sql") + with patch("socket.gethostname", return_value="TEST-SERVER"): + connections.close_all() # Close connections before backup + call_command("database_backup", "--no-compress", output=backup_file) + + # Clear the database + for Model in [Collection, DumpUrl, CuratedUrl, DeltaUrl]: + Model.objects.all().delete() + + assert Collection.objects.count() == 0 + assert DumpUrl.objects.count() == 0 + assert CuratedUrl.objects.count() == 0 + assert DeltaUrl.objects.count() == 0 + + # Restore from backup + connections.close_all() # Close connections before restore + call_command("database_restore", backup_file) + + # Verify data integrity + self.verify_data_integrity(original_data) + + @pytest.mark.integration + def test_compressed_backup_restore_cycle(self, tmp_path): + """Test backup and restore cycle with compression.""" + # Create test data + original_data = self.create_test_data() + + # Create compressed backup + backup_file = str(tmp_path / "integration_test_backup.sql.gz") + with patch("socket.gethostname", return_value="TEST-SERVER"): + connections.close_all() # Close connections before backup + call_command("database_backup", output=backup_file) # Compression is enabled by default + + # Clear the database + connections.close_all() # Close connections before clearing + Collection.objects.all().delete() + + # Restore from compressed backup + connections.close_all() # Close connections before restore + call_command("database_restore", backup_file) + + # Verify data integrity + self.verify_data_integrity(original_data) + + @pytest.mark.integration + def test_partial_data_integrity(self, tmp_path): + """Test backup and restore with partial data modifications.""" + # Create initial data + original_data = self.create_test_data() + original_name = original_data["collection"].name + original_url_id = original_data["curated_urls"][0].id # Store the ID explicitly + + # Create backup + backup_file = str(tmp_path / "partial_test_backup.sql") + with patch("socket.gethostname", return_value="TEST-SERVER"): + connections.close_all() # Close connections before backup + call_command("database_backup", "--no-compress", output=backup_file) + + # Modify some data + collection = original_data["collection"] + collection.name = "Modified Name" + collection.save() + + new_curated_url = CuratedUrlFactory(collection=collection) + original_data["curated_urls"][0].delete() + + # Restore from backup + connections.close_all() # Close connections before restore + call_command("database_restore", backup_file) + + # Verify original state is restored + restored_collection = Collection.objects.get(pk=collection.pk) + assert restored_collection.name == original_name + assert not CuratedUrl.objects.filter(pk=new_curated_url.pk).exists() + assert CuratedUrl.objects.filter(pk=original_url_id).exists() # Use the stored ID