From 11f685d6f90d7db566d13fa9cb33797e7dbf447d Mon Sep 17 00:00:00 2001 From: aloftus23 Date: Thu, 21 Nov 2024 09:45:20 -0500 Subject: [PATCH] Fix syncdb naming and dangerouslyforcs and make new users/me to have relations --- .../xfd_django/xfd_api/api_methods/user.py | 32 ++++++++++- .../xfd_api/management/commands/syncdb.py | 38 ++----------- .../xfd_django/xfd_api/tasks/run_syncdb.py | 31 +++++++++-- .../{syndb_helpers.py => syncdb_helpers.py} | 54 ++++++++++++++++++- backend/src/xfd_django/xfd_api/views.py | 4 +- .../xfd_django/middleware/middleware.py | 10 ++-- 6 files changed, 124 insertions(+), 45 deletions(-) rename backend/src/xfd_django/xfd_api/tasks/{syndb_helpers.py => syncdb_helpers.py} (71%) diff --git a/backend/src/xfd_django/xfd_api/api_methods/user.py b/backend/src/xfd_django/xfd_api/api_methods/user.py index fdb0f07a..de03b2c3 100644 --- a/backend/src/xfd_django/xfd_api/api_methods/user.py +++ b/backend/src/xfd_django/xfd_api/api_methods/user.py @@ -7,11 +7,41 @@ # Third-Party Libraries from fastapi import HTTPException, Query +from django.db.models import Prefetch +from django.forms.models import model_to_dict -from ..models import User +from ..models import User, Role from ..schema_models.user import User as UserSchema +def get_me(current_user): + """Get current user.""" + # Fetch the user and related objects from the database + user = User.objects.prefetch_related( + Prefetch('roles', queryset=Role.objects.select_related('organization')), + Prefetch('apiKeys') + ).get(id=str(current_user.id)) + + # Convert the user object to a dictionary + user_dict = model_to_dict(user, exclude=['password']) + + # Include roles with their related organization + user_dict['roles'] = [ + { + "id": role.id, + "role": role.role, + "approved": role.approved, + "organization": model_to_dict(role.organization) if role.organization else None + } + for role in user.roles.all() + ] + + # Include API keys + user_dict['apiKeys'] = list(user.apiKeys.values('id', 'createdAt', 'updatedAt', 'lastUsed', 'hashedKey', 'lastFour')) + + return user_dict + + def get_users(regionId): """ Retrieve a list of users based on optional filter parameters. diff --git a/backend/src/xfd_django/xfd_api/management/commands/syncdb.py b/backend/src/xfd_django/xfd_api/management/commands/syncdb.py index 5249bf01..fea2861a 100644 --- a/backend/src/xfd_django/xfd_api/management/commands/syncdb.py +++ b/backend/src/xfd_django/xfd_api/management/commands/syncdb.py @@ -1,35 +1,9 @@ -# Standard Python Libraries -import json -import os -import random # Third-Party Libraries -from django.conf import settings from django.core.management import call_command from django.core.management.base import BaseCommand -from xfd_api.tasks.es_client import ESClient -from xfd_api.tasks.syndb_helpers import manage_elasticsearch_indices, populate_sample_data - -# Sample data and helper data for random generation -SAMPLE_TAG_NAME = "Sample Data" -NUM_SAMPLE_ORGS = 10 -NUM_SAMPLE_DOMAINS = 10 -PROB_SAMPLE_SERVICES = 0.5 -PROB_SAMPLE_VULNERABILITIES = 0.5 -SAMPLE_STATES = ["VA", "CA", "CO"] -SAMPLE_REGION_IDS = ["1", "2", "3"] - -SAMPLE_DATA_DIR = os.path.join(settings.BASE_DIR, "xfd_api", "tasks", "sample_data") -services = json.load(open(os.path.join(SAMPLE_DATA_DIR, "services.json"))) -cpes = json.load(open(os.path.join(SAMPLE_DATA_DIR, "cpes.json"))) -vulnerabilities = json.load(open(os.path.join(SAMPLE_DATA_DIR, "vulnerabilities.json"))) -cves = json.load(open(os.path.join(SAMPLE_DATA_DIR, "cves.json"))) -nouns = json.load(open(os.path.join(SAMPLE_DATA_DIR, "nouns.json"))) -adjectives = json.load(open(os.path.join(SAMPLE_DATA_DIR, "adjectives.json"))) - -# Initialize Elasticsearch client -es_client = ESClient() - +from backend.src.xfd_django.xfd_api.tasks.syncdb_helpers import manage_elasticsearch_indices, populate_sample_data +from xfd_api.tasks.run_syncdb import drop_all_tables, synchronize class Command(BaseCommand): help = "Synchronizes and populates the database with optional sample data, and manages Elasticsearch indices." @@ -55,13 +29,11 @@ def handle(self, *args, **options): # Step 1: Database Reset and Migration if dangerouslyforce: self.stdout.write("Dropping and recreating the database...") - call_command("flush", "--noinput") - call_command("makemigrations") - call_command("migrate") + drop_all_tables() + synchronize() else: self.stdout.write("Applying migrations...") - call_command("makemigrations") - call_command("migrate") + synchronize() # Step 2: Elasticsearch Index Management manage_elasticsearch_indices(dangerouslyforce) diff --git a/backend/src/xfd_django/xfd_api/tasks/run_syncdb.py b/backend/src/xfd_django/xfd_api/tasks/run_syncdb.py index ecc89fc1..72f6ccd5 100644 --- a/backend/src/xfd_django/xfd_api/tasks/run_syncdb.py +++ b/backend/src/xfd_django/xfd_api/tasks/run_syncdb.py @@ -15,7 +15,7 @@ # Initialize Django django.setup() -from xfd_api.tasks.syndb_helpers import manage_elasticsearch_indices, populate_sample_data +from backend.src.xfd_django.xfd_api.tasks.syncdb_helpers import manage_elasticsearch_indices, populate_sample_data def handler(event, context): @@ -29,7 +29,7 @@ def handler(event, context): # Drop and recreate the database if dangerouslyforce is true if dangerouslyforce: print("Dropping and recreating the database...") - call_command("flush", "--noinput") + drop_all_tables() # Generate and apply migrations dynamically print("Applying migrations dynamically...") @@ -154,7 +154,10 @@ def process_m2m_tables(schema_editor: BaseDatabaseSchemaEditor, cursor): if not table_exists: print(f"Creating Many-to-Many table: {m2m_table_name}") - schema_editor.create_model(model) + schema_editor.create_model(field.remote_field.through) + else: + print(f"Many-to-Many table {m2m_table_name} already exists. Skipping.") + def update_table(schema_editor: BaseDatabaseSchemaEditor, model): @@ -209,3 +212,25 @@ def cleanup_stale_tables(cursor): cursor.execute(f"DROP TABLE {table} CASCADE;") except Exception as e: print(f"Error dropping stale table {table}: {e}") + +def drop_all_tables(): + """ + Drops all tables in the database. Used with `dangerouslyforce`. + """ + with connection.cursor() as cursor: + cursor.execute( + """ + DO $$ DECLARE + r RECORD; + BEGIN + FOR r IN ( + SELECT tablename + FROM pg_tables + WHERE schemaname = 'public' + ) LOOP + EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE'; + END LOOP; + END $$; + """ + ) + print("All tables dropped successfully.") \ No newline at end of file diff --git a/backend/src/xfd_django/xfd_api/tasks/syndb_helpers.py b/backend/src/xfd_django/xfd_api/tasks/syncdb_helpers.py similarity index 71% rename from backend/src/xfd_django/xfd_api/tasks/syndb_helpers.py rename to backend/src/xfd_django/xfd_api/tasks/syncdb_helpers.py index c93d8346..8142771d 100644 --- a/backend/src/xfd_django/xfd_api/tasks/syndb_helpers.py +++ b/backend/src/xfd_django/xfd_api/tasks/syncdb_helpers.py @@ -2,10 +2,13 @@ import json import os import random +import hashlib +import secrets from django.conf import settings from django.db import transaction -from xfd_api.models import Domain, Organization, OrganizationTag, Service, Vulnerability +from xfd_api.models import ApiKey, Domain, Organization, OrganizationTag, Service, Vulnerability, UserType, User from xfd_api.tasks.es_client import ESClient +from datetime import datetime, timezone # Constants for sample data generation SAMPLE_TAG_NAME = "Sample Data" @@ -13,7 +16,7 @@ NUM_SAMPLE_DOMAINS = 10 PROB_SAMPLE_SERVICES = 0.5 PROB_SAMPLE_VULNERABILITIES = 0.5 -SAMPLE_STATES = ["VA", "CA", "CO"] +SAMPLE_STATES = ["Virginia", "California", "Colorado"] SAMPLE_REGION_IDS = ["1", "2", "3"] # Load sample data files @@ -46,6 +49,7 @@ def populate_sample_data(): with transaction.atomic(): tag, _ = OrganizationTag.objects.get_or_create(name=SAMPLE_TAG_NAME) for _ in range(NUM_SAMPLE_ORGS): + # Create organization org = Organization.objects.create( acronym="".join(random.choices("ABCDEFGHIJKLMNOPQRSTUVWXYZ", k=5)), name=generate_random_name(), @@ -57,9 +61,55 @@ def populate_sample_data(): ) org.tags.add(tag) + # Create sample domains, services, and vulnerabilities for _ in range(NUM_SAMPLE_DOMAINS): domain = create_sample_domain(org) create_sample_services_and_vulnerabilities(domain) + + # Create a user for the organization + user = create_sample_user(org) + + # Create an API key for the user + create_api_key_for_user(user) + + + +def create_sample_user(organization): + """Create a sample user linked to an organization.""" + user = User.objects.create( + firstName="Sample", + lastName="User", + email=f"user{random.randint(1, 1000)}@example.com", + userType=UserType.GLOBAL_ADMIN, + state=random.choice(SAMPLE_STATES), + regionId=random.choice(SAMPLE_REGION_IDS), + ) + # Set user as the creator of the organization (optional) + organization.createdBy = user + organization.save() + return user + + +def create_api_key_for_user(user): + """Create a sample API key linked to a user.""" + + # Generate a random 16-byte API key + key = secrets.token_hex(16) + + # Hash the API key + hashed_key = hashlib.sha256(key.encode()).hexdigest() + + # Create the API key record + ApiKey.objects.create( + hashedKey=hashed_key, + lastFour=key[-4:], + userId=user, + createdAt=datetime.utcnow(), + updatedAt=datetime.utcnow(), + ) + + # Print the raw key for debugging or manual testing + print(f"Created API key for user {user.email}: {key}") def generate_random_name(): diff --git a/backend/src/xfd_django/xfd_api/views.py b/backend/src/xfd_django/xfd_api/views.py index a245643c..f8e029e1 100644 --- a/backend/src/xfd_django/xfd_api/views.py +++ b/backend/src/xfd_django/xfd_api/views.py @@ -25,7 +25,7 @@ update_saved_search, ) from .api_methods.search import export, search_post -from .api_methods.user import get_users +from .api_methods.user import get_users, get_me from .api_methods.vulnerability import ( get_vulnerability_by_id, search_vulnerabilities, @@ -302,7 +302,7 @@ async def callback_route(request: Request): # GET Current User @api_router.get("/users/me", tags=["users"]) async def read_users_me(current_user: User = Depends(get_current_active_user)): - return current_user + return get_me(current_user) @api_router.get( diff --git a/backend/src/xfd_django/xfd_django/middleware/middleware.py b/backend/src/xfd_django/xfd_django/middleware/middleware.py index 96e959bf..80c0a0a4 100644 --- a/backend/src/xfd_django/xfd_django/middleware/middleware.py +++ b/backend/src/xfd_django/xfd_django/middleware/middleware.py @@ -36,6 +36,9 @@ async def dispatch(self, request: Request, call_next): # Default to "undefined" for userEmail if not provided user_email = request.state.user_email if hasattr(request.state, "user_email") else "undefined" + # Proceed with the request + response = await call_next(request) + # Prepare log details log_info = { "httpMethod": method, @@ -44,12 +47,11 @@ async def dispatch(self, request: Request, call_next): "path": path, "headers": headers, "userEmail": user_email, + "statusCode": response.status_code, "timestamp": datetime.utcnow().isoformat(), + "requestId": request_id } - # Log the request + # Log in JSON format self.logger.info(log_info) - - # Proceed with the request - response = await call_next(request) return response