Skip to content

Commit

Permalink
Fix syncdb naming and dangerouslyforcs and make new users/me to have …
Browse files Browse the repository at this point in the history
…relations
  • Loading branch information
aloftus23 committed Nov 21, 2024
1 parent 202a181 commit 11f685d
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 45 deletions.
32 changes: 31 additions & 1 deletion backend/src/xfd_django/xfd_api/api_methods/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,41 @@

# Third-Party Libraries
from fastapi import HTTPException, Query
from django.db.models import Prefetch
from django.forms.models import model_to_dict

from ..models import User
from ..models import User, Role
from ..schema_models.user import User as UserSchema


def get_me(current_user):
"""Get current user."""
# Fetch the user and related objects from the database
user = User.objects.prefetch_related(
Prefetch('roles', queryset=Role.objects.select_related('organization')),
Prefetch('apiKeys')
).get(id=str(current_user.id))

# Convert the user object to a dictionary
user_dict = model_to_dict(user, exclude=['password'])

# Include roles with their related organization
user_dict['roles'] = [
{
"id": role.id,
"role": role.role,
"approved": role.approved,
"organization": model_to_dict(role.organization) if role.organization else None
}
for role in user.roles.all()
]

# Include API keys
user_dict['apiKeys'] = list(user.apiKeys.values('id', 'createdAt', 'updatedAt', 'lastUsed', 'hashedKey', 'lastFour'))

return user_dict


