From be74eb76197abc2230ba4c528e5bf6a3c574ce42 Mon Sep 17 00:00:00 2001 From: Avery Lee Date: Fri, 28 Jun 2024 15:17:56 -0700 Subject: [PATCH] Calculates case costs. --- src/dispatch/case/type/service.py | 14 + src/dispatch/case_cost/service.py | 416 ++++++++++++++++++ tests/case_cost/test_case_cost_service.py | 245 +++++++++++ tests/case_type/test_case_type_service.py | 6 +- tests/conftest.py | 5 + .../test_incident_cost_service.py | 5 +- 6 files changed, 686 insertions(+), 5 deletions(-) create mode 100644 src/dispatch/case_cost/service.py create mode 100644 tests/case_cost/test_case_cost_service.py diff --git a/src/dispatch/case/type/service.py b/src/dispatch/case/type/service.py index 401f8e865bdf..c4b78956893b 100644 --- a/src/dispatch/case/type/service.py +++ b/src/dispatch/case/type/service.py @@ -3,6 +3,8 @@ from sqlalchemy.sql.expression import true +from dispatch.case import service as case_service +from dispatch.case_cost import service as case_cost_service from dispatch.cost_model import service as cost_model_service from dispatch.document import service as document_service from dispatch.exceptions import NotFoundError @@ -165,8 +167,14 @@ def update(*, db_session, case_type: CaseType, case_type_in: CaseTypeUpdate) -> cost_model = cost_model_service.get_cost_model_by_id( db_session=db_session, cost_model_id=case_type_in.cost_model.id ) + should_update_case_cost = case_type.cost_model != cost_model case_type.cost_model = cost_model + # Calculate the cost of all non-closed cases associated with this case type + cases = case_service.get_all_open_by_case_type(db_session=db_session, case_type_id=case_type.id) + for case in cases: + case_cost_service.calculate_case_response_cost(case_id=case.id, db_session=db_session) + if case_type_in.case_template_document: case_template_document = document_service.get( db_session=db_session, document_id=case_type_in.case_template_document.id @@ -202,6 +210,12 @@ def update(*, db_session, case_type: CaseType, case_type_in: CaseTypeUpdate) -> setattr(case_type, field, update_data[field]) db_session.commit() + + if should_update_case_cost: + case_cost_service.update_case_response_cost_for_case_type( + db_session=db_session, case_type=case_type + ) + return case_type diff --git a/src/dispatch/case_cost/service.py b/src/dispatch/case_cost/service.py new file mode 100644 index 000000000000..b909fccaf1ef --- /dev/null +++ b/src/dispatch/case_cost/service.py @@ -0,0 +1,416 @@ +from datetime import datetime, timedelta, timezone +import logging +import math +from typing import List, Optional + +from dispatch.database.core import SessionLocal +from dispatch.cost_model.models import CostModelActivity +from dispatch.case import service as case_service +from dispatch.case.enums import CaseStatus +from dispatch.case.models import Case +from dispatch.case.type.models import CaseType +from dispatch.case_cost_type import service as case_cost_type_service +from dispatch.case_cost_type.models import CaseCostTypeRead +from dispatch.participant import service as participant_service +from dispatch.participant.models import ParticipantRead +from dispatch.participant_activity import service as participant_activity_service +from dispatch.participant_activity.models import ParticipantActivityCreate +from dispatch.participant_role.models import ParticipantRoleType, ParticipantRole +from dispatch.plugin import service as plugin_service + +from .models import CaseCost, CaseCostCreate, CaseCostUpdate + + +HOURS_IN_DAY = 24 +SECONDS_IN_HOUR = 3600 +log = logging.getLogger(__name__) + + +def get(*, db_session, case_cost_id: int) -> Optional[CaseCost]: + """Gets an case cost by its id.""" + return db_session.query(CaseCost).filter(CaseCost.id == case_cost_id).one_or_none() + + +def get_by_case_id(*, db_session, case_id: int) -> List[Optional[CaseCost]]: + """Gets case costs by their case id.""" + return db_session.query(CaseCost).filter(CaseCost.case_id == case_id).all() + + +def get_by_case_id_and_case_cost_type_id( + *, db_session, case_id: int, case_cost_type_id: int +) -> Optional[CaseCost]: + """Gets case costs by their case id and case cost type id.""" + return ( + db_session.query(CaseCost) + .filter(CaseCost.case_id == case_id) + .filter(CaseCost.case_cost_type_id == case_cost_type_id) + .one_or_none() + ) + + +def get_all(*, db_session) -> List[Optional[CaseCost]]: + """Gets all case costs.""" + return db_session.query(CaseCost) + + +def get_or_create(*, db_session, case_cost_in: CaseCostCreate | CaseCostUpdate) -> CaseCost: + """Gets or creates an case cost object.""" + if type(case_cost_in) is CaseCostUpdate and case_cost_in.id: + case_cost = get(db_session=db_session, case_cost_id=case_cost_in.id) + else: + case_cost = create(db_session=db_session, case_cost_in=case_cost_in) + + return case_cost + + +def create(*, db_session, case_cost_in: CaseCostCreate) -> CaseCost: + """Creates a new case cost.""" + case_cost_type = case_cost_type_service.get( + db_session=db_session, case_cost_type_id=case_cost_in.case_cost_type.id + ) + case_cost = CaseCost( + **case_cost_in.dict(exclude={"case_cost_type", "project"}), + case_cost_type=case_cost_type, + project=case_cost_type.project, + ) + db_session.add(case_cost) + db_session.commit() + + return case_cost + + +def update(*, db_session, case_cost: CaseCost, case_cost_in: CaseCostUpdate) -> CaseCost: + """Updates an case cost.""" + case_cost_data = case_cost.dict() + update_data = case_cost_in.dict(skip_defaults=True) + + for field in case_cost_data: + if field in update_data: + setattr(case_cost, field, update_data[field]) + + db_session.commit() + return case_cost + + +def delete(*, db_session, case_cost_id: int): + """Deletes an existing case cost.""" + db_session.query(CaseCost).filter(CaseCost.id == case_cost_id).delete() + db_session.commit() + + +def get_hourly_rate(project) -> int: + """Calculates and rounds up the employee hourly rate within a project.""" + return math.ceil(project.annual_employee_cost / project.business_year_hours) + + +def update_case_response_cost_for_case_type(db_session, case_type: CaseType) -> None: + """Calculate the response cost of all non-closed cases associated with this case type.""" + cases = case_service.get_all_open_by_case_type(db_session=db_session, case_type_id=case_type.id) + for case in cases: + update_case_response_cost(case_id=case.id, db_session=db_session) + + +def calculate_response_cost(hourly_rate, total_response_time_seconds) -> int: + """Calculates and rounds up the case response cost.""" + return math.ceil(((total_response_time_seconds / SECONDS_IN_HOUR)) * hourly_rate) + + +def get_default_case_response_cost(case: Case, db_session: SessionLocal) -> Optional[CaseCost]: + response_cost_type = case_cost_type_service.get_default( + db_session=db_session, project_id=case.project.id + ) + + if not response_cost_type: + log.warning( + f"A default cost type for response cost doesn't exist in the {case.project.name} project and organization {case.project.organization.name}. Response costs for case {case.name} won't be calculated." + ) + return None + + return get_by_case_id_and_case_cost_type_id( + db_session=db_session, + case_id=case.id, + case_cost_type_id=response_cost_type.id, + ) + + +def get_or_create_default_case_response_cost( + case: Case, db_session: SessionLocal +) -> Optional[CaseCost]: + """Gets or creates the default case cost for an case. + + The default case cost is the cost associated with the participant effort in an case's response. + """ + response_cost_type = case_cost_type_service.get_default( + db_session=db_session, project_id=case.project.id + ) + + if not response_cost_type: + log.warning( + f"A default cost type for response cost doesn't exist in the {case.project.name} project and organization {case.project.organization.name}. Response costs for case {case.name} won't be calculated." + ) + return None + + case_response_cost = get_by_case_id_and_case_cost_type_id( + db_session=db_session, + case_id=case.id, + case_cost_type_id=response_cost_type.id, + ) + + if not case_response_cost: + # we create the response cost if it doesn't exist + case_cost_type = CaseCostTypeRead.from_orm(response_cost_type) + case_cost_in = CaseCostCreate(case_cost_type=case_cost_type, project=case.project) + case_response_cost = create(db_session=db_session, case_cost_in=case_cost_in) + case.case_costs.append(case_response_cost) + db_session.add(case) + db_session.commit() + + return case_response_cost + + +def fetch_case_events( + case: Case, activity: CostModelActivity, oldest: str, db_session: SessionLocal +) -> List[Optional[tuple[datetime.timestamp, str]]]: + plugin_instance = plugin_service.get_active_instance_by_slug( + db_session=db_session, + slug=activity.plugin_event.plugin.slug, + project_id=case.project.id, + ) + if not plugin_instance: + log.warning( + f"Cannot fetch cost model activity. Its associated plugin {activity.plugin_event.plugin.title} is not enabled." + ) + return [] + + # Array of sorted (timestamp, user_id) tuples. + return plugin_instance.instance.fetch_events( + db_session=db_session, + subject=case, + plugin_event_id=activity.plugin_event.id, + oldest=oldest, + ) + + +def calculate_case_response_cost_with_cost_model(case: Case, db_session: SessionLocal) -> int: + """Calculates the cost of an case using the case's cost model. + + This function aggregates all new case costs based on plugin activity since the last case cost update. + If this is the first time performing cost calculation for this case, it computes the total costs from the case's creation. + + Args: + case: The case to calculate the case response cost for. + db_session: The database session. + + Returns: + int: The case response cost in dollars. + """ + + participants_total_response_time_seconds = 0 + oldest = case.created_at.replace(tzinfo=timezone.utc).timestamp() + + # Used for determining whether we've previously calculated the case cost. + current_time = datetime.now(tz=timezone.utc).replace(tzinfo=None) + + case_response_cost = get_or_create_default_case_response_cost(case=case, db_session=db_session) + if not case_response_cost: + log.warning(f"Cannot calculate case response cost for case {case.name}.") + return 0 + + # Ignore events that happened before the last case cost update. + if case_response_cost.updated_at < current_time: + oldest = case_response_cost.updated_at.replace(tzinfo=timezone.utc).timestamp() + + if case.case_type.cost_model: + # Get the cost model. Iterate through all the listed activities we want to record. + for activity in case.case_type.cost_model.activities: + + # Array of sorted (timestamp, user_id) tuples. + case_events = fetch_case_events( + case=case, activity=activity, oldest=oldest, db_session=db_session + ) + + for ts, user_id in case_events: + participant = participant_service.get_by_case_id_and_conversation_id( + db_session=db_session, + case_id=case.id, + user_conversation_id=user_id, + ) + if not participant: + log.warning("Cannot resolve participant.") + continue + + activity_in = ParticipantActivityCreate( + plugin_event=activity.plugin_event, + started_at=ts, + ended_at=ts + timedelta(seconds=activity.response_time_seconds), + participant=ParticipantRead(id=participant.id), + case=case, + ) + + if participant_response_time := participant_activity_service.create_or_update( + db_session=db_session, activity_in=activity_in + ): + participants_total_response_time_seconds += ( + participant_response_time.total_seconds() + ) + + hourly_rate = get_hourly_rate(case.project) + amount = calculate_response_cost( + hourly_rate=hourly_rate, + total_response_time_seconds=participants_total_response_time_seconds, + ) + + return case.total_cost + amount + + +def get_participant_role_time_seconds( + case: Case, participant_role: ParticipantRole, start_at: datetime +) -> int: + """Calculates the time spent by a participant in an case role starting from a given time. + + Args: + case: The case the participant is part of. + participant_role: The role of the participant and the time they assumed and renounced the role. + start_at: Only time spent after this will be considered. + + Returns: + int: The time spent by the participant in the case role in seconds. + """ + if participant_role.renounced_at and participant_role.renounced_at < start_at: + # skip calculating already-recorded activity + return 0 + + if participant_role.role == ParticipantRoleType.observer: + # skip calculating cost for participants with the observer role + return 0 + + if participant_role.activity == 0: + # skip calculating cost for roles that have no activity + return 0 + + participant_role_assumed_at = participant_role.assumed_at + + # we set the renounced_at default time to the current time + participant_role_renounced_at = datetime.now(tz=timezone.utc).replace(tzinfo=None) + + if case.status in [CaseStatus.new, CaseStatus.triage]: + if participant_role.renounced_at: + # the participant left the conversation or got assigned another role + # we use the role's renounced_at time + participant_role_renounced_at = participant_role.renounced_at + else: + # we set the renounced_at default time to when the case was marked as escalated or closed + if case.escalated_at: + participant_role_renounced_at = case.escalated_at + + if case.closed_at: + participant_role_renounced_at = case.closed_at + + if participant_role.renounced_at: + # the participant left the conversation or got assigned another role + if participant_role.renounced_at < participant_role_renounced_at: + # we use the role's renounced_at time if it happened before the + # case was marked as stable or closed + participant_role_renounced_at = participant_role.renounced_at + + # the time the participant has spent in the case role since the last case cost update + participant_role_time = participant_role_renounced_at - max( + participant_role_assumed_at, start_at + ) + if participant_role_time.total_seconds() < 0: + # the participant was added after the case was marked as stable + return 0 + + # we calculate the number of hours the participant has spent in the case role + participant_role_time_hours = participant_role_time.total_seconds() / SECONDS_IN_HOUR + + # we make the assumption that participants only spend 8 hours a day working on the case, + # if the case goes past 24hrs + # TODO(mvilanova): adjust based on case priority + if participant_role_time_hours > HOURS_IN_DAY: + days, hours = divmod(participant_role_time_hours, HOURS_IN_DAY) + participant_role_time_hours = math.ceil(((days * HOURS_IN_DAY) / 3) + hours) + + # we make the assumption that participants spend more or less time based on their role + # and we adjust the time spent based on that + return participant_role_time_hours * SECONDS_IN_HOUR + + +def get_total_participant_roles_time_seconds(case: Case, start_at: datetime) -> int: + """Calculates the time spent by all participants in this case starting from a given time. + + Args: + case: The case the participant is part of. + participant_role: The role of the participant and the time they assumed and renounced the role. + start_at: Only time spent after this will be considered. + + Returns: + int: The total time spent by all participants in the case roles in seconds. + + """ + total_participants_roles_time_seconds = 0 + for participant in case.participants: + for participant_role in participant.participant_roles: + total_participants_roles_time_seconds += get_participant_role_time_seconds( + case=case, + participant_role=participant_role, + start_at=start_at, + ) + return total_participants_roles_time_seconds + + +def calculate_case_response_cost(case_id: int, db_session: SessionLocal) -> int: + """Calculates the response cost of a given case. + + If there is no cost model, the case cost will not be calculated. + """ + case = case_service.get(db_session=db_session, case_id=case_id) + if not case: + log.warning(f"Case with id {case_id} not found.") + return 0 + + case_type = case.case_type + if not case_type: + print(f"Case type for case {case.name} not found.") + return case.total_cost + + if not case_type.cost_model: + log.debug("No case cost model found. Skipping this case.") + return case.total_cost + + if not case_type.cost_model.enabled: + log.debug("Case cost model is not enabled. Skipping this case.") + return case.total_cost + + log.debug(f"Calculating {case.name} case cost with model {case_type.cost_model}.") + return calculate_case_response_cost_with_cost_model(case=case, db_session=db_session) + + +def update_case_response_cost(case_id: int, db_session: SessionLocal) -> int: + """Updates the response cost of a given case. + + Args: + case_id: The case id. + db_session: The database session. + + Returns: + int: The case response cost in dollars. + """ + case = case_service.get(db_session=db_session, case_id=case_id) + + amount = calculate_case_response_cost(case_id=case.id, db_session=db_session) + + case_response_cost = get_default_case_response_cost(case=case, db_session=db_session) + + if not case_response_cost: + log.warning(f"Cannot calculate case response cost for case {case.name}.") + return 0 + + # we update the cost amount only if the case cost has changed + if case_response_cost.amount != amount: + case_response_cost.amount = amount + case.case_costs.append(case_response_cost) + db_session.add(case) + db_session.commit() + + return case_response_cost.amount diff --git a/tests/case_cost/test_case_cost_service.py b/tests/case_cost/test_case_cost_service.py new file mode 100644 index 000000000000..4ac017e71b56 --- /dev/null +++ b/tests/case_cost/test_case_cost_service.py @@ -0,0 +1,245 @@ +def test_get(session, case_cost): + from dispatch.case_cost.service import get + + t_case_cost = get(db_session=session, case_cost_id=case_cost.id) + assert t_case_cost.id == case_cost.id + + +def test_get_by_case_id(session, case_cost): + from dispatch.case_cost.service import get_by_case_id + + assert get_by_case_id(db_session=session, case_id=case_cost.case_id) + + +def test_get_all(session, case_costs): + from dispatch.case_cost.service import get_all + + t_case_costs = get_all(db_session=session).all() + assert t_case_costs + + +def test_create(session, case_cost_type, project): + from dispatch.case_cost.service import create + from dispatch.case_cost.models import CaseCostCreate + + amount = 10000 + + case_cost_in = CaseCostCreate( + amount=amount, + case_cost_type=case_cost_type, + project=project, + ) + case_cost = create(db_session=session, case_cost_in=case_cost_in) + assert case_cost + + +def test_get_or_create__create(session, case_cost_type, project): + from dispatch.case_cost.service import get_or_create + from dispatch.case_cost.models import CaseCostCreate + + amount = 10000 + + case_cost_in = CaseCostCreate( + amount=amount, + case_cost_type=case_cost_type, + project=project, + ) + case_cost = get_or_create(db_session=session, case_cost_in=case_cost_in) + assert case_cost + + +def test_update(session, case_cost, case_cost_type): + from dispatch.case_cost.service import update + from dispatch.case_cost.models import CaseCostUpdate + + amount = 10001 + + case_cost_in = CaseCostUpdate(amount=amount, case_cost_type=case_cost_type) + case_cost = update(db_session=session, case_cost=case_cost, case_cost_in=case_cost_in) + assert case_cost.amount == amount + + +def test_delete(session, case_cost): + from dispatch.case_cost.service import delete, get + + delete(db_session=session, case_cost_id=case_cost.id) + assert not get(db_session=session, case_cost_id=case_cost.id) + + +def test_fetch_case_event__enabled_plugins( + case, + cost_model_activity, + session, + conversation_plugin_instance, +): + from dispatch.case_cost.service import fetch_case_events + + conversation_plugin_instance.project_id = case.project.id + cost_model_activity.plugin_event.plugin = conversation_plugin_instance.plugin + conversation_plugin_instance.enabled = True + + assert fetch_case_events(case, cost_model_activity, oldest="0", db_session=session) + + +def test_fetch_case_event__no_enabled_plugins( + case, + cost_model_activity, + session, + conversation_plugin_instance, +): + from dispatch.case_cost.service import fetch_case_events + + conversation_plugin_instance.project_id = case.project.id + cost_model_activity.plugin_event.plugin = conversation_plugin_instance.plugin + conversation_plugin_instance.enabled = False + + assert not fetch_case_events(case, cost_model_activity, oldest="0", db_session=session) + + +def test_calculate_case_response_cost( + session, + case, + case_cost_type, + cost_model_activity, + conversation_plugin_instance, + conversation, + participant, +): + """Tests that the case cost is calculated correctly when a cost model is enabled.""" + from datetime import timedelta + import math + from dispatch.case_cost.service import update_case_response_cost, get_hourly_rate + from dispatch.case_cost_type import service as case_cost_type_service + from dispatch.participant_activity.service import ( + get_all_case_participant_activities_for_case, + ) + + SECONDS_IN_HOUR = 3600 + orig_total_case_cost = case.total_cost + + # Set incoming plugin events. + conversation_plugin_instance.project_id = case.project.id + cost_model_activity.plugin_event.plugin = conversation_plugin_instance.plugin + participant.user_conversation_id = "0XDECAFBAD" + participant.case = case + + # Set up a default case costs type. + for cost_type in case_cost_type_service.get_all(db_session=session): + cost_type.default = False + case_cost_type.default = True + case_cost_type.project = case.project + + # Set up case. + case.case_type.cost_model.enabled = True + case.case_type.cost_model.activities = [cost_model_activity] + + case.conversation = conversation + case.dedicated_channel = True + + # Calculates and updates the case cost. + cost = update_case_response_cost(case_id=case.id, db_session=session) + activities = get_all_case_participant_activities_for_case(db_session=session, case_id=case.id) + assert activities + + # Evaluate expected case cost. + participants_total_response_time_seconds = timedelta(seconds=0) + for activity in activities: + participants_total_response_time_seconds += activity.ended_at - activity.started_at + hourly_rate = get_hourly_rate(case.project) + expected_case_cost = ( + math.ceil( + (participants_total_response_time_seconds.seconds / SECONDS_IN_HOUR) * hourly_rate + ) + + orig_total_case_cost + ) + + assert cost + assert cost == expected_case_cost + assert cost == case.total_cost + + +def test_calculate_case_response_cost__no_enabled_plugins( + session, + case, + case_cost_type, + cost_model_activity, + plugin_instance, + conversation, + participant, +): + """Tests that the case cost is calculated correctly when a cost model is enabled.""" + from dispatch.case.service import get + from dispatch.case_cost.service import update_case_response_cost + from dispatch.case_cost_type import service as case_cost_type_service + from dispatch.participant_activity.service import ( + get_all_case_participant_activities_for_case, + ) + + # Disable the plugin instance for the cost model plugin event. + plugin_instance.project_id = case.project.id + plugin_instance.enabled = False + cost_model_activity.plugin_event.plugin = plugin_instance.plugin + participant.user_conversation_id = "0XDECAFBAD" + participant.case = case + + # Set up a default case costs type. + for cost_type in case_cost_type_service.get_all(db_session=session): + cost_type.default = False + case_cost_type.default = True + case_cost_type.project = case.project + + # Set up case. + case = get(db_session=session, case_id=case.id) + case.case_type.cost_model.enabled = True + case.case_type.cost_model.activities = [cost_model_activity] + case.conversation = conversation + + # Calculates and updates the case cost. + cost = update_case_response_cost(case_id=case.id, db_session=session) + activities = get_all_case_participant_activities_for_case(db_session=session, case_id=case.id) + assert not activities + assert not cost + assert not case.total_cost + + +def test_update_case_response_cost__no_cost_model(case, session, case_cost_type): + """Tests that the case response cost is not created if the case type has no cost model.""" + from dispatch.case import service as case_service + from dispatch.case_cost.service import update_case_response_cost + from dispatch.case_cost_type import service as case_cost_type_service + + # Set up a default case costs type. + for cost_type in case_cost_type_service.get_all(db_session=session): + cost_type.default = False + case_cost_type.default = True + case_cost_type.project = case.project + + case = case_service.get(db_session=session, case_id=case.id) + case.case_type.cost_model = None + + # The case response cost should not be created without a cost model. + case_response_cost_amount = update_case_response_cost(case_id=case.id, db_session=session) + + assert not case_response_cost_amount + + +def test_update_case_response_cost__fail(case, session): + """Tests that the case response cost is not created if the project has no default cost_type.""" + from dispatch.case import service as case_service + from dispatch.case_cost.service import ( + update_case_response_cost, + get_by_case_id, + ) + from dispatch.case_cost_type import service as case_cost_type_service + + case = case_service.get(db_session=session, case_id=case.id) + + # Ensure there is no default cost type for this project. + for cost_type in case_cost_type_service.get_all(db_session=session): + cost_type.default = False + + # Fail to create the inital case response cost. + assert not update_case_response_cost(case_id=case.id, db_session=session) + + # Validate that the case cost was not created nor saved in the database. + assert not get_by_case_id(db_session=session, case_id=case.id) diff --git a/tests/case_type/test_case_type_service.py b/tests/case_type/test_case_type_service.py index 8c9d519eb894..a8eff1c8af45 100644 --- a/tests/case_type/test_case_type_service.py +++ b/tests/case_type/test_case_type_service.py @@ -52,7 +52,7 @@ def test_update(session, case_type): assert case_type.name == name -def test_update_cost_model(session, case, case_type, cost_model, case_cost_type): +def test_update_cost_model(session, case, case_type, cost_model, case_cost, case_cost_type): """Updating the cost model field should immediately update the case cost of all cases with this case type.""" from dispatch.case.models import CaseStatus from dispatch.case.type.service import update @@ -77,6 +77,10 @@ def test_update_cost_model(session, case, case_type, cost_model, case_cost_type) case_cost_type.project = case_type.project case_cost_type.default = True + case_cost.project = case_type.project + case_cost.case_id = case.id + case_cost.case_cost_type = case_cost_type + cost_model.project = case_type.project case_type_in.cost_model = cost_model diff --git a/tests/conftest.py b/tests/conftest.py index 82f1cc596aa8..621cf45ce3b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -520,6 +520,11 @@ def case_type(session): return CaseTypeFactory() +@pytest.fixture +def case_types(session): + return [CaseTypeFactory(), CaseTypeFactory()] + + @pytest.fixture def incident(session): return IncidentFactory() diff --git a/tests/incident_cost/test_incident_cost_service.py b/tests/incident_cost/test_incident_cost_service.py index aef8f2afb562..046dd94c1370 100644 --- a/tests/incident_cost/test_incident_cost_service.py +++ b/tests/incident_cost/test_incident_cost_service.py @@ -110,13 +110,11 @@ def test_calculate_incident_response_cost_with_cost_model( """Tests that the incident cost is calculated correctly when a cost model is enabled.""" from datetime import timedelta import math - from dispatch.incident.service import get from dispatch.incident_cost.service import update_incident_response_cost, get_hourly_rate from dispatch.incident_cost_type import service as incident_cost_type_service from dispatch.participant_activity.service import ( get_all_incident_participant_activities_for_incident, ) - from dispatch.plugins.dispatch_slack.events import ChannelActivityEvent SECONDS_IN_HOUR = 3600 orig_total_incident_cost = incident.total_cost @@ -134,10 +132,9 @@ def test_calculate_incident_response_cost_with_cost_model( incident_cost_type.project = incident.project # Set up incident. - incident = get(db_session=session, incident_id=incident.id) - cost_model_activity.plugin_event.slug = ChannelActivityEvent.slug incident.incident_type.cost_model.enabled = True incident.incident_type.cost_model.activities = [cost_model_activity] + incident.conversation = conversation # Calculates and updates the incident cost.