From ad45c424336d30ed62a4393e2c30b638bfc86d2d Mon Sep 17 00:00:00 2001 From: Avery Date: Thu, 12 Dec 2024 10:07:19 -0800 Subject: [PATCH] Add tags and tickets fields to get cases request (#5564) * Add tags and tickets fields to get cases request * Attempt to move include parameter in GET requests to Cases into database service. * Adding tests * Adding tests * Add tags and ticket fields back to CaseReadMinimal * Switch test from Incident to Case. * remove duplicates * Fix unused variable assignment * Fixing tests --------- Co-authored-by: kevgliss Co-authored-by: Kevin Glisson --- .vscode/settings.json | 20 +-- src/dispatch/case/models.py | 2 + src/dispatch/case/views.py | 1 + src/dispatch/database/service.py | 33 +++-- tests/conftest.py | 42 +++++- tests/database/test_service.py | 225 +++++++++++++++++++++++++++++++ tests/factories.py | 35 ++++- 7 files changed, 323 insertions(+), 35 deletions(-) create mode 100644 tests/database/test_service.py diff --git a/.vscode/settings.json b/.vscode/settings.json index b8d00b662cbd..22c7f0c901d5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -17,29 +17,11 @@ "files.trimFinalNewlines": true, "files.insertFinalNewline": true, "vetur.format.enable": false, - "python.formatting.provider": "black", - "python.formatting.blackArgs": [ - "--line-length", - "100" - ], - "python.linting.enabled": true, - "python.linting.flake8Enabled": true, - "python.linting.flake8Args": [ - "--ignore=E24,W504,E501", - "--verbose" - ], "python.testing.pytestEnabled": true, - "python.sortImports.args": [ - "--settings-path", - "${workspaceFolder}/setup.cfg" - ], - "python.linting.pylintArgs": [ - "--rcfile", - "${workspaceFolder}/setup.cfg" - ], "[python]": { "editor.codeActionsOnSave": { "source.organizeImports": "never" } }, + "codeQL.githubDatabase.update": "never", } diff --git a/src/dispatch/case/models.py b/src/dispatch/case/models.py index 9065db786413..7405bd194282 100644 --- a/src/dispatch/case/models.py +++ b/src/dispatch/case/models.py @@ -288,6 +288,8 @@ class CaseReadMinimal(CaseBase): project: ProjectRead reporter: Optional[ParticipantReadMinimal] reported_at: Optional[datetime] = None + tags: Optional[List[TagRead]] = [] + ticket: Optional[TicketRead] = None total_cost: float | None triage_at: Optional[datetime] = None diff --git a/src/dispatch/case/views.py b/src/dispatch/case/views.py index 34de7640e621..bc70665bf1d0 100644 --- a/src/dispatch/case/views.py +++ b/src/dispatch/case/views.py @@ -117,6 +117,7 @@ def get_cases( expand: bool = Query(default=False), ): """Retrieves all cases.""" + common["include_keys"] = include pagination = search_filter_sort_paginate(model="Case", **common) if expand: diff --git a/src/dispatch/database/service.py b/src/dispatch/database/service.py index ba326ff4edea..d966be4c4cce 100644 --- a/src/dispatch/database/service.py +++ b/src/dispatch/database/service.py @@ -341,8 +341,7 @@ def apply_filters(query, filter_spec, model_cls=None, do_auto_join=True): return query -def apply_filter_specific_joins(model: Base, filter_spec: dict, query: orm.query): - """Applies any model specific implicitly joins.""" +def get_model_map(filters: dict) -> dict: # this is required because by default sqlalchemy-filter's auto-join # knows nothing about how to join many-many relationships. model_map = { @@ -371,19 +370,21 @@ def apply_filter_specific_joins(model: Base, filter_spec: dict, query: orm.query (SignalInstance, "EntityType"): (SignalInstance.entities, True), (Tag, "TagType"): (Tag.tag_type, False), } - filters = build_filters(filter_spec) - # Replace mapping if looking for commander - if "Commander" in str(filter_spec): + if "Commander" in filters: model_map.update({(Incident, "IndividualContact"): (Incident.commander, True)}) - if "Assignee" in str(filter_spec): + if "Assignee" in filters: model_map.update({(Case, "IndividualContact"): (Case.assignee, True)}) + return model_map - filter_models = get_named_models(filters) + +def apply_model_specific_joins(model: Base, models: List[str], query: orm.query): + model_map = get_model_map(models) joined_models = [] - for filter_model in filter_models: - if model_map.get((model, filter_model)): - joined_model, is_outer = model_map[(model, filter_model)] + + for include_model in models: + if model_map.get((model, include_model)): + joined_model, is_outer = model_map[(model, include_model)] try: if joined_model not in joined_models: query = query.join(joined_model, isouter=is_outer) @@ -394,6 +395,14 @@ def apply_filter_specific_joins(model: Base, filter_spec: dict, query: orm.query return query +def apply_filter_specific_joins(model: Base, filter_spec: dict, query: orm.query): + """Applies any model specific implicitly joins.""" + filters = build_filters(filter_spec) + filter_models = get_named_models(filters) + + return apply_model_specific_joins(model, filter_models, query) + + def composite_search(*, db_session, query_str: str, models: List[Base], current_user: DispatchUser): """Perform a multi-table search based on the supplied query.""" s = CompositeSearch(db_session, models) @@ -537,6 +546,7 @@ def search_filter_sort_paginate( model, query_str: str = None, filter_spec: str | dict | None = None, + include_keys: List[str] = None, page: int = 1, items_per_page: int = 5, sort_by: List[str] = None, @@ -574,6 +584,9 @@ def search_filter_sort_paginate( else: query = apply_filters(query, filter_spec, model_cls) + if include_keys: + query = apply_model_specific_joins(model_cls, include_keys, query) + if model == "Incident": query = query.intersect(query_restricted) for filter in tag_all_filters: diff --git a/tests/conftest.py b/tests/conftest.py index 9f36dc7b37fa..5ef624e5f258 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest from sqlalchemy_utils import drop_database, database_exists from starlette.config import environ -from starlette.testclient import TestClient +from fastapi.testclient import TestClient # set test config environ["DATABASE_CREDENTIALS"] = "postgres:dispatch" @@ -19,6 +19,7 @@ from dispatch import config from dispatch.database.core import engine from dispatch.database.manage import init_database +from dispatch.enums import Visibility, UserRoles from .database import Session from .factories import ( @@ -32,6 +33,7 @@ ConversationFactory, DefinitionFactory, DispatchUserFactory, + DispatchUserOrganizationFactory, DocumentFactory, EmailTemplateFactory, EntityFactory, @@ -98,13 +100,13 @@ def pytest_runtest_makereport(item, call): @pytest.fixture(scope="session") def testapp(): # we only want to use test plugins so unregister everybody else - from dispatch.main import app + from dispatch.main import api from dispatch.plugins.base import plugins, unregister for p in plugins.all(): - unregister(p) + unregister(p.__class__) - yield app + yield api @pytest.fixture(scope="session") @@ -137,7 +139,7 @@ def session(db): @pytest.fixture(scope="function") -def client(testapp, session, client): +def client(testapp, session): yield TestClient(testapp) @@ -272,6 +274,18 @@ def user(session): return DispatchUserFactory() +@pytest.fixture +def admin_user(session): + # we need to create a new user with the admin role + user = DispatchUserFactory() + organization = OrganizationFactory() + DispatchUserOrganizationFactory( + dispatch_user=user, organization=organization, role=UserRoles.admin + ) + + return user + + @pytest.fixture def tag(session): return TagFactory() @@ -532,6 +546,24 @@ def incident(session): return IncidentFactory() +@pytest.fixture() +def incidents(session): + return [ + IncidentFactory( + title="Test Incident 1", + description="Description 1", + visibility=Visibility.open, + tags=[TagFactory()], + ), + IncidentFactory( + title="Test Incident 2", description="Description 2", visibility=Visibility.restricted + ), + IncidentFactory( + title="Another Incident", description="Description 3", visibility=Visibility.open + ), + ] + + @pytest.fixture def participant_activity(session): return ParticipantActivityFactory() diff --git a/tests/database/test_service.py b/tests/database/test_service.py new file mode 100644 index 000000000000..250dc0c951a0 --- /dev/null +++ b/tests/database/test_service.py @@ -0,0 +1,225 @@ +import pytest +import json +from json.decoder import JSONDecodeError +from sqlalchemy_filters.exceptions import BadFilterFormat + +from dispatch.database.service import ( + Operator, + Filter, + search_filter_sort_paginate, + restricted_incident_filter, + apply_filters, +) +from dispatch.incident.models import Incident +from dispatch.enums import UserRoles, Visibility + + +# Test the Filter class and related functions +def test_operator_invalid(): + """Tests that invalid operators raise BadFilterFormat.""" + with pytest.raises(BadFilterFormat): + Operator("invalid_operator") + + +def test_filter_missing_field(): + """Tests that missing field raises BadFilterFormat.""" + with pytest.raises(BadFilterFormat): + Filter({}) + + +def test_filter_invalid_spec(): + """Tests that invalid filter spec raises BadFilterFormat.""" + with pytest.raises(BadFilterFormat): + Filter(None) + + +# Test search_filter_sort_paginate +def test_search_filter_sort_paginate_basic(session, user): + """Tests basic functionality of search_filter_sort_paginate.""" + result = search_filter_sort_paginate( + db_session=session, model="Incident", current_user=user, role=UserRoles.member + ) + + assert isinstance(result, dict) + assert "items" in result + assert "itemsPerPage" in result + assert "page" in result + assert "total" in result + + +def test_basic_pagination(session, incidents, admin_user): + """Test basic pagination functionality.""" + result = search_filter_sort_paginate( + db_session=session, + model="Incident", + page=1, + items_per_page=2, + current_user=admin_user, + role=UserRoles.admin, + ) + + assert result["page"] == 1 + assert result["itemsPerPage"] == 2 + assert len(result["items"]) == 2 + + +def test_simple_filter_specification(session, incidents, admin_user): + """Test filtering with simple filter specification.""" + filter_spec = {"field": "visibility", "op": "==", "value": "open"} + + result = search_filter_sort_paginate( + db_session=session, + model="Incident", + filter_spec=json.dumps(filter_spec), + current_user=admin_user, + role=UserRoles.admin, + ) + + assert all(incident.visibility == Visibility.open for incident in result["items"]) + + +def test_sorting_functionality(session, incidents, user): + """Test sorting functionality.""" + result = search_filter_sort_paginate( + db_session=session, + model="Incident", + sort_by=["title"], + descending=[True], + current_user=user, + ) + + titles = [incident.title for incident in result["items"]] + assert titles == sorted(titles, reverse=True) + + +def test_unlimited_pagination(session, incidents, admin_user): + """Test pagination with unlimited items per page.""" + result = search_filter_sort_paginate( + db_session=session, + model="Incident", + items_per_page=-1, + current_user=admin_user, + role=UserRoles.admin, + ) + + assert len(result["items"]) == result["total"] # All items + + +def test_empty_query_string(session, incidents, admin_user): + """Test behavior with empty query string.""" + result = search_filter_sort_paginate( + db_session=session, + model="Incident", + query_str="", + current_user=admin_user, + role=UserRoles.admin, + ) + + assert len(result["items"]) > 0 # Should return all items + + +def test_invalid_filter_spec(session, incidents, user): + """Test behavior with invalid filter specification.""" + with pytest.raises(JSONDecodeError): # Adjust exception type as needed + search_filter_sort_paginate( + db_session=session, + model="Incident", + filter_spec="invalid_json", + current_user=user, + ) + + +def test_pagination_out_of_bounds(session, incidents, user): + """Test pagination when page number is out of bounds.""" + result = search_filter_sort_paginate( + db_session=session, model="Incident", page=999, items_per_page=5, current_user=user + ) + + assert len(result["items"]) == 0 + assert result["page"] == 999 + + +def test_role_based_filtering(session, incidents, user, admin_user): + """Test filtering based on user role.""" + # Test admin access + admin_result = search_filter_sort_paginate( + db_session=session, model="Incident", current_user=admin_user, role=UserRoles.admin + ) + + # Test member access + member_result = search_filter_sort_paginate( + db_session=session, model="Incident", current_user=user, role=UserRoles.member + ) + + assert len(admin_result["items"]) >= len(member_result["items"]) + + +def test_include_keys_functionality(session, case, admin_user): + """Test functionality of include_keys parameter.""" + from dispatch.common.utils.views import create_pydantic_include + from dispatch.case.models import CasePagination + + result = search_filter_sort_paginate( + db_session=session, + model="Case", + include_keys=["tags"], + current_user=admin_user, + role=UserRoles.admin, + ) + + # make sure they are renderable + include_sets = create_pydantic_include(["tags", "title"]) + + include_fields = { + "items": {"__all__": include_sets}, + "itemsPerPage": ..., + "page": ..., + "total": ..., + } + marshalled = json.loads(CasePagination(**result).json(include=include_fields)) + assert "tags" in marshalled["items"][0].keys() + + +# Test restricted filters +def test_restricted_incident_filter_member(session, user): + """Tests incident filtering for member role.""" + query = session.query(Incident) + filtered_query = restricted_incident_filter( + query=query, current_user=user, role=UserRoles.member + ) + + assert filtered_query is not None + + +def test_restricted_incident_filter_admin(session, user): + """Tests incident filtering for admin role.""" + query = session.query(Incident) + filtered_query = restricted_incident_filter( + query=query, current_user=user, role=UserRoles.admin + ) + + assert filtered_query is not None + + +# Test apply_filters +def test_apply_filters_basic(session): + """Tests basic filter application.""" + query = session.query(Incident) + filter_spec = {"field": "title", "op": "==", "value": "Test"} + + filtered_query = apply_filters(query, filter_spec) + assert filtered_query is not None + + +def test_apply_filters_complex(session): + """Tests complex filter application with boolean operations.""" + query = session.query(Incident) + filter_spec = { + "and": [ + {"field": "title", "op": "==", "value": "Test"}, + {"field": "visibility", "op": "==", "value": "open"}, + ] + } + + filtered_query = apply_filters(query, filter_spec) + assert filtered_query is not None diff --git a/tests/factories.py b/tests/factories.py index eb225343be60..cfc84a74c89f 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -15,7 +15,7 @@ from faker.providers import misc from pytz import UTC -from dispatch.auth.models import DispatchUser, hash_password # noqa +from dispatch.auth.models import DispatchUser, DispatchUserOrganization, hash_password # noqa from dispatch.case.models import Case, CaseRead from dispatch.case.priority.models import CasePriority from dispatch.case.severity.models import CaseSeverity @@ -64,6 +64,7 @@ from dispatch.term.models import Term from dispatch.ticket.models import Ticket from dispatch.workflow.models import Workflow, WorkflowInstance +from dispatch.enums import UserRoles, Visibility from .database import Session @@ -126,6 +127,19 @@ def projects(self, create, extracted, **kwargs): self.projects.append(project) +class DispatchUserOrganizationFactory(BaseFactory): + """Dispatch User Organization Factory.""" + + dispatch_user = SubFactory(DispatchUserFactory) + organization = SubFactory(OrganizationFactory) + role = UserRoles.member + + class Meta: + """Factory Configuration.""" + + model = DispatchUserOrganization + + class ProjectFactory(BaseFactory): """Project Factory.""" @@ -770,6 +784,15 @@ class Meta: model = Case + @post_generation + def tags(self, create, extracted, **kwargs): + if not create: + return + + if extracted: + for tag in extracted: + self.tags.append(tag) + class Params: status = "New" @@ -919,6 +942,7 @@ class IncidentFactory(BaseFactory): incident_severity = SubFactory(IncidentSeverityFactory) project = SubFactory(ProjectFactory) conversation = SubFactory(ConversationFactory) + visibility = Visibility.open class Meta: """Factory Configuration.""" @@ -934,6 +958,15 @@ def participants(self, create, extracted, **kwargs): for participant in extracted: self.participants.append(participant) + @post_generation + def tags(self, create, extracted, **kwargs): + if not create: + return + + if extracted: + for tag in extracted: + self.tags.append(tag) + class TaskFactory(ResourceBaseFactory): """Task Factory."""