Skip to content

Commit

Permalink
Ruff: fix some SIM
Browse files Browse the repository at this point in the history
  • Loading branch information
kiblik committed Nov 18, 2024
1 parent 7080a25 commit 125200d
Show file tree
Hide file tree
Showing 150 changed files with 401 additions and 937 deletions.
21 changes: 6 additions & 15 deletions dojo/api_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,7 @@ def to_representation(self, value):
if not isinstance(value, RequestResponseDict):
if not isinstance(value, list):
# this will trigger when a queryset is found...
if self.order_by:
burps = value.all().order_by(*self.order_by)
else:
burps = value.all()
burps = value.all().order_by(*self.order_by) if self.order_by else value.all()
value = [
{
"request": burp.get_request(),
Expand Down Expand Up @@ -552,10 +549,7 @@ def update(self, instance, validated_data):
return instance

def create(self, validated_data):
if "password" in validated_data:
password = validated_data.pop("password")
else:
password = None
password = validated_data.pop("password", None)

new_configuration_permissions = None
if (
Expand All @@ -581,10 +575,7 @@ def create(self, validated_data):
return user

def validate(self, data):
if self.instance is not None:
instance_is_superuser = self.instance.is_superuser
else:
instance_is_superuser = False
instance_is_superuser = self.instance.is_superuser if self.instance is not None else False
data_is_superuser = data.get("is_superuser", False)
if not self.context["request"].user.is_superuser and (
instance_is_superuser or data_is_superuser
Expand Down Expand Up @@ -1217,7 +1208,7 @@ class Meta:

def validate(self, data):

if not self.context["request"].method == "PATCH":
if self.context["request"].method != "PATCH":
if "product" not in data:
msg = "Product is required"
raise serializers.ValidationError(msg)
Expand Down Expand Up @@ -2248,7 +2239,7 @@ def setup_common_context(self, data: dict) -> dict:
"""
context = dict(data)
# update some vars
context["scan"] = data.pop("file", None)
context["scan"] = data.pop("file")

if context.get("auto_create_context"):
environment = Development_Environment.objects.get_or_create(name=data.get("environment", "Development"))[0]
Expand Down Expand Up @@ -2293,7 +2284,7 @@ def setup_common_context(self, data: dict) -> dict:

# engagement end date was not being used at all and so target_end would also turn into None
# in this case, do not want to change target_end unless engagement_end exists
eng_end_date = context.get("engagement_end_date", None)
eng_end_date = context.get("engagement_end_date")
if eng_end_date:
context["target_end"] = context.get("engagement_end_date")

Expand Down
49 changes: 10 additions & 39 deletions dojo/api_v2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,9 +1481,7 @@ def metadata(self, request, pk=None):
return self._get_metadata(request, finding)
if request.method == "POST":
return self._add_metadata(request, finding)
if request.method == "PUT":
return self._edit_metadata(request, finding)
if request.method == "PATCH":
if request.method in ["PUT", "PATCH"]:
return self._edit_metadata(request, finding)
if request.method == "DELETE":
return self._remove_metadata(request, finding)
Expand Down Expand Up @@ -2892,24 +2890,15 @@ def report_generate(request, obj, options):
if eng.name:
engagement_name = eng.name
engagement_target_start = eng.target_start
if eng.target_end:
engagement_target_end = eng.target_end
else:
engagement_target_end = "ongoing"
engagement_target_end = eng.target_end or "ongoing"
if eng.test_set.all():
for t in eng.test_set.all():
test_type_name = t.test_type.name
if t.environment:
test_environment_name = t.environment.name
test_target_start = t.target_start
if t.target_end:
test_target_end = t.target_end
else:
test_target_end = "ongoing"
if eng.test_strategy:
test_strategy_ref = eng.test_strategy
else:
test_strategy_ref = ""
test_target_end = t.target_end or "ongoing"
test_strategy_ref = eng.test_strategy or ""
total_findings = len(findings.qs.all())

elif type(obj).__name__ == "Product":
Expand All @@ -2919,59 +2908,41 @@ def report_generate(request, obj, options):
if eng.name:
engagement_name = eng.name
engagement_target_start = eng.target_start
if eng.target_end:
engagement_target_end = eng.target_end
else:
engagement_target_end = "ongoing"
engagement_target_end = eng.target_end or "ongoing"

if eng.test_set.all():
for t in eng.test_set.all():
test_type_name = t.test_type.name
if t.environment:
test_environment_name = t.environment.name
if eng.test_strategy:
test_strategy_ref = eng.test_strategy
else:
test_strategy_ref = ""
test_strategy_ref = eng.test_strategy or ""
total_findings = len(findings.qs.all())

elif type(obj).__name__ == "Engagement":
eng = obj
if eng.name:
engagement_name = eng.name
engagement_target_start = eng.target_start
if eng.target_end:
engagement_target_end = eng.target_end
else:
engagement_target_end = "ongoing"
engagement_target_end = eng.target_end or "ongoing"

if eng.test_set.all():
for t in eng.test_set.all():
test_type_name = t.test_type.name
if t.environment:
test_environment_name = t.environment.name
if eng.test_strategy:
test_strategy_ref = eng.test_strategy
else:
test_strategy_ref = ""
test_strategy_ref = eng.test_strategy or ""
total_findings = len(findings.qs.all())

elif type(obj).__name__ == "Test":
t = obj
test_type_name = t.test_type.name
test_target_start = t.target_start
if t.target_end:
test_target_end = t.target_end
else:
test_target_end = "ongoing"
test_target_end = t.target_end or "ongoing"
total_findings = len(findings.qs.all())
if t.engagement.name:
engagement_name = t.engagement.name
engagement_target_start = t.engagement.target_start
if t.engagement.target_end:
engagement_target_end = t.engagement.target_end
else:
engagement_target_end = "ongoing"
engagement_target_end = t.engagement.target_end or "ongoing"
else:
pass # do nothing

Expand Down
31 changes: 9 additions & 22 deletions dojo/authorization/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def user_has_permission(user, obj, permission):
if user.is_superuser:
return True

if isinstance(obj, Product_Type) or isinstance(obj, Product):
if isinstance(obj, Product_Type | Product):
# Global roles are only relevant for product types, products and their
# dependent objects
if user_has_global_permission(user, permission):
Expand Down Expand Up @@ -97,13 +97,9 @@ def user_has_permission(user, obj, permission):
and permission in Permissions.get_test_permissions()
):
return user_has_permission(user, obj.engagement.product, permission)
if (
isinstance(obj, Finding) or isinstance(obj, Stub_Finding)
) and permission in Permissions.get_finding_permissions():
return user_has_permission(
user, obj.test.engagement.product, permission,
)
if (
if ((
isinstance(obj, Finding | Stub_Finding)
) and permission in Permissions.get_finding_permissions()) or (
isinstance(obj, Finding_Group)
and permission in Permissions.get_finding_group_permissions()
):
Expand All @@ -113,23 +109,17 @@ def user_has_permission(user, obj, permission):
if (
isinstance(obj, Endpoint)
and permission in Permissions.get_endpoint_permissions()
):
return user_has_permission(user, obj.product, permission)
if (
) or (
isinstance(obj, Languages)
and permission in Permissions.get_language_permissions()
):
return user_has_permission(user, obj.product, permission)
if (
) or ((
isinstance(obj, App_Analysis)
and permission in Permissions.get_technology_permissions()
):
return user_has_permission(user, obj.product, permission)
if (
) or (
isinstance(obj, Product_API_Scan_Configuration)
and permission
in Permissions.get_product_api_scan_configuration_permissions()
):
)):
return user_has_permission(user, obj.product, permission)
if (
isinstance(obj, Product_Type_Member)
Expand Down Expand Up @@ -351,10 +341,7 @@ def get_product_groups_dict(user):
.select_related("role")
.filter(group__users=user)
):
if pg_dict.get(product_group.product.id) is None:
pgu_list = []
else:
pgu_list = pg_dict[product_group.product.id]
pgu_list = [] if pg_dict.get(product_group.product.id) is None else pg_dict[product_group.product.id]
pgu_list.append(product_group)
pg_dict[product_group.product.id] = pgu_list
return pg_dict
Expand Down
5 changes: 2 additions & 3 deletions dojo/benchmark/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import logging

from crum import get_current_user
Expand Down Expand Up @@ -37,10 +38,8 @@ def add_benchmark(queryset, product):
benchmark_product.control = requirement
requirements.append(benchmark_product)

try:
with contextlib.suppress(Exception):
Benchmark_Product.objects.bulk_create(requirements)
except Exception:
pass


@user_is_authorized(Product, Permissions.Benchmark_Edit, "pid")
Expand Down
5 changes: 1 addition & 4 deletions dojo/cred/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ def get_authorized_cred_mappings(permission, queryset=None):
if user is None:
return Cred_Mapping.objects.none()

if queryset is None:
cred_mappings = Cred_Mapping.objects.all().order_by("id")
else:
cred_mappings = queryset
cred_mappings = Cred_Mapping.objects.all().order_by("id") if queryset is None else queryset

if user.is_superuser:
return cred_mappings
Expand Down
5 changes: 2 additions & 3 deletions dojo/cred/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import logging

from django.contrib import messages
Expand Down Expand Up @@ -585,10 +586,8 @@ def new_cred_finding(request, fid):
@user_is_authorized(Cred_User, Permissions.Credential_Delete, "ttid")
def delete_cred_controller(request, destination_url, id, ttid):
cred = None
try:
with contextlib.suppress(Exception):
cred = Cred_Mapping.objects.get(pk=ttid)
except:
pass
if request.method == "POST":
tform = CredMappingForm(request.POST, instance=cred)
message = ""
Expand Down
10 changes: 2 additions & 8 deletions dojo/endpoint/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ def get_authorized_endpoints(permission, queryset=None, user=None):
if user is None:
return Endpoint.objects.none()

if queryset is None:
endpoints = Endpoint.objects.all().order_by("id")
else:
endpoints = queryset
endpoints = Endpoint.objects.all().order_by("id") if queryset is None else queryset

if user.is_superuser:
return endpoints
Expand Down Expand Up @@ -66,10 +63,7 @@ def get_authorized_endpoint_status(permission, queryset=None, user=None):
if user is None:
return Endpoint_Status.objects.none()

if queryset is None:
endpoint_status = Endpoint_Status.objects.all().order_by("id")
else:
endpoint_status = queryset
endpoint_status = Endpoint_Status.objects.all().order_by("id") if queryset is None else queryset

if user.is_superuser:
return endpoint_status
Expand Down
41 changes: 11 additions & 30 deletions dojo/endpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,11 @@
def endpoint_filter(**kwargs):
qs = Endpoint.objects.all()

if kwargs.get("protocol"):
qs = qs.filter(protocol__iexact=kwargs["protocol"])
else:
qs = qs.filter(protocol__isnull=True)
qs = qs.filter(protocol__iexact=kwargs["protocol"]) if kwargs.get("protocol") else qs.filter(protocol__isnull=True)

if kwargs.get("userinfo"):
qs = qs.filter(userinfo__exact=kwargs["userinfo"])
else:
qs = qs.filter(userinfo__isnull=True)
qs = qs.filter(userinfo__exact=kwargs["userinfo"]) if kwargs.get("userinfo") else qs.filter(userinfo__isnull=True)

if kwargs.get("host"):
qs = qs.filter(host__iexact=kwargs["host"])
else:
qs = qs.filter(host__isnull=True)
qs = qs.filter(host__iexact=kwargs["host"]) if kwargs.get("host") else qs.filter(host__isnull=True)

if kwargs.get("port"):
if (kwargs.get("protocol")) and \
Expand All @@ -48,20 +39,11 @@ def endpoint_filter(**kwargs):
else:
qs = qs.filter(port__isnull=True)

if kwargs.get("path"):
qs = qs.filter(path__exact=kwargs["path"])
else:
qs = qs.filter(path__isnull=True)
qs = qs.filter(path__exact=kwargs["path"]) if kwargs.get("path") else qs.filter(path__isnull=True)

if kwargs.get("query"):
qs = qs.filter(query__exact=kwargs["query"])
else:
qs = qs.filter(query__isnull=True)
qs = qs.filter(query__exact=kwargs["query"]) if kwargs.get("query") else qs.filter(query__isnull=True)

if kwargs.get("fragment"):
qs = qs.filter(fragment__exact=kwargs["fragment"])
else:
qs = qs.filter(fragment__isnull=True)
qs = qs.filter(fragment__exact=kwargs["fragment"]) if kwargs.get("fragment") else qs.filter(fragment__isnull=True)

if kwargs.get("product"):
qs = qs.filter(product__exact=kwargs["product"])
Expand Down Expand Up @@ -267,12 +249,11 @@ def validate_endpoints_to_add(endpoints_to_add):
endpoints = endpoints_to_add.split()
for endpoint in endpoints:
try:
if "://" in endpoint: # is it full uri?
endpoint_ins = Endpoint.from_uri(endpoint) # from_uri validate URI format + split to components
else:
# from_uri parse any '//localhost', '//127.0.0.1:80', '//foo.bar/path' correctly
# format doesn't follow RFC 3986 but users use it
endpoint_ins = Endpoint.from_uri("//" + endpoint)
# is it full uri?
# 1. from_uri validate URI format + split to components
# 2. from_uri parse any '//localhost', '//127.0.0.1:80', '//foo.bar/path' correctly
# format doesn't follow RFC 3986 but users use it
endpoint_ins = Endpoint.from_uri(endpoint) if "://" in endpoint else Endpoint.from_uri("//" + endpoint)
endpoint_ins.clean()
endpoint_list.append([
endpoint_ins.protocol,
Expand Down
5 changes: 1 addition & 4 deletions dojo/endpoint/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ def process_endpoints_view(request, host_view=False, vulnerable=False):

paged_endpoints = get_page_items(request, endpoints.qs, 25)

if vulnerable:
view_name = "Vulnerable"
else:
view_name = "All"
view_name = "Vulnerable" if vulnerable else "All"

if host_view:
view_name += " Hosts"
Expand Down
Loading

0 comments on commit 125200d

Please sign in to comment.