Skip to content

Commit

Permalink
put common context code in base
Browse files Browse the repository at this point in the history
  • Loading branch information
hblankenship committed Oct 15, 2024
1 parent 20eee47 commit b4af07c
Showing 1 changed file with 23 additions and 68 deletions.
91 changes: 23 additions & 68 deletions dojo/api_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2228,24 +2228,7 @@ def validate_scan_date(self, value: str) -> None:
raise serializers.ValidationError(msg)
return value


class ImportScanSerializer(CommonImportScanSerializer):

engagement = serializers.PrimaryKeyRelatedField(
queryset=Engagement.objects.all(), required=False,
)

# extra fields populated in response
# need to use the _id suffix as without the serializer framework gets
# confused
test = serializers.IntegerField(
read_only=True,
) # left for backwards compatibility

def set_context(
self,
data: dict,
) -> dict:
def setup_common_context(self, data: dict) -> dict:
"""
Process all of the user supplied inputs to massage them into the correct
format the importer is expecting to see
Expand Down Expand Up @@ -2294,6 +2277,27 @@ def set_context(
if context.get("scan_date")
else None
)
return context


class ImportScanSerializer(CommonImportScanSerializer):

engagement = serializers.PrimaryKeyRelatedField(
queryset=Engagement.objects.all(), required=False,
)

# extra fields populated in response
# need to use the _id suffix as without the serializer framework gets
# confused
test = serializers.IntegerField(
read_only=True,
) # left for backwards compatibility

def set_context(
self,
data: dict,
) -> dict:
context = self.setup_common_context(dict)
# Process the auto create context inputs
self.process_auto_create_create_context(context)

Expand Down Expand Up @@ -2345,57 +2349,8 @@ def set_context(
self,
data: dict,
) -> dict:
"""
Process all of the user supplied inputs to massage them into the correct
format the importer is expecting to see
"""
context = dict(data)
# update some vars
context["scan"] = data.get("file", None)

if context.get("auto_create_context"):
environment = Development_Environment.objects.get_or_create(name=data.get("environment", "Development"))[0]
else:
try:
environment = Development_Environment.objects.get(name=data.get("environment", "Development"))
except:
msg = "Environment named " + data.get("environment") + " does not exist."
raise ValidationError(msg)

context["environment"] = environment

# Set the active/verified status based upon the overrides
if "active" in self.initial_data:
context["active"] = data.get("active")
else:
context["active"] = None
if "verified" in self.initial_data:
context["verified"] = data.get("verified")
else:
context["verified"] = None
# Change the way that endpoints are sent to the importer
if endpoints_to_add := data.get("endpoint_to_add"):
context["endpoints_to_add"] = [endpoints_to_add]
else:
context["endpoint_to_add"] = None
# Convert the tags to a list if needed. At this point, the
# TaggitListSerializer has already removed commas supplied
# by the user, so this operation will consistently return
# a list to be used by the importer
if tags := context.get("tags"):
if isinstance(tags, str):
context["tags"] = tags.split(", ")
# have to make the scan_date_time timezone aware otherwise uploads via
# the API would fail (but unit tests for api upload would pass...)
context["scan_date"] = (
timezone.make_aware(
datetime.combine(context.get("scan_date"), datetime.min.time()),
)
if context.get("scan_date")
else None
)

return context
return self.setup_common_context(data)

def process_auto_create_create_context(
self,
Expand Down

0 comments on commit b4af07c

Please sign in to comment.