Skip to content

Commit

Permalink
Importers: migrate to options class (#10254)
Browse files Browse the repository at this point in the history
  • Loading branch information
Maffooch authored Jun 3, 2024
1 parent 20176f2 commit fc32c13
Show file tree
Hide file tree
Showing 18 changed files with 1,379 additions and 1,057 deletions.
40 changes: 29 additions & 11 deletions dojo/api_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2141,7 +2141,7 @@ def set_context(
"""
context = dict(data)
# update some vars
context["scan"] = data.get("file", None)
context["scan"] = data.pop("file", None)
context["environment"] = Development_Environment.objects.get(
name=data.get("environment", "Development")
)
Expand Down Expand Up @@ -2201,12 +2201,15 @@ def process_auto_create_create_context(
# Raise an explicit drf exception here
raise ValidationError(str(e))

def get_importer(self) -> BaseImporter:
def get_importer(
self,
**kwargs: dict,
) -> BaseImporter:
"""
Returns a new instance of an importer that extends
the BaseImporter class
"""
return DefaultImporter()
return DefaultImporter(**kwargs)

def process_scan(
self,
Expand All @@ -2220,8 +2223,9 @@ def process_scan(
Raises exceptions in the event of an error
"""
try:
context["test"], _, _, _, _, _, _ = self.get_importer().process_scan(
**context,
importer = self.get_importer(**context)
context["test"], _, _, _, _, _, _ = importer.process_scan(
context.pop("scan", None)
)
# Update the response body with some new data
if test := context.get("test"):
Expand Down Expand Up @@ -2472,19 +2476,25 @@ def process_auto_create_create_context(
# Raise an explicit drf exception here
raise ValidationError(str(e))

def get_importer(self) -> BaseImporter:
def get_importer(
self,
**kwargs: dict,
) -> BaseImporter:
"""
Returns a new instance of an importer that extends
the BaseImporter class
"""
return DefaultImporter()
return DefaultImporter(**kwargs)

def get_reimporter(self) -> BaseImporter:
def get_reimporter(
self,
**kwargs: dict,
) -> BaseImporter:
"""
Returns a new instance of a reimporter that extends
the BaseImporter class
"""
return DefaultReImporter()
return DefaultReImporter(**kwargs)

def process_scan(
self,
Expand All @@ -2502,14 +2512,22 @@ def process_scan(
try:
if test := context.get("test"):
statistics_before = test.statistics
context["test"], _, _, _, _, _, test_import = self.get_reimporter().process_scan(**context)
context["test"], _, _, _, _, _, test_import = self.get_reimporter(
**context
).process_scan(
context.pop("scan", None)
)
if test_import:
statistics_delta = test_import.statistics
elif context.get("auto_create_context"):
# Attempt to create an engagement
logger.debug("reimport for non-existing test, using import to create new test")
context["engagement"] = auto_create_manager.get_or_create_engagement(**context)
context["test"], _, _, _, _, _, _ = self.get_importer().process_scan(**context)
context["test"], _, _, _, _, _, _ = self.get_importer(
**context
).process_scan(
context.pop("scan", None)
)
else:
msg = "A test could not be found!"
raise NotFound(msg)
Expand Down
28 changes: 16 additions & 12 deletions dojo/endpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import re

from django.contrib import messages
from django.core.exceptions import MultipleObjectsReturned, ValidationError
from django.core.exceptions import ValidationError
from django.core.validators import validate_ipv46_address
from django.db import transaction
from django.db.models import Count, Q
from django.http import HttpResponseRedirect
from django.urls import reverse
Expand Down Expand Up @@ -73,17 +74,20 @@ def endpoint_filter(**kwargs):


def endpoint_get_or_create(**kwargs):

qs = endpoint_filter(**kwargs)

if qs.count() == 0:
return Endpoint.objects.get_or_create(**kwargs)

elif qs.count() == 1:
return qs.first(), False

else:
raise MultipleObjectsReturned()
with transaction.atomic():
qs = endpoint_filter(**kwargs)
count = qs.count()
if count == 0:
return Endpoint.objects.get_or_create(**kwargs)
else:
logger.warning(
f"Endpoints in your database are broken. "
f"Please access {reverse('endpoint_migrate')} and migrate them to new format or remove them."
)
# Get the oldest endpoint first, and return that instead
# a datetime is not captured on the endpoint model, so ID
# will have to work here instead
return qs.order_by("id").first(), False


def clean_hosts_run(apps, change):
Expand Down
7 changes: 2 additions & 5 deletions dojo/engagement/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,17 +927,14 @@ def import_findings(
Attempt to import with all the supplied information
"""
try:
importer_client = DefaultImporter()
importer_client = DefaultImporter(**context)
context["test"], _, finding_count, closed_finding_count, _, _, _ = importer_client.process_scan(
**context,
context.pop("scan", None)
)
# Add a message to the view for the user to see the results
add_success_message_to_response(importer_client.construct_imported_message(
context.get("scan_type"),
Test_Import.IMPORT_TYPE,
finding_count=finding_count,
closed_finding_count=closed_finding_count,
close_old_findings=context.get("close_old_findings"),
))
except Exception as e:
logger.exception(e)
Expand Down
56 changes: 30 additions & 26 deletions dojo/importers/auto_create_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

from crum import get_current_user
from django.db import transaction
from django.http.request import QueryDict
from django.utils import timezone

Expand Down Expand Up @@ -229,14 +230,15 @@ def get_or_create_product_type(
if product_type := self.get_target_product_type_if_exists(product_type_name=product_type_name):
return product_type
else:
product_type, created = Product_Type.objects.get_or_create(name=product_type_name)
if created:
Product_Type_Member.objects.create(
user=get_current_user(),
product_type=product_type,
role=Role.objects.get(is_owner=True),
)
return product_type
with transaction.atomic():
product_type, created = Product_Type.objects.select_for_update().get_or_create(name=product_type_name)
if created:
Product_Type_Member.objects.create(
user=get_current_user(),
product_type=product_type,
role=Role.objects.get(is_owner=True),
)
return product_type

def get_or_create_product(
self,
Expand All @@ -260,13 +262,14 @@ def get_or_create_product(
# Look for a product type first
product_type = self.get_or_create_product_type(product_type_name=product_type_name)
# Create the product
product, created = Product.objects.get_or_create(name=product_name, prod_type=product_type, description=product_name)
if created:
Product_Member.objects.create(
user=get_current_user(),
product=product,
role=Role.objects.get(is_owner=True),
)
with transaction.atomic():
product, created = Product.objects.select_for_update().get_or_create(name=product_name, prod_type=product_type, description=product_name)
if created:
Product_Member.objects.create(
user=get_current_user(),
product=product,
role=Role.objects.get(is_owner=True),
)

return product

Expand Down Expand Up @@ -313,17 +316,18 @@ def get_or_create_engagement(
if (target_end is None) or (target_start > target_end):
target_end = (timezone.now() + timedelta(days=365)).date()
# Create the engagement
return Engagement.objects.create(
engagement_type="CI/CD",
name=engagement_name,
product=product,
lead=get_current_user(),
target_start=target_start,
target_end=target_end,
status="In Progress",
deduplication_on_engagement=deduplication_on_engagement,
source_code_management_uri=source_code_management_uri,
)
with transaction.atomic():
return Engagement.objects.select_for_update().create(
engagement_type="CI/CD",
name=engagement_name,
product=product,
lead=get_current_user(),
target_start=target_start,
target_end=target_end,
status="In Progress",
deduplication_on_engagement=deduplication_on_engagement,
source_code_management_uri=source_code_management_uri,
)

"""
===================================
Expand Down
Loading

0 comments on commit fc32c13

Please sign in to comment.