Skip to content

Commit

Permalink
Fix incorrect attempt to overwrite object in update_or_create due to …
Browse files Browse the repository at this point in the history
…uppercase URNs (#1081)
  • Loading branch information
eric-intuitem authored Nov 23, 2024
2 parents dad1bb5 + 0aaaf26 commit 16db8bc
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 13 deletions.
42 changes: 42 additions & 0 deletions backend/core/migrations/0039_make_urn_lowercase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Generated by Django 5.1.1 on 2024-11-23 07:58

# Explain this bug: django.db.transaction.TransactionManagementError: An error occurred in the current transaction. You can't execute queries until the end of the 'atomic' block.

from django.db import migrations
from django.db.models.functions import Lower


def make_urn_lowercase(apps, schema_editor):
Threat = apps.get_model("core", "Threat")
ReferenceControl = apps.get_model("core", "ReferenceControl")
RiskMatrix = apps.get_model("core", "RiskMatrix")
Framework = apps.get_model("core", "Framework")
RequirementNode = apps.get_model("core", "RequirementNode")
RequirementMappingSet = apps.get_model("core", "RequirementMappingSet")
StoredLibrary = apps.get_model("core", "StoredLibrary")
LoadedLibrary = apps.get_model("core", "LoadedLibrary")

models = [
Threat,
ReferenceControl,
RiskMatrix,
Framework,
RequirementMappingSet,
StoredLibrary,
LoadedLibrary,
]
for model in models:
model.objects.filter(urn__isnull=False).update(urn=Lower("urn"))

RequirementNode.objects.filter(urn__isnull=False).update(urn=Lower("urn"))
RequirementNode.objects.filter(parent_urn__isnull=False).update(
parent_urn=Lower("parent_urn")
)


class Migration(migrations.Migration):
dependencies = [
("core", "0038_asset_disaster_recovery_objectives_and_more"),
]

operations = [migrations.RunPython(make_urn_lowercase)]
32 changes: 19 additions & 13 deletions backend/library/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def preview_library(framework: dict) -> dict[str, list]:
if framework.get("requirement_nodes"):
index = 0
for requirement_node in framework["requirement_nodes"]:
parent_urn = requirement_node.get("parent_urn")
if parent_urn:
parent_urn = parent_urn.lower()
index += 1
requirement_nodes_list.append(
RequirementNode(
Expand All @@ -49,8 +52,8 @@ def preview_library(framework: dict) -> dict[str, list]:
),
ref_id=requirement_node.get("ref_id"),
name=get_referential_translation(requirement_node, "name"),
urn=requirement_node["urn"],
parent_urn=requirement_node.get("parent_urn"),
urn=requirement_node["urn"].lower(),
parent_urn=parent_urn,
order_id=index,
)
)
Expand All @@ -70,12 +73,15 @@ def is_valid(self) -> Union[str, None]:
return "Missing the following fields : {}".format(", ".join(missing_fields))

def import_requirement_node(self, framework_object: Framework):
parent_urn = self.requirement_data.get("parent_urn")
if parent_urn:
parent_urn = parent_urn.lower()
requirement_node = RequirementNode.objects.create(
# Should i just inherit the folder from Framework or this is useless ?
folder=Folder.get_root_folder(),
framework=framework_object,
urn=self.requirement_data["urn"],
parent_urn=self.requirement_data.get("parent_urn"),
urn=self.requirement_data["urn"].lower(),
parent_urn=parent_urn,
assessable=self.requirement_data.get("assessable"),
ref_id=self.requirement_data.get("ref_id"),
annotation=self.requirement_data.get("annotation"),
Expand Down Expand Up @@ -126,15 +132,15 @@ def load(
):
try:
target_requirement = RequirementNode.objects.get(
urn=self.data["target_requirement_urn"], default_locale=True
urn=self.data["target_requirement_urn"].lower(), default_locale=True
)
except RequirementNode.DoesNotExist:
err_msg = f"ERROR: target requirement with URN {self.data['target_requirement_urn']} does not exist"
print(err_msg)
raise Http404(err_msg)
try:
source_requirement = RequirementNode.objects.get(
urn=self.data["source_requirement_urn"], default_locale=True
urn=self.data["source_requirement_urn"].lower(), default_locale=True
)
except RequirementNode.DoesNotExist:
err_msg = f"ERROR: source requirement with URN {self.data['source_requirement_urn']} does not exist"
Expand Down Expand Up @@ -179,14 +185,14 @@ def load(
):
self.init_requirement_mappings(self.data["requirement_mappings"])
_target_framework = Framework.objects.get(
urn=self.data["target_framework_urn"], default_locale=True
urn=self.data["target_framework_urn"].lower(), default_locale=True
)
_source_framework = Framework.objects.get(
urn=self.data["source_framework_urn"], default_locale=True
urn=self.data["source_framework_urn"].lower(), default_locale=True
)
mapping_set = RequirementMappingSet.objects.create(
name=self.data["name"],
urn=self.data["urn"],
urn=self.data["urn"].lower(),
target_framework=_target_framework,
source_framework=_source_framework,
library=library_object,
Expand Down Expand Up @@ -284,7 +290,7 @@ def import_framework(self, library_object: LoadedLibrary):
framework_object = Framework.objects.create(
folder=Folder.get_root_folder(),
library=library_object,
urn=self.framework_data["urn"],
urn=self.framework_data["urn"].lower(),
ref_id=self.framework_data["ref_id"],
name=self.framework_data.get("name"),
description=self.framework_data.get("description"),
Expand Down Expand Up @@ -318,7 +324,7 @@ def is_valid(self) -> Union[str, None]:
def import_threat(self, library_object: LoadedLibrary):
Threat.objects.create(
library=library_object,
urn=self.threat_data.get("urn"),
urn=self.threat_data["urn"].lower(),
ref_id=self.threat_data["ref_id"],
name=self.threat_data.get("name"),
description=self.threat_data.get("description"),
Expand Down Expand Up @@ -364,7 +370,7 @@ def is_valid(self) -> Union[str, None]:
def import_reference_control(self, library_object: LoadedLibrary):
ReferenceControl.objects.create(
library=library_object,
urn=self.reference_control_data.get("urn"),
urn=self.reference_control_data["urn"].lower(),
ref_id=self.reference_control_data["ref_id"],
name=self.reference_control_data.get("name"),
description=self.reference_control_data.get("description"),
Expand Down Expand Up @@ -410,7 +416,7 @@ def import_risk_matrix(self, library_object: LoadedLibrary):
folder=Folder.get_root_folder(),
name=self.risk_matrix_data.get("name"),
description=self.risk_matrix_data.get("description"),
urn=self.risk_matrix_data.get("urn"),
urn=self.risk_matrix_data["urn"].lower(),
provider=library_object.provider,
ref_id=self.risk_matrix_data.get("ref_id"),
json_definition=json.dumps(matrix_data),
Expand Down

0 comments on commit 16db8bc

Please sign in to comment.