def get_users(regionId):
"""
Retrieve a list of users based on optional filter parameters.
Expand Down
38 changes: 5 additions & 33 deletions backend/src/xfd_django/xfd_api/management/commands/syncdb.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,9 @@
# Standard Python Libraries
import json
import os
import random

# Third-Party Libraries
from django.conf import settings
from django.core.management import call_command
from django.core.management.base import BaseCommand
from xfd_api.tasks.es_client import ESClient
from xfd_api.tasks.syndb_helpers import manage_elasticsearch_indices, populate_sample_data

# Sample data and helper data for random generation
SAMPLE_TAG_NAME = "Sample Data"
NUM_SAMPLE_ORGS = 10
NUM_SAMPLE_DOMAINS = 10
PROB_SAMPLE_SERVICES = 0.5
PROB_SAMPLE_VULNERABILITIES = 0.5
SAMPLE_STATES = ["VA", "CA", "CO"]
SAMPLE_REGION_IDS = ["1", "2", "3"]

SAMPLE_DATA_DIR = os.path.join(settings.BASE_DIR, "xfd_api", "tasks", "sample_data")
services = json.load(open(os.path.join(SAMPLE_DATA_DIR, "services.json")))
cpes = json.load(open(os.path.join(SAMPLE_DATA_DIR, "cpes.json")))
vulnerabilities = json.load(open(os.path.join(SAMPLE_DATA_DIR, "vulnerabilities.json")))
cves = json.load(open(os.path.join(SAMPLE_DATA_DIR, "cves.json")))
nouns = json.load(open(os.path.join(SAMPLE_DATA_DIR, "nouns.json")))
adjectives = json.load(open(os.path.join(SAMPLE_DATA_DIR, "adjectives.json")))

# Initialize Elasticsearch client
es_client = ESClient()

from backend.src.xfd_django.xfd_api.tasks.syncdb_helpers import manage_elasticsearch_indices, populate_sample_data
from xfd_api.tasks.run_syncdb import drop_all_tables, synchronize

class Command(BaseCommand):
help = "Synchronizes and populates the database with optional sample data, and manages Elasticsearch indices."
Expand All @@ -55,13 +29,11 @@ def handle(self, *args, **options):
# Step 1: Database Reset and Migration
if dangerouslyforce:
self.stdout.write("Dropping and recreating the database...")
call_command("flush", "--noinput")
call_command("makemigrations")
call_command("migrate")
drop_all_tables()
synchronize()
else:
self.stdout.write("Applying migrations...")
call_command("makemigrations")
call_command("migrate")
synchronize()

# Step 2: Elasticsearch Index Management
manage_elasticsearch_indices(dangerouslyforce)
Expand Down
31 changes: 28 additions & 3 deletions backend/src/xfd_django/xfd_api/tasks/run_syncdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Initialize Django
django.setup()

from xfd_api.tasks.syndb_helpers import manage_elasticsearch_indices, populate_sample_data
from backend.src.xfd_django.xfd_api.tasks.syncdb_helpers import manage_elasticsearch_indices, populate_sample_data


def handler(event, context):
Expand All @@ -29,7 +29,7 @@ def handler(event, context):
# Drop and recreate the database if dangerouslyforce is true
if dangerouslyforce:
print("Dropping and recreating the database...")
call_command("flush", "--noinput")
drop_all_tables()

# Generate and apply migrations dynamically
print("Applying migrations dynamically...")
Expand Down Expand Up @@ -154,7 +154,10 @@ def process_m2m_tables(schema_editor: BaseDatabaseSchemaEditor, cursor):

if not table_exists:
print(f"Creating Many-to-Many table: {m2m_table_name}")
schema_editor.create_model(model)
schema_editor.create_model(field.remote_field.through)
else:
print(f"Many-to-Many table {m2m_table_name} already exists. Skipping.")



def update_table(schema_editor: BaseDatabaseSchemaEditor, model):
Expand Down Expand Up @@ -209,3 +212,25 @@ def cleanup_stale_tables(cursor):
cursor.execute(f"DROP TABLE {table} CASCADE;")
except Exception as e:
print(f"Error dropping stale table {table}: {e}")

def drop_all_tables():
"""
Drops all tables in the database. Used with `dangerouslyforce`.
"""
with connection.cursor() as cursor:
cursor.execute(
"""
DO $$ DECLARE
r RECORD;
BEGIN
FOR r IN (
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
) LOOP
EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
END LOOP;
END $$;
"""
)
print("All tables dropped successfully.")
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
import json
import os
import random
import hashlib
import secrets
from django.conf import settings
from django.db import transaction
from xfd_api.models import Domain, Organization, OrganizationTag, Service, Vulnerability
from xfd_api.models import ApiKey, Domain, Organization, OrganizationTag, Service, Vulnerability, UserType, User
from xfd_api.tasks.es_client import ESClient
from datetime import datetime, timezone

# Constants for sample data generation
SAMPLE_TAG_NAME = "Sample Data"
NUM_SAMPLE_ORGS = 10
NUM_SAMPLE_DOMAINS = 10
PROB_SAMPLE_SERVICES = 0.5
PROB_SAMPLE_VULNERABILITIES = 0.5
SAMPLE_STATES = ["VA", "CA", "CO"]
SAMPLE_STATES = ["Virginia", "California", "Colorado"]
SAMPLE_REGION_IDS = ["1", "2", "3"]

# Load sample data files
Expand Down Expand Up @@ -46,6 +49,7 @@ def populate_sample_data():
with transaction.atomic():
tag, _ = OrganizationTag.objects.get_or_create(name=SAMPLE_TAG_NAME)
for _ in range(NUM_SAMPLE_ORGS):
# Create organization
org = Organization.objects.create(
acronym="".join(random.choices("ABCDEFGHIJKLMNOPQRSTUVWXYZ", k=5)),
name=generate_random_name(),
Expand All @@ -57,9 +61,55 @@ def populate_sample_data():
)
org.tags.add(tag)

# Create sample domains, services, and vulnerabilities
for _ in range(NUM_SAMPLE_DOMAINS):
domain = create_sample_domain(org)
create_sample_services_and_vulnerabilities(domain)

# Create a user for the organization
user = create_sample_user(org)

# Create an API key for the user
create_api_key_for_user(user)



def create_sample_user(organization):
"""Create a sample user linked to an organization."""
user = User.objects.create(
firstName="Sample",
lastName="User",
email=f"user{random.randint(1, 1000)}@example.com",
userType=UserType.GLOBAL_ADMIN,
state=random.choice(SAMPLE_STATES),
regionId=random.choice(SAMPLE_REGION_IDS),
)
# Set user as the creator of the organization (optional)
organization.createdBy = user
organization.save()
return user


def create_api_key_for_user(user):
"""Create a sample API key linked to a user."""

# Generate a random 16-byte API key
key = secrets.token_hex(16)

# Hash the API key
hashed_key = hashlib.sha256(key.encode()).hexdigest()

# Create the API key record
ApiKey.objects.create(
hashedKey=hashed_key,
lastFour=key[-4:],
userId=user,
createdAt=datetime.utcnow(),
updatedAt=datetime.utcnow(),
)

# Print the raw key for debugging or manual testing
print(f"Created API key for user {user.email}: {key}")


def generate_random_name():
Expand Down
4 changes: 2 additions & 2 deletions backend/src/xfd_django/xfd_api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
update_saved_search,
)
from .api_methods.search import export, search_post
from .api_methods.user import get_users
from .api_methods.user import get_users, get_me
from .api_methods.vulnerability import (
get_vulnerability_by_id,
search_vulnerabilities,
Expand Down Expand Up @@ -302,7 +302,7 @@ async def callback_route(request: Request):
# GET Current User
@api_router.get("/users/me", tags=["users"])
async def read_users_me(current_user: User = Depends(get_current_active_user)):
return current_user
return get_me(current_user)


@api_router.get(
Expand Down
10 changes: 6 additions & 4 deletions backend/src/xfd_django/xfd_django/middleware/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ async def dispatch(self, request: Request, call_next):
# Default to "undefined" for userEmail if not provided
user_email = request.state.user_email if hasattr(request.state, "user_email") else "undefined"

# Proceed with the request
response = await call_next(request)

# Prepare log details
log_info = {
"httpMethod": method,
Expand All @@ -44,12 +47,11 @@ async def dispatch(self, request: Request, call_next):
"path": path,
"headers": headers,
"userEmail": user_email,
"statusCode": response.status_code,
"timestamp": datetime.utcnow().isoformat(),
"requestId": request_id
}

# Log the request
# Log in JSON format
self.logger.info(log_info)

# Proceed with the request
response = await call_next(request)
return response

0 comments on commit 11f685d

Please sign in to comment.