From 9942c9fab60a096f12f61851494476af697ee013 Mon Sep 17 00:00:00 2001
From: Donny Peeters <46660228+Donnype@users.noreply.github.com>
Date: Tue, 10 Sep 2024 11:05:20 +0200
Subject: [PATCH 1/7] Handle empty normalizer results (#3482)
Signed-off-by: Donny Peeters
Co-authored-by: Jan Klopper
---
boefjes/boefjes/job_handler.py | 17 +++++++++++++++++
1 file changed, 17 insertions(+)
diff --git a/boefjes/boefjes/job_handler.py b/boefjes/boefjes/job_handler.py
index 8b0d979d256..7ad0f7247ad 100644
--- a/boefjes/boefjes/job_handler.py
+++ b/boefjes/boefjes/job_handler.py
@@ -222,6 +222,23 @@ def handle(self, normalizer_meta: NormalizerMeta) -> None:
)
)
+ if (
+ normalizer_meta.raw_data.boefje_meta.input_ooi # No input OOI means no deletion propagation
+ and not (results.observations or results.declarations or results.affirmations)
+ ):
+ # There were no results found, which we still need to signal to Octopoes for deletion propagation
+
+ connector.save_observation(
+ Observation(
+ method=normalizer_meta.normalizer.id,
+ source=Reference.from_str(normalizer_meta.raw_data.boefje_meta.input_ooi),
+ source_method=normalizer_meta.raw_data.boefje_meta.boefje.id,
+ task_id=normalizer_meta.id,
+ valid_time=normalizer_meta.raw_data.boefje_meta.ended_at,
+ result=[],
+ )
+ )
+
corrected_scan_profiles = []
for profile in results.scan_profiles:
profile.level = ScanLevel(
From b098d8d2ee2f7bdbb4c7d98e7125ae3c16f43bd7 Mon Sep 17 00:00:00 2001
From: Donny Peeters <46660228+Donnype@users.noreply.github.com>
Date: Tue, 10 Sep 2024 11:13:48 +0200
Subject: [PATCH 2/7] Fix enabling normalizers from Rocky (#3481)
Signed-off-by: Donny Peeters
Co-authored-by: Jan Klopper
---
rocky/katalogus/client.py | 12 ++---
.../katalogus/views/plugin_enable_disable.py | 49 ++-----------------
rocky/katalogus/views/plugin_settings_add.py | 2 +-
rocky/rocky/locale/django.pot | 10 +---
4 files changed, 11 insertions(+), 62 deletions(-)
diff --git a/rocky/katalogus/client.py b/rocky/katalogus/client.py
index 8bb9f7fc463..ebb614b0b9e 100644
--- a/rocky/katalogus/client.py
+++ b/rocky/katalogus/client.py
@@ -169,14 +169,14 @@ def get_normalizers(self) -> list[Plugin]:
def get_boefjes(self) -> list[Plugin]:
return self.get_plugins(plugin_type="boefje")
- def enable_boefje(self, plugin: Plugin) -> None:
- self._patch_boefje_state(plugin.id, True)
+ def enable_plugin(self, plugin: Plugin) -> None:
+ self._patch_plugin_state(plugin.id, True)
def enable_boefje_by_id(self, boefje_id: str) -> None:
- self.enable_boefje(self.get_plugin(boefje_id))
+ self.enable_plugin(self.get_plugin(boefje_id))
- def disable_boefje(self, plugin: Plugin) -> None:
- self._patch_boefje_state(plugin.id, False)
+ def disable_plugin(self, plugin: Plugin) -> None:
+ self._patch_plugin_state(plugin.id, False)
def get_enabled_boefjes(self) -> list[Plugin]:
return [plugin for plugin in self.get_boefjes() if plugin.enabled]
@@ -184,7 +184,7 @@ def get_enabled_boefjes(self) -> list[Plugin]:
def get_enabled_normalizers(self) -> list[Plugin]:
return [plugin for plugin in self.get_normalizers() if plugin.enabled]
- def _patch_boefje_state(self, boefje_id: str, enabled: bool) -> None:
+ def _patch_plugin_state(self, boefje_id: str, enabled: bool) -> None:
logger.info("Toggle plugin state", plugin_id=boefje_id, enabled=enabled)
response = self.session.patch(
diff --git a/rocky/katalogus/views/plugin_enable_disable.py b/rocky/katalogus/views/plugin_enable_disable.py
index 36db033e530..81b3d7ce36f 100644
--- a/rocky/katalogus/views/plugin_enable_disable.py
+++ b/rocky/katalogus/views/plugin_enable_disable.py
@@ -1,24 +1,16 @@
from django.contrib import messages
from django.http import HttpResponseRedirect
-from django.shortcuts import redirect
-from django.urls import reverse
from django.utils.translation import gettext_lazy as _
-from httpx import HTTPError
from katalogus.views.mixins import SinglePluginView
class PluginEnableDisableView(SinglePluginView):
- def check_required_settings(self, settings: dict):
- if self.plugin_schema is None or "required" not in self.plugin_schema:
- return True
-
- return all([field in settings for field in self.plugin_schema["required"]])
-
def post(self, request, *args, **kwargs):
plugin_state = kwargs["plugin_state"]
+
if plugin_state == "True":
- self.katalogus_client.disable_boefje(self.plugin)
+ self.katalogus_client.disable_plugin(self.plugin)
messages.add_message(
self.request,
messages.WARNING,
@@ -26,43 +18,8 @@ def post(self, request, *args, **kwargs):
)
return HttpResponseRedirect(request.POST.get("current_url"))
- try:
- plugin_settings = self.katalogus_client.get_plugin_settings(self.plugin.id)
- except HTTPError:
- messages.add_message(
- self.request,
- messages.ERROR,
- _("Failed fetching settings for {}. Is the Katalogus up?").format(self.plugin.name),
- )
- return redirect(
- reverse(
- "boefje_detail",
- kwargs={
- "organization_code": self.organization.code,
- "plugin_id": self.plugin.id,
- },
- )
- )
-
- if not self.check_required_settings(plugin_settings):
- messages.add_message(
- self.request,
- messages.INFO,
- _("Before enabling, please set the required settings for '{}'.").format(self.plugin.name),
- )
- return redirect(
- reverse(
- "plugin_settings_add",
- kwargs={
- "organization_code": self.organization.code,
- "plugin_id": self.plugin.id,
- "plugin_type": self.plugin.type,
- },
- )
- )
-
if self.plugin.can_scan(self.organization_member):
- self.katalogus_client.enable_boefje(self.plugin)
+ self.katalogus_client.enable_plugin(self.plugin)
messages.add_message(
self.request,
messages.SUCCESS,
diff --git a/rocky/katalogus/views/plugin_settings_add.py b/rocky/katalogus/views/plugin_settings_add.py
index 47842b284a5..f89921025cd 100644
--- a/rocky/katalogus/views/plugin_settings_add.py
+++ b/rocky/katalogus/views/plugin_settings_add.py
@@ -54,7 +54,7 @@ def form_valid(self, form):
if "add-enable" in self.request.POST:
try:
- self.katalogus_client.enable_boefje(self.plugin)
+ self.katalogus_client.enable_plugin(self.plugin)
except HTTPError:
messages.add_message(self.request, messages.ERROR, _("Enabling {} failed").format(self.plugin.name))
return redirect(self.get_success_url())
diff --git a/rocky/rocky/locale/django.pot b/rocky/rocky/locale/django.pot
index 8d7603aad8a..84889204c77 100644
--- a/rocky/rocky/locale/django.pot
+++ b/rocky/rocky/locale/django.pot
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2024-09-05 08:44+0000\n"
+"POT-Creation-Date: 2024-09-06 08:27+0000\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language-Team: LANGUAGE \n"
@@ -1198,14 +1198,6 @@ msgstr ""
msgid "{} '{}' disabled."
msgstr ""
-#: katalogus/views/plugin_enable_disable.py
-msgid "Failed fetching settings for {}. Is the Katalogus up?"
-msgstr ""
-
-#: katalogus/views/plugin_enable_disable.py
-msgid "Before enabling, please set the required settings for '{}'."
-msgstr ""
-
#: katalogus/views/plugin_enable_disable.py
msgid "{} '{}' enabled."
msgstr ""
From 60283ac02e142055e5ceadb858cf4d3c398e983e Mon Sep 17 00:00:00 2001
From: Donny Peeters <46660228+Donnype@users.noreply.github.com>
Date: Tue, 10 Sep 2024 14:13:18 +0200
Subject: [PATCH 3/7] Feature/upload multiple files at once to bytes (#3476)
Signed-off-by: Donny Peeters
Co-authored-by: ammar92
---
boefjes/boefjes/clients/bytes_client.py | 23 +++--
bytes/bytes/api/models.py | 14 +++-
bytes/bytes/api/router.py | 97 +++++++++++++---------
bytes/bytes/models.py | 10 ++-
bytes/tests/client.py | 31 +++++--
bytes/tests/integration/test_bytes_api.py | 49 +++++++++--
bytes/tests/integration/test_migrations.py | 6 +-
bytes/tests/unit/test_context_mapping.py | 2 +-
rocky/rocky/bytes_client.py | 19 ++++-
9 files changed, 179 insertions(+), 72 deletions(-)
diff --git a/boefjes/boefjes/clients/bytes_client.py b/boefjes/boefjes/clients/bytes_client.py
index c2698523183..b7b66bbc272 100644
--- a/boefjes/boefjes/clients/bytes_client.py
+++ b/boefjes/boefjes/clients/bytes_client.py
@@ -1,5 +1,6 @@
import typing
import uuid
+from base64 import b64encode
from collections.abc import Callable, Set
from functools import wraps
from typing import Any
@@ -99,17 +100,25 @@ def get_normalizer_meta(self, normalizer_meta_id: uuid.UUID) -> NormalizerMeta:
@retry_with_login
def save_raw(self, boefje_meta_id: str, raw: str | bytes, mime_types: Set[str] = frozenset()) -> UUID:
- headers = {"content-type": "application/octet-stream"}
- headers.update(self.headers)
+ file_name = "raw" # The name provides a key for all ids returned, so this is arbitrary as we only upload 1 file
+
response = self._session.post(
"/bytes/raw",
- content=raw,
- headers=headers,
- params={"mime_types": list(mime_types), "boefje_meta_id": boefje_meta_id},
+ json={
+ "files": [
+ {
+ "name": file_name,
+ "content": b64encode(raw if isinstance(raw, bytes) else raw.encode()).decode(),
+ "tags": list(mime_types),
+ }
+ ]
+ },
+ headers=self.headers,
+ params={"boefje_meta_id": str(boefje_meta_id)},
)
-
self._verify_response(response)
- return UUID(response.json()["id"])
+
+ return UUID(response.json()[file_name])
@retry_with_login
def get_raw(self, raw_data_id: str) -> bytes:
diff --git a/bytes/bytes/api/models.py b/bytes/bytes/api/models.py
index af77cb82fc6..fddfc0dc4a3 100644
--- a/bytes/bytes/api/models.py
+++ b/bytes/bytes/api/models.py
@@ -1,7 +1,17 @@
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
class RawResponse(BaseModel):
status: str
message: str
- id: str | None = None
+ ids: list[str] | None = None
+
+
+class File(BaseModel):
+ name: str
+ content: str = Field(..., contentEncoding="base64")
+ tags: list[str] = Field(default_factory=list)
+
+
+class BoefjeOutput(BaseModel):
+ files: list[File] = Field(default_factory=list)
diff --git a/bytes/bytes/api/router.py b/bytes/bytes/api/router.py
index e88efc6cc06..394f0e959b4 100644
--- a/bytes/bytes/api/router.py
+++ b/bytes/bytes/api/router.py
@@ -1,13 +1,14 @@
+from base64 import b64decode
from uuid import UUID
import structlog
-from asgiref.sync import async_to_sync
from cachetools import TTLCache, cached
-from fastapi import APIRouter, Depends, HTTPException, Query, Request
+from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
+from httpx import codes
from starlette.responses import JSONResponse
-from bytes.api.models import RawResponse
+from bytes.api.models import BoefjeOutput
from bytes.auth import authenticate_token
from bytes.config import get_settings
from bytes.database.sql_meta_repository import MetaIntegrityError, ObjectNotFoundException, create_meta_data_repository
@@ -34,10 +35,11 @@ def create_boefje_meta(
meta_repository.save_boefje_meta(boefje_meta)
except MetaIntegrityError:
return JSONResponse(
- {"status": "failed", "message": "Integrity error: object might already exist"}, status_code=400
+ {"status": "failed", "message": "Integrity error: object might already exist"},
+ status_code=codes.BAD_REQUEST,
)
- return JSONResponse({"status": "success"}, status_code=201)
+ return JSONResponse({"status": "success"}, status_code=codes.CREATED)
@router.get("/boefje_meta/{boefje_meta_id}", response_model=BoefjeMeta, tags=[BOEFJE_META_TAG])
@@ -95,10 +97,11 @@ def create_normalizer_meta(
meta_repository.save_normalizer_meta(normalizer_meta)
except MetaIntegrityError:
return JSONResponse(
- {"status": "failed", "message": "Integrity error: object might already exist"}, status_code=400
+ {"status": "failed", "message": "Integrity error: object might already exist"},
+ status_code=codes.BAD_REQUEST,
)
- return JSONResponse({"status": "success"}, status_code=201)
+ return JSONResponse({"status": "success"}, status_code=codes.CREATED)
@router.get("/normalizer_meta/{normalizer_meta_id}", response_model=NormalizerMeta, tags=[NORMALIZER_META_TAG])
@@ -109,7 +112,7 @@ def get_normalizer_meta_by_id(
try:
return meta_repository.get_normalizer_meta_by_id(normalizer_meta_id)
except ObjectNotFoundException as error:
- raise HTTPException(status_code=404, detail="Normalizer meta not found") from error
+ raise HTTPException(status_code=codes.NOT_FOUND, detail="Normalizer meta not found") from error
@router.get("/normalizer_meta", response_model=list[NormalizerMeta], tags=[NORMALIZER_META_TAG])
@@ -148,42 +151,60 @@ def get_normalizer_meta(
@router.post("/raw", tags=[RAW_TAG])
def create_raw(
- request: Request,
boefje_meta_id: UUID,
- mime_types: list[str] | None = Query(None),
+ boefje_output: BoefjeOutput,
meta_repository: MetaDataRepository = Depends(create_meta_data_repository),
event_manager: EventManager = Depends(create_event_manager),
-) -> RawResponse:
- parsed_mime_types = [] if mime_types is None else [MimeType(value=mime_type) for mime_type in mime_types]
+) -> dict[str, UUID]:
+ """Parse all the raw files from the request and return the ids. The ids are ordered according to the order from the
+ request data, but we assume the `name` field is unique, and hence return a mapping of the file name to the id."""
- try:
- meta = meta_repository.get_boefje_meta_by_id(boefje_meta_id)
+ raw_ids = {}
+ mime_types_by_id = {
+ raw.id: set(raw.mime_types) for raw in meta_repository.get_raw(RawDataFilter(boefje_meta_id=boefje_meta_id))
+ }
+ all_parsed_mime_types = list(mime_types_by_id.values())
- if meta_repository.has_raw(meta, parsed_mime_types):
- return RawResponse(status="success", message="Raw data already present")
+ for raw in boefje_output.files:
+ parsed_mime_types = {MimeType(value=x) for x in raw.tags}
- # FastAPI/starlette only has async versions of the Request methods, but
- # all our code is sync, so we wrap it in async_to_sync.
- data = async_to_sync(request.body)()
+ if parsed_mime_types in mime_types_by_id.values():
+ # Set the id for this file using the precomputed dict that maps existing primary keys to the mime-type set.
+ raw_ids[raw.name] = list(mime_types_by_id.keys())[list(mime_types_by_id.values()).index(parsed_mime_types)]
- raw_data = RawData(value=data, boefje_meta=meta, mime_types=parsed_mime_types)
- with meta_repository:
- raw_id = meta_repository.save_raw(raw_data)
-
- event = RawFileReceived(
- organization=meta.organization,
- raw_data=RawDataMeta(
- id=raw_id,
- boefje_meta=raw_data.boefje_meta,
- mime_types=raw_data.mime_types,
- ),
- )
- event_manager.publish(event)
- except Exception as error:
- logger.exception("Error saving raw data")
- raise HTTPException(status_code=500, detail="Could not save raw data") from error
+ continue
+
+ if parsed_mime_types in all_parsed_mime_types:
+ raise HTTPException(
+ status_code=codes.BAD_REQUEST, detail="Content types do not define unique sets of mime types."
+ )
+
+ try:
+ meta = meta_repository.get_boefje_meta_by_id(boefje_meta_id)
+ raw_data = RawData(value=b64decode(raw.content.encode()), boefje_meta=meta, mime_types=parsed_mime_types)
+
+ with meta_repository:
+ raw_id = meta_repository.save_raw(raw_data)
+ raw_ids[raw.name] = raw_id
+
+ all_parsed_mime_types.append(parsed_mime_types)
+
+ event = RawFileReceived(
+ organization=meta.organization,
+ raw_data=RawDataMeta(
+ id=raw_id,
+ boefje_meta=raw_data.boefje_meta,
+ mime_types=raw_data.mime_types,
+ ),
+ )
+ event_manager.publish(event)
+ except Exception as error:
+ logger.exception("Error saving raw data")
+ raise HTTPException(status_code=codes.INTERNAL_SERVER_ERROR, detail="Could not save raw data") from error
+
+ all_parsed_mime_types.append(parsed_mime_types)
- return RawResponse(status="success", message="Raw data saved", id=raw_id)
+ return raw_ids
@router.get("/raw/{raw_id}", tags=[RAW_TAG])
@@ -194,7 +215,7 @@ def get_raw_by_id(
try:
raw_data = meta_repository.get_raw_by_id(raw_id)
except ObjectNotFoundException as error:
- raise HTTPException(status_code=404, detail="No raw data found") from error
+ raise HTTPException(status_code=codes.NOT_FOUND, detail="No raw data found") from error
return Response(raw_data.value, media_type="application/octet-stream")
@@ -207,7 +228,7 @@ def get_raw_meta_by_id(
try:
raw_meta = meta_repository.get_raw_meta_by_id(raw_id)
except ObjectNotFoundException as error:
- raise HTTPException(status_code=404, detail="No raw data found") from error
+ raise HTTPException(status_code=codes.NOT_FOUND, detail="No raw data found") from error
return raw_meta
diff --git a/bytes/bytes/models.py b/bytes/bytes/models.py
index ce8cfcd08be..03ae39506aa 100644
--- a/bytes/bytes/models.py
+++ b/bytes/bytes/models.py
@@ -38,6 +38,12 @@ def _validate_timezone_aware_datetime(value: datetime) -> datetime:
class MimeType(BaseModel):
value: str
+ def __hash__(self):
+ return hash(self.value)
+
+ def __lt__(self, other: MimeType):
+ return self.value < other.value
+
class Job(BaseModel):
id: UUID
@@ -69,7 +75,7 @@ class RawDataMeta(BaseModel):
id: UUID
boefje_meta: BoefjeMeta
- mime_types: list[MimeType] = Field(default_factory=list)
+ mime_types: set[MimeType] = Field(default_factory=set)
# These are set once the raw is saved
secure_hash: SecureHash | None = None
@@ -80,7 +86,7 @@ class RawDataMeta(BaseModel):
class RawData(BaseModel):
value: bytes
boefje_meta: BoefjeMeta
- mime_types: list[MimeType] = Field(default_factory=list)
+ mime_types: set[MimeType] = Field(default_factory=set)
# These are set once the raw is saved
secure_hash: SecureHash | None = None
diff --git a/bytes/tests/client.py b/bytes/tests/client.py
index 6996d1e821b..22405c75ffc 100644
--- a/bytes/tests/client.py
+++ b/bytes/tests/client.py
@@ -1,4 +1,5 @@
import typing
+from base64 import b64encode
from collections.abc import Callable
from functools import wraps
from typing import Any
@@ -7,6 +8,7 @@
import httpx
from httpx import HTTPError
+from bytes.api.models import BoefjeOutput
from bytes.models import BoefjeMeta, NormalizerMeta
from bytes.repositories.meta_repository import BoefjeMetaFilter, NormalizerMetaFilter, RawDataFilter
@@ -126,19 +128,34 @@ def save_raw(self, boefje_meta_id: UUID, raw: bytes, mime_types: list[str] | Non
if not mime_types:
mime_types = []
- headers = {"content-type": "application/octet-stream"}
-
+ file_name = "raw" # The name provides a key for all ids returned, so this is arbitrary as we only upload 1 file
response = self.client.post(
"/bytes/raw",
- content=raw,
- headers=headers,
- params={"mime_types": mime_types, "boefje_meta_id": str(boefje_meta_id)},
+ json={
+ "files": [
+ {
+ "name": file_name,
+ "content": b64encode(raw).decode(),
+ "tags": mime_types,
+ }
+ ],
+ },
+ params={"boefje_meta_id": str(boefje_meta_id)},
)
+ self._verify_response(response)
+ return response.json()[file_name]
+
+ @retry_with_login
+ def save_raws(self, boefje_meta_id: UUID, boefje_output: BoefjeOutput) -> dict[str, str]:
+ response = self.client.post(
+ "/bytes/raw",
+ content=boefje_output.model_dump_json(),
+ params={"boefje_meta_id": str(boefje_meta_id)},
+ )
self._verify_response(response)
- raw_id = response.json()["id"]
- return str(raw_id)
+ return response.json()
@retry_with_login
def get_raw(self, raw_id: UUID) -> bytes:
diff --git a/bytes/tests/integration/test_bytes_api.py b/bytes/tests/integration/test_bytes_api.py
index 95122aaeea0..afc34ec75bd 100644
--- a/bytes/tests/integration/test_bytes_api.py
+++ b/bytes/tests/integration/test_bytes_api.py
@@ -1,10 +1,12 @@
import uuid
+from base64 import b64encode
import httpx
import pytest
from httpx import HTTPError
from prometheus_client.parser import text_string_to_metric_families
+from bytes.api.models import BoefjeOutput, File
from bytes.models import MimeType
from bytes.rabbitmq import RabbitMQEventManager
from bytes.repositories.meta_repository import BoefjeMetaFilter, NormalizerMetaFilter, RawDataFilter
@@ -147,7 +149,10 @@ def test_normalizer_meta(bytes_api_client: BytesAPIClient, event_manager: Rabbit
normalizer_meta.raw_data.hash_retrieval_link = retrieved_normalizer_meta.raw_data.hash_retrieval_link
normalizer_meta.raw_data.signing_provider_url = retrieved_normalizer_meta.raw_data.signing_provider_url
- assert normalizer_meta.dict() == retrieved_normalizer_meta.dict()
+ normalizer_meta.raw_data.mime_types = sorted(normalizer_meta.raw_data.mime_types)
+ retrieved_normalizer_meta.raw_data.mime_types = sorted(retrieved_normalizer_meta.raw_data.mime_types)
+
+ assert normalizer_meta.model_dump_json() == retrieved_normalizer_meta.model_dump_json()
def test_filtered_normalizer_meta(bytes_api_client: BytesAPIClient) -> None:
@@ -255,21 +260,30 @@ def test_save_raw_no_mime_types(bytes_api_client: BytesAPIClient) -> None:
boefje_meta = get_boefje_meta(meta_id=uuid.uuid4())
bytes_api_client.save_boefje_meta(boefje_meta)
- headers = {"content-type": "application/octet-stream"}
bytes_api_client.login()
- headers.update(bytes_api_client.client.headers)
raw_url = f"{bytes_api_client.client.base_url}/bytes/raw"
raw = b"second test 123456"
+ file_name = "raw"
response = httpx.post(
- raw_url, content=raw, headers=headers, params={"boefje_meta_id": str(boefje_meta.id)}, timeout=30
+ raw_url,
+ json={
+ "files": [
+ {
+ "name": file_name,
+ "content": b64encode(raw).decode(),
+ "tags": [],
+ }
+ ]
+ },
+ headers=bytes_api_client.client.headers,
+ params={"boefje_meta_id": str(boefje_meta.id)},
)
-
assert response.status_code == 200
get_raw_without_mime_type_response = httpx.get(
- f"{raw_url}/{response.json().get('id')}", headers=bytes_api_client.client.headers, timeout=30
+ f"{raw_url}/{response.json()[file_name]}", headers=bytes_api_client.client.headers, timeout=30
)
assert get_raw_without_mime_type_response.status_code == 200
@@ -293,13 +307,13 @@ def test_raw_mimes(bytes_api_client: BytesAPIClient) -> None:
)
)
assert len(retrieved_raws) == 1
- assert retrieved_raws[0]["mime_types"] == [{"value": value} for value in mime_types]
+ assert {x["value"] for x in retrieved_raws[0]["mime_types"]} == set(mime_types)
retrieved_raws = bytes_api_client.get_raws(
RawDataFilter(boefje_meta_id=boefje_meta.id, normalized=False, mime_types=[MimeType(value="text/html")])
)
assert len(retrieved_raws) == 1
- assert retrieved_raws[0]["mime_types"] == [{"value": value} for value in mime_types]
+ assert {x["value"] for x in retrieved_raws[0]["mime_types"]} == set(mime_types)
retrieved_raws = bytes_api_client.get_raws(
RawDataFilter(boefje_meta_id=boefje_meta.id, normalized=False, mime_types=[MimeType(value="bad/mime")])
@@ -336,3 +350,22 @@ def test_cannot_overwrite_raw(bytes_api_client: BytesAPIClient) -> None:
retrieved_raw = bytes_api_client.get_raw(first_raw_id)
assert retrieved_raw == right_raw
+
+
+def test_save_multiple_raw_files(bytes_api_client: BytesAPIClient) -> None:
+ boefje_meta = get_boefje_meta()
+ bytes_api_client.save_boefje_meta(boefje_meta)
+
+ first_raw = b"first"
+ second_raw = b"second"
+ boefje_output = BoefjeOutput(
+ files=[
+ File(name="first", content=b64encode(first_raw).decode(), tags=[]),
+ File(name="second", content=b64encode(second_raw).decode(), tags=["mime", "type"]),
+ ]
+ )
+
+ ids = bytes_api_client.save_raws(boefje_meta.id, boefje_output)
+
+ assert bytes_api_client.get_raw(ids["first"]) == first_raw
+ assert bytes_api_client.get_raw(ids["second"]) == second_raw
diff --git a/bytes/tests/integration/test_migrations.py b/bytes/tests/integration/test_migrations.py
index eb3a29c0c4f..4967d847470 100644
--- a/bytes/tests/integration/test_migrations.py
+++ b/bytes/tests/integration/test_migrations.py
@@ -13,15 +13,15 @@ def test_clean_mime_types(meta_repository: SQLMetaDataRepository) -> None:
meta_repository.save_boefje_meta(boefje_meta)
raw = get_raw_data()
- raw.mime_types.append(MimeType(value=raw.boefje_meta.boefje.id))
+ raw.mime_types.add(MimeType(value=raw.boefje_meta.boefje.id))
raw_id_1 = meta_repository.save_raw(raw)
- raw.mime_types.append(
+ raw.mime_types.add(
MimeType(value=f"boefje/{raw.boefje_meta.boefje.id}-ce293f79fd3c809a300a2837bb1da4f7115fc034a1f78")
)
raw_id_2 = meta_repository.save_raw(raw)
- raw.mime_types.append(
+ raw.mime_types.add(
MimeType(value=f"boefje/{raw.boefje_meta.boefje.id}-ba293f79fd3c809a300a2837bb1da4f7115fc034a1f78")
)
raw_id_3 = meta_repository.save_raw(raw)
diff --git a/bytes/tests/unit/test_context_mapping.py b/bytes/tests/unit/test_context_mapping.py
index 62f303f2c46..61147f84414 100644
--- a/bytes/tests/unit/test_context_mapping.py
+++ b/bytes/tests/unit/test_context_mapping.py
@@ -62,7 +62,7 @@ def test_context_mapping_raw() -> None:
assert raw_data.hash_retrieval_link == raw_data_in_db.hash_retrieval_link
assert raw_data.secure_hash == raw_data_in_db.secure_hash
assert raw_data.signing_provider_url is None
- assert raw_data.mime_types == [to_mime_type(mime_type) for mime_type in raw_data_in_db.mime_types]
+ assert raw_data.mime_types == {to_mime_type(mime_type) for mime_type in raw_data_in_db.mime_types}
raw_data_new = to_raw_data(raw_data_in_db, raw_data.value)
diff --git a/rocky/rocky/bytes_client.py b/rocky/rocky/bytes_client.py
index 837d586ebc9..9ddd97ff9d7 100644
--- a/rocky/rocky/bytes_client.py
+++ b/rocky/rocky/bytes_client.py
@@ -1,4 +1,5 @@
import uuid
+from base64 import b64encode
from collections.abc import Set
from datetime import datetime, timezone
@@ -113,15 +114,25 @@ def _save_normalizer_meta(self, normalizer_meta: NormalizerMeta) -> None:
response.raise_for_status()
def _save_raw(self, boefje_meta_id: uuid.UUID, raw: bytes, mime_types: Set[str] = frozenset()) -> str:
+ file_name = "raw" # The name provides a key for all ids returned, so this is arbitrary as we only upload 1 file
+
response = self.session.post(
"/bytes/raw",
- content=raw,
- headers={"content-type": "application/octet-stream"},
- params={"mime_types": list(mime_types), "boefje_meta_id": str(boefje_meta_id)},
+ json={
+ "files": [
+ {
+ "name": file_name,
+ "content": b64encode(raw).decode(),
+ "tags": list(mime_types),
+ }
+ ]
+ },
+ params={"boefje_meta_id": str(boefje_meta_id)},
)
response.raise_for_status()
- return response.json()["id"]
+
+ return response.json()[file_name]
def get_raw(self, raw_id: str) -> bytes:
# Note: we assume organization permissions are handled before requesting raw data.
From ba809d0ce61c3a0f28940c97bd048582c657481a Mon Sep 17 00:00:00 2001
From: JP Bruins Slot
Date: Tue, 10 Sep 2024 14:19:56 +0200
Subject: [PATCH 4/7] Add report scheduler functionality to scheduler (#3352)
Co-authored-by: ammar92
Co-authored-by: stephanie0x00 <9821756+stephanie0x00@users.noreply.github.com>
Co-authored-by: Jan Klopper
---
mula/scheduler/app.py | 22 ++-
mula/scheduler/models/__init__.py | 2 +-
mula/scheduler/models/task.py | 13 +-
mula/scheduler/schedulers/__init__.py | 1 +
mula/scheduler/schedulers/report.py | 162 ++++++++++++++++++
mula/scheduler/server/handlers/schedules.py | 2 +
mula/tests/integration/test_api.py | 11 ++
mula/tests/integration/test_app.py | 22 +--
.../integration/test_report_scheduler.py | 162 ++++++++++++++++++
9 files changed, 383 insertions(+), 14 deletions(-)
create mode 100644 mula/scheduler/schedulers/report.py
create mode 100644 mula/tests/integration/test_report_scheduler.py
diff --git a/mula/scheduler/app.py b/mula/scheduler/app.py
index 4e0836bed12..833cb3ec2f5 100644
--- a/mula/scheduler/app.py
+++ b/mula/scheduler/app.py
@@ -60,7 +60,10 @@ def __init__(self, ctx: context.AppContext) -> None:
self.schedulers: dict[
str,
- schedulers.Scheduler | schedulers.BoefjeScheduler | schedulers.NormalizerScheduler,
+ schedulers.Scheduler
+ | schedulers.BoefjeScheduler
+ | schedulers.NormalizerScheduler
+ | schedulers.ReportScheduler,
] = {}
self.server: server.Server | None = None
@@ -136,12 +139,21 @@ def monitor_organisations(self) -> None:
callback=self.remove_scheduler,
)
+ scheduler_report = schedulers.ReportScheduler(
+ ctx=self.ctx,
+ scheduler_id=f"report-{org.id}",
+ organisation=org,
+ callback=self.remove_scheduler,
+ )
+
with self.lock:
self.schedulers[scheduler_boefje.scheduler_id] = scheduler_boefje
self.schedulers[scheduler_normalizer.scheduler_id] = scheduler_normalizer
+ self.schedulers[scheduler_report.scheduler_id] = scheduler_report
scheduler_normalizer.run()
scheduler_boefje.run()
+ scheduler_report.run()
if additions:
# Flush katalogus caches when new organisations are added
@@ -201,6 +213,14 @@ def start_schedulers(self) -> None:
)
self.schedulers[normalizer_scheduler.scheduler_id] = normalizer_scheduler
+ report_scheduler = schedulers.ReportScheduler(
+ ctx=self.ctx,
+ scheduler_id=f"report-{org.id}",
+ organisation=org,
+ callback=self.remove_scheduler,
+ )
+ self.schedulers[report_scheduler.scheduler_id] = report_scheduler
+
# Start schedulers
for scheduler in self.schedulers.values():
scheduler.run()
diff --git a/mula/scheduler/models/__init__.py b/mula/scheduler/models/__init__.py
index ed1a7fa177a..a5390ad6ede 100644
--- a/mula/scheduler/models/__init__.py
+++ b/mula/scheduler/models/__init__.py
@@ -9,4 +9,4 @@
from .queue import Queue
from .schedule import Schedule, ScheduleDB
from .scheduler import Scheduler
-from .task import BoefjeTask, NormalizerTask, Task, TaskDB, TaskStatus
+from .task import BoefjeTask, NormalizerTask, ReportTask, Task, TaskDB, TaskStatus
diff --git a/mula/scheduler/models/task.py b/mula/scheduler/models/task.py
index cf0c9a95834..76c51ad2e79 100644
--- a/mula/scheduler/models/task.py
+++ b/mula/scheduler/models/task.py
@@ -59,7 +59,7 @@ class Task(BaseModel):
hash: str | None = Field(None, max_length=32)
- data: dict | None = None
+ data: dict = Field(default_factory=dict)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
modified_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@@ -143,3 +143,14 @@ def hash(self) -> str:
return mmh3.hash_bytes(f"{self.input_ooi}-{self.boefje.id}-{self.organization}").hex()
return mmh3.hash_bytes(f"{self.boefje.id}-{self.organization}").hex()
+
+
+class ReportTask(BaseModel):
+ type: ClassVar[str] = "report"
+
+ organisation_id: str
+ report_recipe_id: str
+
+ @property
+ def hash(self) -> str:
+ return mmh3.hash_bytes(f"{self.report_recipe_id}-{self.organisation_id}").hex()
diff --git a/mula/scheduler/schedulers/__init__.py b/mula/scheduler/schedulers/__init__.py
index 5614508b532..4c82914aee5 100644
--- a/mula/scheduler/schedulers/__init__.py
+++ b/mula/scheduler/schedulers/__init__.py
@@ -1,3 +1,4 @@
from .boefje import BoefjeScheduler
from .normalizer import NormalizerScheduler
+from .report import ReportScheduler
from .scheduler import Scheduler
diff --git a/mula/scheduler/schedulers/report.py b/mula/scheduler/schedulers/report.py
new file mode 100644
index 00000000000..b76d79a13b7
--- /dev/null
+++ b/mula/scheduler/schedulers/report.py
@@ -0,0 +1,162 @@
+from collections.abc import Callable
+from concurrent import futures
+from datetime import datetime, timezone
+from typing import Any
+
+import structlog
+from opentelemetry import trace
+
+from scheduler import context, queues, storage
+from scheduler.models import Organisation, ReportTask, Task
+from scheduler.storage import filters
+
+from .scheduler import Scheduler
+
+tracer = trace.get_tracer(__name__)
+
+
+class ReportScheduler(Scheduler):
+ ITEM_TYPE: Any = ReportTask
+
+ def __init__(
+ self,
+ ctx: context.AppContext,
+ scheduler_id: str,
+ organisation: Organisation,
+ queue: queues.PriorityQueue | None = None,
+ callback: Callable[..., None] | None = None,
+ ):
+ self.logger: structlog.BoundLogger = structlog.get_logger(__name__)
+ self.organisation = organisation
+ self.create_schedule = False
+
+ self.queue = queue or queues.PriorityQueue(
+ pq_id=scheduler_id,
+ maxsize=ctx.config.pq_maxsize,
+ item_type=self.ITEM_TYPE,
+ allow_priority_updates=True,
+ pq_store=ctx.datastores.pq_store,
+ )
+
+ super().__init__(
+ ctx=ctx,
+ queue=self.queue,
+ scheduler_id=scheduler_id,
+ callback=callback,
+ )
+
+ def run(self) -> None:
+ # Rescheduling
+ self.run_in_thread(
+ name=f"scheduler-{self.scheduler_id}-reschedule",
+ target=self.push_tasks_for_rescheduling,
+ interval=60.0,
+ )
+
+ @tracer.start_as_current_span(name="report_push_tasks_for_rescheduling")
+ def push_tasks_for_rescheduling(self):
+ if self.queue.full():
+ self.logger.warning(
+ "Report queue is full, not populating with new tasks",
+ queue_qsize=self.queue.qsize(),
+ organisation_id=self.organisation.id,
+ scheduler_id=self.scheduler_id,
+ )
+ return
+
+ try:
+ schedules, _ = self.ctx.datastores.schedule_store.get_schedules(
+ filters=filters.FilterRequest(
+ filters=[
+ filters.Filter(
+ column="scheduler_id",
+ operator="eq",
+ value=self.scheduler_id,
+ ),
+ filters.Filter(
+ column="deadline_at",
+ operator="lt",
+ value=datetime.now(timezone.utc),
+ ),
+ filters.Filter(
+ column="enabled",
+ operator="eq",
+ value=True,
+ ),
+ ]
+ )
+ )
+ except storage.errors.StorageError as exc_db:
+ self.logger.error(
+ "Could not get schedules for rescheduling %s",
+ self.scheduler_id,
+ scheduler_id=self.scheduler_id,
+ organisation_id=self.organisation.id,
+ exc_info=exc_db,
+ )
+ raise exc_db
+
+ with futures.ThreadPoolExecutor(
+ thread_name_prefix=f"ReportScheduler-TPE-{self.scheduler_id}-rescheduling"
+ ) as executor:
+ for schedule in schedules:
+ report_task = ReportTask.model_validate(schedule.data)
+ executor.submit(
+ self.push_report_task,
+ report_task,
+ self.push_tasks_for_rescheduling.__name__,
+ )
+
+ def push_report_task(self, report_task: ReportTask, caller: str = "") -> None:
+ self.logger.debug(
+ "Pushing report task",
+ task_hash=report_task.hash,
+ organisation_id=self.organisation.id,
+ scheduler_id=self.scheduler_id,
+ caller=caller,
+ )
+
+ if self.is_item_on_queue_by_hash(report_task.hash):
+ self.logger.debug(
+ "Report task already on queue",
+ task_hash=report_task.hash,
+ organisation_id=self.organisation.id,
+ scheduler_id=self.scheduler_id,
+ caller=caller,
+ )
+ return
+
+ task = Task(
+ scheduler_id=self.scheduler_id,
+ priority=int(datetime.now().timestamp()),
+ type=self.ITEM_TYPE.type,
+ hash=report_task.hash,
+ data=report_task.model_dump(),
+ )
+
+ try:
+ self.push_item_to_queue_with_timeout(
+ task,
+ self.max_tries,
+ )
+ except queues.QueueFullError:
+ self.logger.warning(
+ "Could not add task %s to queue, queue was full",
+ report_task.hash,
+ task_hash=report_task.hash,
+ queue_qsize=self.queue.qsize(),
+ queue_maxsize=self.queue.maxsize,
+ organisation_id=self.organisation.id,
+ scheduler_id=self.scheduler_id,
+ caller=caller,
+ )
+ return
+
+ self.logger.info(
+ "Report task pushed to queue",
+ task_id=task.id,
+ task_hash=report_task.hash,
+ organisation_id=self.organisation.id,
+ scheduler_id=self.scheduler_id,
+ caller=caller,
+ )
diff --git a/mula/scheduler/server/handlers/schedules.py b/mula/scheduler/server/handlers/schedules.py
index 33bd17b5aa3..23e44c638b3 100644
--- a/mula/scheduler/server/handlers/schedules.py
+++ b/mula/scheduler/server/handlers/schedules.py
@@ -63,6 +63,7 @@ def __init__(
def list(
self,
request: fastapi.Request,
+ scheduler_id: str | None = None,
schedule_hash: str | None = None,
enabled: bool | None = None,
offset: int = 0,
@@ -86,6 +87,7 @@ def list(
try:
results, count = self.ctx.datastores.schedule_store.get_schedules(
+ scheduler_id=scheduler_id,
schedule_hash=schedule_hash,
enabled=enabled,
min_deadline_at=min_deadline_at,
diff --git a/mula/tests/integration/test_api.py b/mula/tests/integration/test_api.py
index 9f6440c5801..f72c67705c4 100644
--- a/mula/tests/integration/test_api.py
+++ b/mula/tests/integration/test_api.py
@@ -837,6 +837,17 @@ def test_list_schedules(self):
self.assertEqual(2, response.json()["count"])
self.assertEqual(2, len(response.json()["results"]))
+ def test_list_schedules_scheduler_id(self):
+ response = self.client.get(f"/schedules?scheduler_id={self.scheduler.scheduler_id}")
+ self.assertEqual(200, response.status_code)
+ self.assertEqual(2, response.json()["count"])
+ self.assertEqual(2, len(response.json()["results"]))
+
+ response = self.client.get(f"/schedules?scheduler_id={uuid.uuid4()}")
+ self.assertEqual(200, response.status_code)
+ self.assertEqual(0, response.json()["count"])
+ self.assertEqual(0, len(response.json()["results"]))
+
def test_list_schedules_enabled(self):
response = self.client.get("/schedules?enabled=true")
self.assertEqual(200, response.status_code)
diff --git a/mula/tests/integration/test_app.py b/mula/tests/integration/test_app.py
index 4fd380ca3ee..0c5ae739c76 100644
--- a/mula/tests/integration/test_app.py
+++ b/mula/tests/integration/test_app.py
@@ -50,9 +50,9 @@ def test_monitor_orgs_add(self):
# Act
self.app.monitor_organisations()
- # Assert: four schedulers should have been created for two organisations
- self.assertEqual(4, len(self.app.schedulers.keys()))
- self.assertEqual(4, len(self.app.server.schedulers.keys()))
+ # Assert: six schedulers should have been created for two organisations
+ self.assertEqual(6, len(self.app.schedulers.keys()))
+ self.assertEqual(6, len(self.app.server.schedulers.keys()))
scheduler_org_ids = {s.organisation.id for s in self.app.schedulers.values()}
self.assertEqual({"org-1", "org-2"}, scheduler_org_ids)
@@ -68,9 +68,9 @@ def test_monitor_orgs_remove(self):
# Act
self.app.monitor_organisations()
- # Assert: four schedulers should have been created for two organisations
- self.assertEqual(4, len(self.app.schedulers.keys()))
- self.assertEqual(4, len(self.app.server.schedulers.keys()))
+ # Assert: six schedulers should have been created for two organisations
+ self.assertEqual(6, len(self.app.schedulers.keys()))
+ self.assertEqual(6, len(self.app.server.schedulers.keys()))
scheduler_org_ids = {s.organisation.id for s in self.app.schedulers.values()}
self.assertEqual({"org-1", "org-2"}, scheduler_org_ids)
@@ -100,9 +100,9 @@ def test_monitor_orgs_add_and_remove(self):
# Act
self.app.monitor_organisations()
- # Assert: four schedulers should have been created for two organisations
- self.assertEqual(4, len(self.app.schedulers.keys()))
- self.assertEqual(4, len(self.app.server.schedulers.keys()))
+ # Assert: six schedulers should have been created for two organisations
+ self.assertEqual(6, len(self.app.schedulers.keys()))
+ self.assertEqual(6, len(self.app.server.schedulers.keys()))
scheduler_org_ids = {s.organisation.id for s in self.app.schedulers.values()}
self.assertEqual({"org-1", "org-2"}, scheduler_org_ids)
@@ -117,8 +117,8 @@ def test_monitor_orgs_add_and_remove(self):
self.app.monitor_organisations()
# Assert
- self.assertEqual(4, len(self.app.schedulers.keys()))
- self.assertEqual(4, len(self.app.server.schedulers.keys()))
+ self.assertEqual(6, len(self.app.schedulers.keys()))
+ self.assertEqual(6, len(self.app.server.schedulers.keys()))
scheduler_org_ids = {s.organisation.id for s in self.app.schedulers.values()}
self.assertEqual({"org-1", "org-3"}, scheduler_org_ids)
diff --git a/mula/tests/integration/test_report_scheduler.py b/mula/tests/integration/test_report_scheduler.py
new file mode 100644
index 00000000000..f0d93232a47
--- /dev/null
+++ b/mula/tests/integration/test_report_scheduler.py
@@ -0,0 +1,162 @@
+import unittest
+from types import SimpleNamespace
+from unittest import mock
+
+from scheduler import config, models, schedulers, storage
+
+from tests.factories import OrganisationFactory
+
+
+class ReportSchedulerBaseTestCase(unittest.TestCase):
+ def setUp(self):
+ # Application Context
+ self.mock_ctx = mock.patch("scheduler.context.AppContext").start()
+ self.mock_ctx.config = config.settings.Settings()
+
+ # Database
+ self.dbconn = storage.DBConn(str(self.mock_ctx.config.db_uri))
+ self.dbconn.connect()
+ models.Base.metadata.drop_all(self.dbconn.engine)
+ models.Base.metadata.create_all(self.dbconn.engine)
+
+ self.mock_ctx.datastores = SimpleNamespace(
+ **{
+ storage.TaskStore.name: storage.TaskStore(self.dbconn),
+ storage.PriorityQueueStore.name: storage.PriorityQueueStore(self.dbconn),
+ storage.ScheduleStore.name: storage.ScheduleStore(self.dbconn),
+ }
+ )
+
+ # Scheduler
+ self.organisation = OrganisationFactory()
+ self.scheduler = schedulers.ReportScheduler(
+ ctx=self.mock_ctx,
+ scheduler_id=self.organisation.id,
+ organisation=self.organisation,
+ )
+
+ def tearDown(self):
+ self.scheduler.stop()
+ models.Base.metadata.drop_all(self.dbconn.engine)
+ self.dbconn.engine.dispose()
+
+
+class ReportSchedulerTestCase(ReportSchedulerBaseTestCase):
+ def setUp(self):
+ super().setUp()
+
+ self.mock_get_schedules = mock.patch(
+ "scheduler.context.AppContext.datastores.schedule_store.get_schedules",
+ ).start()
+
+ def tearDown(self):
+ mock.patch.stopall()
+
+ def test_enable_scheduler(self):
+ # Disable scheduler first
+ self.scheduler.disable()
+
+ # Threads should be stopped
+ self.assertEqual(0, len(self.scheduler.threads))
+
+ # Queue should be empty
+ self.assertEqual(0, self.scheduler.queue.qsize())
+
+ # Re-enable scheduler
+ self.scheduler.enable()
+
+ # Threads should be started
+ self.assertGreater(len(self.scheduler.threads), 0)
+
+ # Scheduler should be enabled
+ self.assertTrue(self.scheduler.is_enabled())
+
+ # Stop the scheduler
+ self.scheduler.stop()
+
+ def test_disable_scheduler(self):
+ # Disable scheduler
+ self.scheduler.disable()
+
+ # Threads should be stopped
+ self.assertEqual(0, len(self.scheduler.threads))
+
+ # Queue should be empty
+ self.assertEqual(0, self.scheduler.queue.qsize())
+
+ # Scheduler should be disabled
+ self.assertFalse(self.scheduler.is_enabled())
+
+ def test_push_tasks_for_rescheduling(self):
+ """When the deadline of schedules have passed, the resulting task should be added to the queue"""
+ # Arrange
+ report_task = models.ReportTask(
+ organisation_id=self.organisation.id,
+ report_recipe_id="123",
+ )
+
+ schedule = models.Schedule(
+ scheduler_id=self.scheduler.scheduler_id,
+ hash=report_task.hash,
+ data=report_task.dict(),
+ )
+
+ schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule)
+
+ # Mocks
+ self.mock_get_schedules.return_value = ([schedule_db], 1)
+
+ # Act
+ self.scheduler.push_tasks_for_rescheduling()
+
+ # Assert: new item should be on queue
+ self.assertEqual(1, self.scheduler.queue.qsize())
+
+ # Assert: new item is created with a similar task
+ peek = self.scheduler.queue.peek(0)
+ self.assertEqual(schedule.hash, peek.hash)
+
+ # Assert: task should be created, and should be the one that is queued
+ task_db = self.mock_ctx.datastores.task_store.get_task(peek.id)
+ self.assertIsNotNone(task_db)
+ self.assertEqual(peek.id, task_db.id)
+
+ def test_push_tasks_for_rescheduling_item_on_queue(self):
+ """When the deadline of schedules have passed, the resulting task should be added to the queue"""
+ # Arrange
+ report_task = models.ReportTask(
+ organisation_id=self.organisation.id,
+ report_recipe_id="123",
+ )
+
+ schedule = models.Schedule(
+ scheduler_id=self.scheduler.scheduler_id,
+ hash=report_task.hash,
+ data=report_task.dict(),
+ )
+
+ schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule)
+
+ # Mocks
+ self.mock_get_schedules.return_value = ([schedule_db], 1)
+
+ # Act
+ self.scheduler.push_tasks_for_rescheduling()
+
+ # Assert: new item should be on queue
+ self.assertEqual(1, self.scheduler.queue.qsize())
+
+ # Assert: new item is created with a similar task
+ peek = self.scheduler.queue.peek(0)
+ self.assertEqual(schedule.hash, peek.hash)
+
+ # Assert: task should be created, and should be the one that is queued
+ task_db = self.mock_ctx.datastores.task_store.get_task(peek.id)
+ self.assertIsNotNone(task_db)
+ self.assertEqual(peek.id, task_db.id)
+
+ # Act: push again
+ self.scheduler.push_tasks_for_rescheduling()
+
+ # Should only be one task on queue
+ self.assertEqual(1, self.scheduler.queue.qsize())
From 4958776f498b2edc9c54df45293315eea0c1da40 Mon Sep 17 00:00:00 2001
From: Rieven
Date: Tue, 10 Sep 2024 16:58:50 +0200
Subject: [PATCH 5/7] Fix report types selection not being overriden (#3436)
Co-authored-by: Peter-Paul van Gemerden
Co-authored-by: ammar92
Co-authored-by: Jan Klopper
---
.../templates/partials/report_types_selection.html | 9 +++++----
rocky/reports/views/aggregate_report.py | 4 ++++
rocky/reports/views/generate_report.py | 4 +++-
3 files changed, 12 insertions(+), 5 deletions(-)
diff --git a/rocky/reports/templates/partials/report_types_selection.html b/rocky/reports/templates/partials/report_types_selection.html
index 3c401605847..066add7c640 100644
--- a/rocky/reports/templates/partials/report_types_selection.html
+++ b/rocky/reports/templates/partials/report_types_selection.html
@@ -14,7 +14,7 @@
{% translate "Choose report types" %}
{% endblocktranslate %}
{% if not selected_oois %}
- {% include "partials/return_button.html" with btn_text="Go back" %}
+ {% include "partials/return_button.html" with btn_text="Go back" selected_report_types=None %}
{% else %}
@@ -40,7 +40,7 @@
action="{{ previous }}"
class="inline layout-wide">
{% csrf_token %}
- {% include "forms/report_form_fields.html" %}
+ {% include "forms/report_form_fields.html" with selected_report_types=None %}