Skip to content

Commit

Permalink
refactor and add tests for database restores
Browse files Browse the repository at this point in the history
  • Loading branch information
CarsonDavis committed Dec 10, 2024
1 parent 48c66b8 commit bbb0d4a
Show file tree
Hide file tree
Showing 4 changed files with 512 additions and 6 deletions.
35 changes: 29 additions & 6 deletions sde_collections/management/commands/database_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Usage:
docker-compose -f local.yml run --rm django python manage.py database_backup
docker-compose -f local.yml run --rm django python manage.py database_backup --no-compress
docker-compose -f local.yml run --rm django python manage.py database_backup --output /path/to/output.sql
docker-compose -f production.yml run --rm django python manage.py database_backup
"""

Expand Down Expand Up @@ -54,19 +55,41 @@ def add_arguments(self, parser):
action="store_true",
help="Disable backup file compression (enabled by default)",
)
parser.add_argument(
"--output",
type=str,
help="Output file path (default: auto-generated based on server name and date)",
)

def get_backup_filename(self, server: Server, compress: bool) -> tuple[str, str]:
def get_backup_filename(self, server: Server, compress: bool, custom_output: str = None) -> tuple[str, str]:
"""Generate backup filename and actual dump path.
Args:
server: Server enum indicating the environment
compress: Whether the output should be compressed
custom_output: Optional custom output path
Returns:
tuple[str, str]: A tuple containing (final_filename, temp_filename)
- final_filename: The name of the final backup file (with .gz if compressed)
- temp_filename: The name of the temporary dump file (always without .gz)
"""
date_str = datetime.now().strftime("%Y%m%d")
temp_filename = f"{server.value.lower()}_backup_{date_str}.sql"
final_filename = f"{temp_filename}.gz" if compress else temp_filename
return final_filename, temp_filename
if custom_output:
# Ensure the output directory exists
output_dir = os.path.dirname(custom_output)
if output_dir:
os.makedirs(output_dir, exist_ok=True)

if compress:
return custom_output + (".gz" if not custom_output.endswith(".gz") else ""), custom_output.removesuffix(
".gz"
)
return custom_output, custom_output
else:
date_str = datetime.now().strftime("%Y%m%d")
temp_filename = f"{server.value.lower()}_backup_{date_str}.sql"
final_filename = f"{temp_filename}.gz" if compress else temp_filename
return final_filename, temp_filename

def run_pg_dump(self, output_file: str, env: dict) -> None:
"""Execute pg_dump with given parameters."""
Expand Down Expand Up @@ -95,7 +118,7 @@ def compress_file(self, input_file: str, output_file: str) -> None:
def handle(self, *args, **options):
server = detect_server()
compress = not options["no_compress"]
backup_file, dump_file = self.get_backup_filename(server, compress)
backup_file, dump_file = self.get_backup_filename(server, compress, options.get("output"))

env = os.environ.copy()
env["PGPASSWORD"] = settings.DATABASES["default"]["PASSWORD"]
Expand Down
24 changes: 24 additions & 0 deletions sde_collections/management/commands/database_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django.db import connections


class Server(enum.Enum):
Expand Down Expand Up @@ -65,9 +66,32 @@ def run_psql_command(self, command: str, db_name: str = "postgres", env: dict =
cmd = ["psql", "-h", db["host"], "-U", db["user"], "-d", db_name, "-c", command]
subprocess.run(cmd, env=env, check=True)

def terminate_database_connections(self, env: dict) -> None:
"""Terminate all connections to the database."""
db = self.get_db_settings()
# Close Django's connection first
connections.close_all()

# Terminate any remaining PostgreSQL connections
terminate_conn_sql = f"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = '{db["name"]}'
AND pid <> pg_backend_pid();
"""
try:
self.run_psql_command(terminate_conn_sql, env=env)
except subprocess.CalledProcessError:
# If this fails, it's usually because there are no connections to terminate
pass

def reset_database(self, env: dict) -> None:
"""Drop and recreate the database."""
db = self.get_db_settings()

self.stdout.write(f"Terminating connections to {db['name']}...")
self.terminate_database_connections(env)

self.stdout.write(f"Dropping database {db['name']}...")
self.run_psql_command(f"DROP DATABASE IF EXISTS {db['name']}", env=env)

Expand Down
190 changes: 190 additions & 0 deletions sde_collections/tests/test_database_backup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# docker-compose -f local.yml run --rm django pytest sde_collections/tests/test_database_backup.py
import gzip
import os
import subprocess
from datetime import datetime
from unittest.mock import Mock, patch

import pytest
from django.core.management import call_command

from sde_collections.management.commands import database_backup
from sde_collections.management.commands.database_backup import (
Server,
temp_file_handler,
)


@pytest.fixture
def mock_subprocess():
with patch("subprocess.run") as mock_run:
mock_run.return_value.returncode = 0
yield mock_run


@pytest.fixture
def mock_date():
with patch("sde_collections.management.commands.database_backup.datetime") as mock_dt:
mock_dt.now.return_value = datetime(2024, 1, 15)
yield mock_dt


@pytest.fixture
def mock_settings(settings):
"""Configure test database settings."""
settings.DATABASES = {
"default": {
"HOST": "test-db-host",
"NAME": "test_db",
"USER": "test_user",
"PASSWORD": "test_password",
}
}
return settings


@pytest.fixture
def command():
return database_backup.Command()


class TestBackupCommand:
def test_get_backup_filename_compressed(self, command, mock_date):
"""Test backup filename generation with compression."""
backup_file, dump_file = command.get_backup_filename(Server.STAGING, compress=True)
assert backup_file == "staging_backup_20240115.sql.gz"
assert dump_file == "staging_backup_20240115.sql"

def test_get_backup_filename_uncompressed(self, command, mock_date):
"""Test backup filename generation without compression."""
backup_file, dump_file = command.get_backup_filename(Server.PRODUCTION, compress=False)
assert backup_file == "production_backup_20240115.sql"
assert dump_file == backup_file

def test_run_pg_dump(self, command, mock_subprocess, mock_settings):
"""Test pg_dump command execution."""
env = {"PGPASSWORD": "test_password"}
command.run_pg_dump("test_output.sql", env)

mock_subprocess.assert_called_once()
cmd_args = mock_subprocess.call_args[0][0]
assert cmd_args == [
"pg_dump",
"-h",
"test-db-host",
"-U",
"test_user",
"-d",
"test_db",
"--no-owner",
"--no-privileges",
"-f",
"test_output.sql",
]

def test_compress_file(self, command, tmp_path):
"""Test file compression."""
input_file = tmp_path / "test.sql"
output_file = tmp_path / "test.sql.gz"
test_content = b"Test database content"

# Create test input file
input_file.write_bytes(test_content)

# Compress the file
command.compress_file(str(input_file), str(output_file))

# Verify compression
assert output_file.exists()
with gzip.open(output_file, "rb") as f:
assert f.read() == test_content

def test_temp_file_handler_cleanup(self, tmp_path):
"""Test temporary file cleanup."""
test_file = tmp_path / "temp.sql"
test_file.touch()

with temp_file_handler(str(test_file)):
assert test_file.exists()
assert not test_file.exists()

def test_temp_file_handler_cleanup_on_error(self, tmp_path):
"""Test temporary file cleanup when an error occurs."""
test_file = tmp_path / "temp.sql"
test_file.touch()

with pytest.raises(ValueError):
with temp_file_handler(str(test_file)):
assert test_file.exists()
raise ValueError("Test error")
assert not test_file.exists()

@patch("socket.gethostname")
def test_server_detection(self, mock_hostname):
"""Test server environment detection."""
test_cases = [
("PRODUCTION-SERVER", Server.PRODUCTION),
("STAGING-DB", Server.STAGING),
("DEV-HOST", Server.UNKNOWN),
]

for hostname, expected_server in test_cases:
mock_hostname.return_value = hostname
with patch("sde_collections.management.commands.database_backup.detect_server") as mock_detect:
mock_detect.return_value = expected_server
server = database_backup.detect_server()
assert server == expected_server

@pytest.mark.parametrize(
"compress,hostname",
[
(True, "PRODUCTION-SERVER"),
(False, "STAGING-SERVER"),
(True, "UNKNOWN-SERVER"),
],
)
def test_handle_integration(self, compress, hostname, mock_subprocess, mock_date, mock_settings):
"""Test full backup process integration."""
with patch("socket.gethostname", return_value=hostname):
call_command("database_backup", no_compress=not compress)

# Verify correct command execution
mock_subprocess.assert_called_once()

# Verify correct filename used
cmd_args = mock_subprocess.call_args[0][0]
date_str = "20240115"
server_type = hostname.split("-")[0].lower()
expected_base = f"{server_type}_backup_{date_str}.sql"

if compress:
assert cmd_args[-1] == expected_base # Temporary file
# Verify cleanup attempted
assert not os.path.exists(expected_base)
else:
assert cmd_args[-1] == expected_base

def test_handle_pg_dump_error(self, mock_subprocess, mock_date):
"""Test error handling when pg_dump fails."""
mock_subprocess.side_effect = subprocess.CalledProcessError(1, "pg_dump")

with patch("socket.gethostname", return_value="STAGING-SERVER"):
call_command("database_backup")

# Verify error handling and cleanup
date_str = "20240115"
temp_file = f"staging_backup_{date_str}.sql"
assert not os.path.exists(temp_file)

def test_handle_compression_error(self, mock_subprocess, mock_date, command):
"""Test error handling during compression."""
# Mock compression to fail
command.compress_file = Mock(side_effect=Exception("Compression failed"))

with patch("socket.gethostname", return_value="STAGING-SERVER"):
call_command("database_backup")

# Verify cleanup
date_str = "20240115"
temp_file = f"staging_backup_{date_str}.sql"
assert not os.path.exists(temp_file)
Loading

0 comments on commit bbb0d4a

Please sign in to comment.