From 418fdad1303f40e44729423395662721f108acc3 Mon Sep 17 00:00:00 2001 From: ziv Date: Sun, 28 Apr 2024 15:29:37 +0300 Subject: [PATCH] Adding location accuracy filter. Add resolution parameter to get_accidents_stats() to add the location accuracy filter. In widgets that do not use it, handle the filter dicts that is passed to get_quary, or the filter directly. --- .../accident_count_by_accident_type_widget.py | 6 +- .../accident_count_by_accident_year_widget.py | 1 + .../accident_count_by_day_night_widget.py | 1 + .../accident_count_by_severity_widget.py | 12 +- .../accidents_heat_map_widget.py | 11 +- .../injured_count_by_accident_year_widget.py | 1 + .../injured_count_by_severity_widget.py | 2 +- ...ured_count_per_age_group_stacked_widget.py | 4 +- ...njured_count_per_age_group_widget_utils.py | 34 +++--- .../most_severe_accidents_table_widget.py | 5 +- .../vision_zero_bike_widget.py | 2 +- .../accident_count_by_car_type_widget.py | 3 + .../accident_count_by_driver_type_widget.py | 3 +- .../accident_count_by_hour_widget.py | 1 + .../accident_count_by_road_light_widget.py | 1 + ...ype_vehicle_type_road_comparison_widget.py | 19 ++- .../fatal_accident_yoy_same_month.py | 1 + .../front_to_side_accidents_by_severity.py | 11 +- .../head_on_collisions_comparison_widget.py | 3 +- ...cycle_accidents_vs_all_accidents_widget.py | 27 +++-- .../road2_plus1_widget.py | 1 + .../suburban_crosswalk_widget.py | 8 +- ...p_road_segments_accidents_per_km_widget.py | 6 +- ...jured_accidents_with_pedestrians_widget.py | 16 ++- ...e_fatal_count_by_vehicle_by_year_widget.py | 6 +- ...motor_severe_fatal_count_by_year_widget.py | 4 +- .../urban_widgets/urban_crosswalk_widget.py | 5 +- anyway/widgets/widget_utils.py | 109 +++++++++++------ tests/test_infographic_api.py | 9 +- tests/test_infographics_utils.py | 112 +++++++++++++++++- tests/test_request_params.py | 22 ++-- 31 files changed, 333 insertions(+), 113 deletions(-) diff --git a/anyway/widgets/all_locations_widgets/accident_count_by_accident_type_widget.py b/anyway/widgets/all_locations_widgets/accident_count_by_accident_type_widget.py index 2ddefca59..89e85d749 100644 --- a/anyway/widgets/all_locations_widgets/accident_count_by_accident_type_widget.py +++ b/anyway/widgets/all_locations_widgets/accident_count_by_accident_type_widget.py @@ -25,14 +25,15 @@ def __init__(self, request_params: RequestParams): def generate_items(self) -> None: # noinspection PyUnresolvedReferences - self.items = AccidentCountByAccidentTypeWidget.get_accident_count_by_accident_type( + self.items = self.get_accident_count_by_accident_type( location_info=self.request_params.location_info, start_time=self.request_params.start_time, end_time=self.request_params.end_time, + resolution=self.request_params.resolution ) @staticmethod - def get_accident_count_by_accident_type(location_info, start_time, end_time): + def get_accident_count_by_accident_type(location_info, start_time, end_time, resolution): all_accident_type_count = get_accidents_stats( table_obj=AccidentMarkerView, filters=location_info, @@ -40,6 +41,7 @@ def get_accident_count_by_accident_type(location_info, start_time, end_time): count="accident_type", start_time=start_time, end_time=end_time, + resolution=resolution ) merged_accident_type_count = [{"accident_type": "Collision", "count": 0}] for item in all_accident_type_count: diff --git a/anyway/widgets/all_locations_widgets/accident_count_by_accident_year_widget.py b/anyway/widgets/all_locations_widgets/accident_count_by_accident_year_widget.py index 8846f8139..9bc229a0d 100644 --- a/anyway/widgets/all_locations_widgets/accident_count_by_accident_year_widget.py +++ b/anyway/widgets/all_locations_widgets/accident_count_by_accident_year_widget.py @@ -37,6 +37,7 @@ def generate_items(self) -> None: count="accident_severity", start_time=self.request_params.start_time, end_time=self.request_params.end_time, + resolution=self.request_params.resolution, ) res2 = sort_and_fill_gaps_for_stacked_bar( res1, diff --git a/anyway/widgets/all_locations_widgets/accident_count_by_day_night_widget.py b/anyway/widgets/all_locations_widgets/accident_count_by_day_night_widget.py index ecbe03ba8..74682a8d7 100644 --- a/anyway/widgets/all_locations_widgets/accident_count_by_day_night_widget.py +++ b/anyway/widgets/all_locations_widgets/accident_count_by_day_night_widget.py @@ -28,6 +28,7 @@ def generate_items(self) -> None: count="day_night_hebrew", start_time=self.request_params.start_time, end_time=self.request_params.end_time, + resolution=self.request_params.resolution, ) @staticmethod diff --git a/anyway/widgets/all_locations_widgets/accident_count_by_severity_widget.py b/anyway/widgets/all_locations_widgets/accident_count_by_severity_widget.py index dd73e8758..de7846c2e 100644 --- a/anyway/widgets/all_locations_widgets/accident_count_by_severity_widget.py +++ b/anyway/widgets/all_locations_widgets/accident_count_by_severity_widget.py @@ -19,14 +19,15 @@ def __init__(self, request_params: RequestParams): self.rank = 1 def generate_items(self) -> None: - self.items = AccidentCountBySeverityWidget.get_accident_count_by_severity( - self.request_params.location_info, - self.request_params.start_time, - self.request_params.end_time, + self.items = self.get_accident_count_by_severity( + location_info=self.request_params.location_info, + start_time=self.request_params.start_time, + end_time=self.request_params.end_time, + resolution=self.request_params.resolution, ) @staticmethod - def get_accident_count_by_severity(location_info, start_time, end_time): + def get_accident_count_by_severity(location_info, start_time, end_time, resolution): count_by_severity = get_accidents_stats( table_obj=AccidentMarkerView, filters=location_info, @@ -34,6 +35,7 @@ def get_accident_count_by_severity(location_info, start_time, end_time): count="accident_severity", start_time=start_time, end_time=end_time, + resolution=resolution ) found_severities = [d["accident_severity"] for d in count_by_severity] items = {} diff --git a/anyway/widgets/all_locations_widgets/accidents_heat_map_widget.py b/anyway/widgets/all_locations_widgets/accidents_heat_map_widget.py index 7e7a7f0c0..6d9f5a058 100644 --- a/anyway/widgets/all_locations_widgets/accidents_heat_map_widget.py +++ b/anyway/widgets/all_locations_widgets/accidents_heat_map_widget.py @@ -5,7 +5,9 @@ from anyway.request_params import RequestParams from anyway.backend_constants import AccidentSeverity, BE_CONST -from anyway.widgets.widget_utils import get_query, get_location_text +from anyway.widgets.widget_utils import ( + get_query, get_location_text, add_resolution_location_accuracy_filter +) from anyway.models import AccidentMarkerView from anyway.widgets.widget import register from anyway.widgets.all_locations_widgets.all_locations_widget import AllLocationsWidget @@ -21,14 +23,17 @@ def __init__(self, request_params: RequestParams): self.rank = 7 def generate_items(self) -> None: - accidents_heat_map_filters = self.request_params.location_info.copy() + accidents_heat_map_filters = add_resolution_location_accuracy_filter( + self.request_params.location_info.copy(), + self.request_params.resolution + ) accidents_heat_map_filters["accident_severity"] = [ # pylint: disable=no-member AccidentSeverity.FATAL.value, # pylint: disable=no-member AccidentSeverity.SEVERE.value, ] - self.items = AccidentsHeatMapWidget.get_accidents_heat_map( + self.items = self.get_accidents_heat_map( filters=accidents_heat_map_filters, start_time=self.request_params.start_time, end_time=self.request_params.end_time, diff --git a/anyway/widgets/all_locations_widgets/injured_count_by_accident_year_widget.py b/anyway/widgets/all_locations_widgets/injured_count_by_accident_year_widget.py index 8307559fe..16764a5da 100644 --- a/anyway/widgets/all_locations_widgets/injured_count_by_accident_year_widget.py +++ b/anyway/widgets/all_locations_widgets/injured_count_by_accident_year_widget.py @@ -37,6 +37,7 @@ def generate_items(self) -> None: count="injury_severity", start_time=self.request_params.start_time, end_time=self.request_params.end_time, + resolution=self.request_params.resolution, ) res2 = sort_and_fill_gaps_for_stacked_bar( res1, diff --git a/anyway/widgets/all_locations_widgets/injured_count_by_severity_widget.py b/anyway/widgets/all_locations_widgets/injured_count_by_severity_widget.py index 830615770..96bfedf86 100644 --- a/anyway/widgets/all_locations_widgets/injured_count_by_severity_widget.py +++ b/anyway/widgets/all_locations_widgets/injured_count_by_severity_widget.py @@ -46,7 +46,6 @@ def get_injured_count_by_severity( InjurySeverity.SEVERE_INJURED.value, InjurySeverity.LIGHT_INJURED.value, ] - count_by_severity = get_accidents_stats( table_obj=InvolvedMarkerView, filters=filters, @@ -54,6 +53,7 @@ def get_injured_count_by_severity( count="injury_severity", start_time=start_time, end_time=end_time, + resolution=resolution, ) found_severities = [d["injury_severity"] for d in count_by_severity] items = {} diff --git a/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_stacked_widget.py b/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_stacked_widget.py index 7ca1716cc..aa3090d36 100644 --- a/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_stacked_widget.py +++ b/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_stacked_widget.py @@ -1,7 +1,6 @@ from typing import Dict, List - +# noinspection PyProtectedMember from flask_babel import _ - from anyway.backend_constants import InjurySeverity, BE_CONST as BE from anyway.request_params import RequestParams from anyway.widgets.all_locations_widgets.killed_and_injured_count_per_age_group_widget_utils import ( @@ -9,7 +8,6 @@ AGE_RANGE_DICT, ) from anyway.widgets.all_locations_widgets import killed_and_injured_count_per_age_group_widget_utils - from anyway.widgets.all_locations_widgets.all_locations_widget import AllLocationsWidget from anyway.widgets.widget import register from anyway.widgets.widget_utils import ( diff --git a/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_widget_utils.py b/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_widget_utils.py index 811d292ac..a92dc562a 100644 --- a/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_widget_utils.py +++ b/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_widget_utils.py @@ -1,3 +1,4 @@ +import copy from collections import defaultdict, OrderedDict from typing import Dict, Tuple, Callable @@ -10,7 +11,10 @@ from anyway.models import InvolvedMarkerView from anyway.request_params import RequestParams from anyway.utilities import parse_age_from_range -from anyway.widgets.widget_utils import get_expression_for_road_segment_location_fields +from anyway.widgets.widget_utils import ( + get_expression_for_fields, + add_resolution_location_accuracy_filter, +) # RequestParams is not hashable, so we can't use functools.lru_cache cache_dict = OrderedDict() @@ -105,20 +109,10 @@ def defaultdict_int_factory() -> Callable: def create_query_for_killed_and_injured_count_per_age_group( end_time, start_time, location_info, resolution ) -> BaseQuery: - if resolution == BE_CONST.ResolutionCategories.SUBURBAN_ROAD: - location_filter = get_expression_for_road_segment_location_fields( - {"road_segment_id": location_info["road_segment_id"]}, InvolvedMarkerView - ) - # (InvolvedMarkerView.road1 == location_info["road1"]) - # | (InvolvedMarkerView.road2 == location_info["road1"]) - # ) & (InvolvedMarkerView.road_segment_name == location_info["road_segment_name"]) - elif resolution == BE_CONST.ResolutionCategories.STREET: - location_filter = ( - InvolvedMarkerView.involve_yishuv_name == location_info["yishuv_name"] - ) & ( - (InvolvedMarkerView.street1_hebrew == location_info["street1_hebrew"]) - | (InvolvedMarkerView.street2_hebrew == location_info["street1_hebrew"]) - ) + loc_filter = adapt_location_fields_to_involve_table(location_info) + loc_filter = add_resolution_location_accuracy_filter(loc_filter, + resolution) + loc_ex = get_expression_for_fields(loc_filter, InvolvedMarkerView) query = ( db.session.query(InvolvedMarkerView) @@ -138,7 +132,7 @@ def create_query_for_killed_and_injured_count_per_age_group( ] ) ) - .filter(location_filter) + .filter(loc_ex) .group_by(InvolvedMarkerView.age_group, InvolvedMarkerView.injury_severity) .with_entities( InvolvedMarkerView.age_group, @@ -148,3 +142,11 @@ def create_query_for_killed_and_injured_count_per_age_group( .order_by(asc(InvolvedMarkerView.age_group)) ) return query + + +def adapt_location_fields_to_involve_table(filter: dict) -> dict: + res = copy.copy(filter) + for field in ["yishuv_name", "yishuv_symbol"]: + if field in res: + res[f"involve_{field}"] = res.pop(field) + return res diff --git a/anyway/widgets/all_locations_widgets/most_severe_accidents_table_widget.py b/anyway/widgets/all_locations_widgets/most_severe_accidents_table_widget.py index 2fd9369aa..06e71cf47 100644 --- a/anyway/widgets/all_locations_widgets/most_severe_accidents_table_widget.py +++ b/anyway/widgets/all_locations_widgets/most_severe_accidents_table_widget.py @@ -6,7 +6,9 @@ from anyway.request_params import RequestParams from anyway.backend_constants import BE_CONST, AccidentSeverity, AccidentType, InjurySeverity from anyway.infographics_dictionaries import segment_dictionary -from anyway.widgets.widget_utils import get_query, get_accidents_stats +from anyway.widgets.widget_utils import ( + get_query, get_accidents_stats, add_resolution_location_accuracy_filter, +) from anyway.models import AccidentMarkerView, InvolvedMarkerView from anyway.widgets.all_locations_widgets.all_locations_widget import AllLocationsWidget from anyway.widgets.widget import register @@ -28,6 +30,7 @@ def get_most_severe_accidents_with_entities( ] # pylint: disable=no-member filters["accident_severity"] = [AccidentSeverity.FATAL.value, AccidentSeverity.SEVERE.value] + filters = add_resolution_location_accuracy_filter(filters, resolution) query = get_query(table_obj, filters, start_time, end_time) query = query.with_entities(*entities) query = query.order_by( diff --git a/anyway/widgets/no_location_widgets/vision_zero_bike_widget.py b/anyway/widgets/no_location_widgets/vision_zero_bike_widget.py index 4dc8a85e8..0bcadfc5e 100644 --- a/anyway/widgets/no_location_widgets/vision_zero_bike_widget.py +++ b/anyway/widgets/no_location_widgets/vision_zero_bike_widget.py @@ -1,8 +1,8 @@ -from typing import Dict from anyway.widgets.widget import Widget from anyway.widgets.widget import register from anyway.request_params import RequestParams from typing import Dict, Optional +# noinspection PyProtectedMember from flask_babel import _ import logging diff --git a/anyway/widgets/road_segment_widgets/accident_count_by_car_type_widget.py b/anyway/widgets/road_segment_widgets/accident_count_by_car_type_widget.py index 68d4f7f03..d504708ee 100644 --- a/anyway/widgets/road_segment_widgets/accident_count_by_car_type_widget.py +++ b/anyway/widgets/road_segment_widgets/accident_count_by_car_type_widget.py @@ -8,6 +8,7 @@ import anyway.widgets.widget_utils as widget_utils from anyway.backend_constants import BE_CONST +RC = BE_CONST.ResolutionCategories from anyway.infographics_dictionaries import segment_dictionary from anyway.models import VehicleMarkerView from anyway.request_params import RequestParams @@ -44,6 +45,7 @@ def get_stats_accidents_by_car_type_with_national_data( count="provider_and_id", start_time=request_params.start_time, end_time=request_params.end_time, + resolution=request_params.resolution, ) start_time = request_params.start_time @@ -127,6 +129,7 @@ def percentage_accidents_by_car_type_national_data_cache( count="provider_and_id", start_time=start_time, end_time=end_time, + resolution=RC.SUBURBAN_ROAD, ) return AccidentCountByCarTypeWidget.percentage_accidents_by_car_type( vehicle_grouped_by_type_count_unique diff --git a/anyway/widgets/road_segment_widgets/accident_count_by_driver_type_widget.py b/anyway/widgets/road_segment_widgets/accident_count_by_driver_type_widget.py index 10ab09ff5..250533b44 100644 --- a/anyway/widgets/road_segment_widgets/accident_count_by_driver_type_widget.py +++ b/anyway/widgets/road_segment_widgets/accident_count_by_driver_type_widget.py @@ -28,7 +28,7 @@ def generate_items(self) -> None: ) @staticmethod - def count_accidents_by_driver_type(request_params): + def count_accidents_by_driver_type(request_params: RequestParams): filters = get_injured_filters(request_params) filters["involved_type"] = [ consts.InvolvedType.DRIVER.value, @@ -42,6 +42,7 @@ def count_accidents_by_driver_type(request_params): cnt_distinct=True, start_time=request_params.start_time, end_time=request_params.end_time, + resolution=request_params.resolution, ) driver_types = defaultdict(int) for item in involved_by_vehicle_type_data: diff --git a/anyway/widgets/road_segment_widgets/accident_count_by_hour_widget.py b/anyway/widgets/road_segment_widgets/accident_count_by_hour_widget.py index 50b28c883..977e88691 100644 --- a/anyway/widgets/road_segment_widgets/accident_count_by_hour_widget.py +++ b/anyway/widgets/road_segment_widgets/accident_count_by_hour_widget.py @@ -22,6 +22,7 @@ def generate_items(self) -> None: count="accident_hour", start_time=self.request_params.start_time, end_time=self.request_params.end_time, + resolution=self.request_params.resolution, ) @staticmethod diff --git a/anyway/widgets/road_segment_widgets/accident_count_by_road_light_widget.py b/anyway/widgets/road_segment_widgets/accident_count_by_road_light_widget.py index 84b205a28..d71382254 100644 --- a/anyway/widgets/road_segment_widgets/accident_count_by_road_light_widget.py +++ b/anyway/widgets/road_segment_widgets/accident_count_by_road_light_widget.py @@ -24,6 +24,7 @@ def generate_items(self) -> None: count="road_light_hebrew", start_time=self.request_params.start_time, end_time=self.request_params.end_time, + resolution=self.request_params.resolution, ) @staticmethod diff --git a/anyway/widgets/road_segment_widgets/accident_type_vehicle_type_road_comparison_widget.py b/anyway/widgets/road_segment_widgets/accident_type_vehicle_type_road_comparison_widget.py index 2252bdea5..51b5aa622 100644 --- a/anyway/widgets/road_segment_widgets/accident_type_vehicle_type_road_comparison_widget.py +++ b/anyway/widgets/road_segment_widgets/accident_type_vehicle_type_road_comparison_widget.py @@ -4,10 +4,16 @@ from flask_babel import _ from sqlalchemy import func, distinct, desc - +from anyway.backend_constants import BE_CONST +RC = BE_CONST.ResolutionCategories from anyway.request_params import RequestParams from anyway.app_and_db import db -from anyway.widgets.widget_utils import get_query, run_query +from anyway.widgets.widget_utils import ( + get_query, + run_query, + add_resolution_location_accuracy_filter, + get_expression_for_fields, +) from anyway.models import VehicleMarkerView, AccidentType from anyway.vehicle_type import VehicleCategory from anyway.widgets.road_segment_widgets.road_segment_widget import RoadSegmentWidget @@ -55,6 +61,9 @@ def accident_type_road_vs_all_count( road_query = all_roads_query.filter( (VehicleMarkerView.road1 == road_number) | (VehicleMarkerView.road2 == road_number) ) + loc_filter = add_resolution_location_accuracy_filter(None, RC.SUBURBAN_ROAD) + loc_ex = get_expression_for_fields(loc_filter, VehicleMarkerView) + road_query = road_query.filter(loc_ex) road_query_result = run_query(road_query) road_sum_accidents = 0 types_to_report = [] @@ -84,12 +93,16 @@ def get_accident_count_by_vehicle_type_query( num_accidents_label: str, vehicle_types: List[int], ) -> db.session.query: + filters = add_resolution_location_accuracy_filter( + {VehicleMarkerView.vehicle_type.name: vehicle_types}, + RC.SUBURBAN_ROAD + ) return ( get_query( table_obj=VehicleMarkerView, start_time=start_time, end_time=end_time, - filters={VehicleMarkerView.vehicle_type.name: vehicle_types}, + filters=filters, ) .with_entities( VehicleMarkerView.accident_type, diff --git a/anyway/widgets/road_segment_widgets/fatal_accident_yoy_same_month.py b/anyway/widgets/road_segment_widgets/fatal_accident_yoy_same_month.py index 2be732d49..699853a44 100644 --- a/anyway/widgets/road_segment_widgets/fatal_accident_yoy_same_month.py +++ b/anyway/widgets/road_segment_widgets/fatal_accident_yoy_same_month.py @@ -32,6 +32,7 @@ def generate_items(self) -> None: count=InvolvedMarkerView.injury_severity.name, start_time=self.request_params.start_time, end_time=self.request_params.end_time, + resolution=self.request_params.resolution, ): structured_data_list.append( { diff --git a/anyway/widgets/road_segment_widgets/front_to_side_accidents_by_severity.py b/anyway/widgets/road_segment_widgets/front_to_side_accidents_by_severity.py index a61316022..5d2a0a5b9 100644 --- a/anyway/widgets/road_segment_widgets/front_to_side_accidents_by_severity.py +++ b/anyway/widgets/road_segment_widgets/front_to_side_accidents_by_severity.py @@ -5,12 +5,13 @@ from flask_babel import _ from sqlalchemy import case, func, distinct -from anyway.backend_constants import AccidentType, AccidentSeverity +from anyway.backend_constants import AccidentType, AccidentSeverity, BE_CONST +RC = BE_CONST.ResolutionCategories from anyway.models import AccidentMarkerView from anyway.request_params import RequestParams from anyway.widgets.road_segment_widgets.road_segment_widget import RoadSegmentWidget from anyway.widgets.widget import register -from anyway.widgets.widget_utils import get_query +from anyway.widgets.widget_utils import get_query, add_resolution_location_accuracy_filter ROAD_SEGMENT_ACCIDENTS = "specific_road_segment_accidents" @@ -100,9 +101,13 @@ def _get_raw_front_to_side_accidents( ) ] ) + filters = add_resolution_location_accuracy_filter( + {"road_segment_id": road_segment_id}, + RC.SUBURBAN_ROAD, + ) query = get_query( table_obj=AccidentMarkerView, - filters={"road_segment_id": road_segment_id}, + filters=filters, start_time=start_date, end_time=end_date, ) diff --git a/anyway/widgets/road_segment_widgets/head_on_collisions_comparison_widget.py b/anyway/widgets/road_segment_widgets/head_on_collisions_comparison_widget.py index f0cadc852..17c4fdd38 100644 --- a/anyway/widgets/road_segment_widgets/head_on_collisions_comparison_widget.py +++ b/anyway/widgets/road_segment_widgets/head_on_collisions_comparison_widget.py @@ -39,13 +39,13 @@ def get_head_to_head_stat(self) -> Dict: count="accident_type", start_time=self.request_params.start_time, end_time=self.request_params.end_time, + resolution=self.request_params.resolution, ) if location_info["road1"] and location_info["road_segment_name"]: filter_dict.update( { "road1": location_info["road1"], - "road_segment_name": location_info["road_segment_name"], "road_segment_id": location_info["road_segment_id"], } ) @@ -56,6 +56,7 @@ def get_head_to_head_stat(self) -> Dict: count="accident_type", start_time=self.request_params.start_time, end_time=self.request_params.end_time, + resolution=self.request_params.resolution, ) road_sums = self.sum_count_of_accident_type( diff --git a/anyway/widgets/road_segment_widgets/motorcycle_accidents_vs_all_accidents_widget.py b/anyway/widgets/road_segment_widgets/motorcycle_accidents_vs_all_accidents_widget.py index 55d872a32..5ca95490d 100644 --- a/anyway/widgets/road_segment_widgets/motorcycle_accidents_vs_all_accidents_widget.py +++ b/anyway/widgets/road_segment_widgets/motorcycle_accidents_vs_all_accidents_widget.py @@ -7,11 +7,16 @@ from anyway.widgets.widget import register from anyway.request_params import RequestParams from anyway.backend_constants import BE_CONST, AccidentSeverity -from anyway.widgets.widget_utils import get_query +RC = BE_CONST.ResolutionCategories +from anyway.widgets.widget_utils import ( + get_query, + add_resolution_location_accuracy_filter, +) from anyway.models import InvolvedMarkerView from anyway.vehicle_type import VehicleCategory from anyway.widgets.road_segment_widgets.road_segment_widget import RoadSegmentWidget from typing import Dict +# noinspection PyProtectedMember from flask_babel import _ @@ -34,6 +39,7 @@ def __init__(self, request_params: RequestParams): "Percentage of serious and fatal motorcycle accidents in the selected section compared to the average percentage of accidents in other road sections throughout the country" ) + # noinspection PyAttributeOutsideInit def generate_items(self) -> None: # noinspection PyUnresolvedReferences res = MotorcycleAccidentsVsAllAccidentsWidget.motorcycle_accidents_vs_all_accidents( @@ -44,7 +50,7 @@ def generate_items(self) -> None: @staticmethod def motorcycle_accidents_vs_all_accidents( start_time: datetime.date, end_time: datetime.date, road_number: str - ) -> Tuple: + ) -> Tuple: location_label = "location" case_location = case( [ @@ -60,6 +66,7 @@ def motorcycle_accidents_vs_all_accidents( vehicle_label = "vehicle" vehicle_other = VehicleCategory.OTHER.get_english_display_name() vehicle_motorcycle = VehicleCategory.MOTORCYCLE.get_english_display_name() + # noinspection PyUnresolvedReferences case_vehicle = case( [ ( @@ -72,8 +79,14 @@ def motorcycle_accidents_vs_all_accidents( else_=literal_column(f"'{vehicle_other}'"), ).label(vehicle_label) + filters = {"road_type": BE_CONST.NON_CITY_ROAD_TYPES, + "accident_severity": [AccidentSeverity.FATAL.value, AccidentSeverity.SEVERE.value]} + filters = add_resolution_location_accuracy_filter( + filters, + RC.SUBURBAN_ROAD, + ) query = get_query( - table_obj=InvolvedMarkerView, filters={}, start_time=start_time, end_time=end_time + table_obj=InvolvedMarkerView, filters=filters, start_time=start_time, end_time=end_time ) num_accidents_label = "num_of_accidents" @@ -83,13 +96,6 @@ def motorcycle_accidents_vs_all_accidents( case_vehicle, func.count(distinct(InvolvedMarkerView.provider_and_id)).label(num_accidents_label), ) - .filter(InvolvedMarkerView.road_type.in_(BE_CONST.NON_CITY_ROAD_TYPES)) - .filter( - InvolvedMarkerView.accident_severity.in_( - # pylint: disable=no-member - [AccidentSeverity.FATAL.value, AccidentSeverity.SEVERE.value] - ) - ) .group_by(location_label, vehicle_label) .order_by(desc(num_accidents_label)) ) @@ -179,5 +185,6 @@ def localize_items(request_params: RequestParams, items: Dict) -> Dict: def is_included(self) -> bool: return self.counter_road_motorcycle >= 3 and self.motorcycle_road_percentage >= 2*self.motorcycle_all_roads_percentage + _("road") _("Percentage of serious and fatal motorcycle accidents in the selected section compared to the average percentage of accidents in other road sections throughout the country") diff --git a/anyway/widgets/road_segment_widgets/road2_plus1_widget.py b/anyway/widgets/road_segment_widgets/road2_plus1_widget.py index 336e56e5c..0bdac5ce2 100644 --- a/anyway/widgets/road_segment_widgets/road2_plus1_widget.py +++ b/anyway/widgets/road_segment_widgets/road2_plus1_widget.py @@ -43,6 +43,7 @@ def get_frontal_accidents_in_past_year(self) -> Optional[int]: count="accident_type", start_time=self.request_params.end_time - datetime.timedelta(days=365), end_time=self.request_params.end_time, + resolution=self.request_params.resolution ) road_sums = self.sum_count_of_accident_type( diff --git a/anyway/widgets/road_segment_widgets/suburban_crosswalk_widget.py b/anyway/widgets/road_segment_widgets/suburban_crosswalk_widget.py index 597233a57..0d8fb64f5 100644 --- a/anyway/widgets/road_segment_widgets/suburban_crosswalk_widget.py +++ b/anyway/widgets/road_segment_widgets/suburban_crosswalk_widget.py @@ -1,7 +1,8 @@ from typing import Dict, Any from anyway.request_params import RequestParams -from anyway.backend_constants import InjurySeverity +from anyway.backend_constants import InjurySeverity, BE_CONST +RC = BE_CONST.ResolutionCategories from anyway.models import InvolvedMarkerView from anyway.widgets.road_segment_widgets.road_segment_widget import RoadSegmentWidget from anyway.widgets.widget_utils import get_accidents_stats @@ -24,10 +25,11 @@ def generate_items(self) -> None: self.request_params.location_info["road_segment_name"], self.request_params.start_time, self.request_params.end_time, + self.request_params.resolution, ) @staticmethod - def get_crosswalk(road, start_time, end_time) -> Dict[str, Any]: + def get_crosswalk(road, start_time, end_time, resolution: RC) -> Dict[str, Any]: cross_output = { "with_crosswalk": get_accidents_stats( table_obj=InvolvedMarkerView, @@ -43,6 +45,7 @@ def get_crosswalk(road, start_time, end_time) -> Dict[str, Any]: count="road_segment_name", start_time=start_time, end_time=end_time, + resolution=resolution ), "without_crosswalk": get_accidents_stats( table_obj=InvolvedMarkerView, @@ -58,6 +61,7 @@ def get_crosswalk(road, start_time, end_time) -> Dict[str, Any]: count="road_segment_name", start_time=start_time, end_time=end_time, + resolution=resolution, ), } if not cross_output["with_crosswalk"]: diff --git a/anyway/widgets/road_segment_widgets/top_road_segments_accidents_per_km_widget.py b/anyway/widgets/road_segment_widgets/top_road_segments_accidents_per_km_widget.py index 96ba41ceb..4237cba9a 100644 --- a/anyway/widgets/road_segment_widgets/top_road_segments_accidents_per_km_widget.py +++ b/anyway/widgets/road_segment_widgets/top_road_segments_accidents_per_km_widget.py @@ -9,7 +9,7 @@ from anyway.request_params import RequestParams from anyway.backend_constants import AccidentSeverity from anyway.widgets.widget import register -from anyway.widgets.widget_utils import get_query +from anyway.widgets.widget_utils import get_query, add_resolution_location_accuracy_filter from anyway.models import AccidentMarkerView from anyway.widgets.road_segment_widgets.road_segment_widget import RoadSegmentWidget @@ -36,9 +36,11 @@ def generate_items(self) -> None: def get_top_road_segments_accidents_per_km( resolution, location_info, start_time=None, end_time=None, limit=3 ): + filters = {"road1": location_info["road1"]} + filters = add_resolution_location_accuracy_filter(filters, resolution) query = get_query( table_obj=AccidentMarkerView, - filters={"road1": location_info["road1"]}, + filters=filters, start_time=start_time, end_time=end_time, ) diff --git a/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py b/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py index 645b8cbce..3a390c0a9 100644 --- a/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py +++ b/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py @@ -12,6 +12,8 @@ add_empty_keys_to_gen_two_level_dict, gen_entity_labels, format_2_level_items, + add_resolution_location_accuracy_filter, + get_expression_for_fields, ) from anyway.models import InvolvedMarkerView from anyway.widgets.widget import register @@ -54,11 +56,16 @@ def generate_items(self) -> None: yishuv_name = self.request_params.location_info.get("yishuv_name") street1_hebrew = self.request_params.location_info.get("street1_hebrew") - if not self.validate_parameters(yishuv_name, street1_hebrew): - # TODO: this will fail since there is no news_flash_obj in request_params - logging.exception(f"Could not validate parameters yishuv_name + street1_hebrew in widget : {self.name}") - return None + # if not self.validate_parameters(yishuv_name, street1_hebrew): + # # TODO: this will fail since there is no news_flash_obj in request_params + # logging.exception(f"Could not validate parameters yishuv_name + street1_hebrew in widget : {self.name}") + # return None + loc_accuracy = add_resolution_location_accuracy_filter( + None, + self.request_params.resolution + ) + loc_ex = get_expression_for_fields(loc_accuracy, InvolvedMarkerView) query = ( db.session.query(InvolvedMarkerView) .with_entities( @@ -66,6 +73,7 @@ def generate_items(self) -> None: InvolvedMarkerView.injury_severity, func.count().label("count"), ) + .filter(loc_ex) .filter(InvolvedMarkerView.accident_yishuv_name == yishuv_name) .filter( InvolvedMarkerView.injury_severity.in_( diff --git a/anyway/widgets/urban_widgets/severe_fatal_count_by_vehicle_by_year_widget.py b/anyway/widgets/urban_widgets/severe_fatal_count_by_vehicle_by_year_widget.py index 866c436d5..b60f064c0 100644 --- a/anyway/widgets/urban_widgets/severe_fatal_count_by_vehicle_by_year_widget.py +++ b/anyway/widgets/urban_widgets/severe_fatal_count_by_vehicle_by_year_widget.py @@ -25,10 +25,11 @@ def generate_items(self) -> None: self.request_params.location_info["yishuv_name"], self.request_params.start_time, self.request_params.end_time, + self.request_params.resolution, ) @staticmethod - def separate_data(yishuv, start_time, end_time) -> Dict[str, Any]: + def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]: output = { "e_bikes": get_accidents_stats( table_obj=InvolvedMarkerView, @@ -44,6 +45,7 @@ def separate_data(yishuv, start_time, end_time) -> Dict[str, Any]: count="accident_year", start_time=start_time, end_time=end_time, + resolution=resolution, ), "bikes": get_accidents_stats( table_obj=InvolvedMarkerView, @@ -59,6 +61,7 @@ def separate_data(yishuv, start_time, end_time) -> Dict[str, Any]: count="accident_year", start_time=start_time, end_time=end_time, + resolution=resolution, ), "e_scooters": get_accidents_stats( table_obj=InvolvedMarkerView, @@ -74,6 +77,7 @@ def separate_data(yishuv, start_time, end_time) -> Dict[str, Any]: count="accident_year", start_time=start_time, end_time=end_time, + resolution=resolution, ), } bike_accidents = [d["accident_year"] for d in output["bikes"] if "accident_year" in d] diff --git a/anyway/widgets/urban_widgets/small_motor_severe_fatal_count_by_year_widget.py b/anyway/widgets/urban_widgets/small_motor_severe_fatal_count_by_year_widget.py index f6834cb0a..38a6ac861 100644 --- a/anyway/widgets/urban_widgets/small_motor_severe_fatal_count_by_year_widget.py +++ b/anyway/widgets/urban_widgets/small_motor_severe_fatal_count_by_year_widget.py @@ -24,10 +24,11 @@ def generate_items(self) -> None: self.request_params.location_info["yishuv_name"], self.request_params.start_time, self.request_params.end_time, + self.request_params.resolution, ) @staticmethod - def get_motor_stats(location_info, start_time, end_time): + def get_motor_stats(location_info, start_time, end_time, resolution): count_by_year = get_accidents_stats( table_obj=InvolvedMarkerView, filters={ @@ -42,6 +43,7 @@ def get_motor_stats(location_info, start_time, end_time): count="accident_year", start_time=start_time, end_time=end_time, + resolution=resolution, ) found_accidents = [d["accident_year"] for d in count_by_year if "accident_year" in d] start_year = start_time.year diff --git a/anyway/widgets/urban_widgets/urban_crosswalk_widget.py b/anyway/widgets/urban_widgets/urban_crosswalk_widget.py index 0d6eb6973..c83e08747 100644 --- a/anyway/widgets/urban_widgets/urban_crosswalk_widget.py +++ b/anyway/widgets/urban_widgets/urban_crosswalk_widget.py @@ -25,10 +25,11 @@ def generate_items(self) -> None: self.request_params.location_info["street1_hebrew"], self.request_params.start_time, self.request_params.end_time, + self.request_params.resolution, ) @staticmethod - def get_crosswalk(yishuv, street, start_time, end_time) -> Dict[str, Any]: + def get_crosswalk(yishuv, street, start_time, end_time, resolution) -> Dict[str, Any]: cross_output = { "with_crosswalk": get_accidents_stats( table_obj=InvolvedMarkerView, @@ -45,6 +46,7 @@ def get_crosswalk(yishuv, street, start_time, end_time) -> Dict[str, Any]: count="street1_hebrew", start_time=start_time, end_time=end_time, + resolution=resolution, ), "without_crosswalk": get_accidents_stats( table_obj=InvolvedMarkerView, @@ -61,6 +63,7 @@ def get_crosswalk(yishuv, street, start_time, end_time) -> Dict[str, Any]: count="street1_hebrew", start_time=start_time, end_time=end_time, + resolution=resolution, ), } if not cross_output["with_crosswalk"]: diff --git a/anyway/widgets/widget_utils.py b/anyway/widgets/widget_utils.py index 4e29b0ded..409f1c4b6 100644 --- a/anyway/widgets/widget_utils.py +++ b/anyway/widgets/widget_utils.py @@ -5,6 +5,8 @@ from typing import Dict, Any, List, Type, Optional, Sequence, Tuple import pandas as pd + +# noinspection PyProtectedMember from flask_babel import _ from sqlalchemy import func, distinct, between, or_, and_ @@ -18,6 +20,8 @@ from anyway.request_params import RequestParams from anyway.widgets.segment_junctions import SegmentJunctions +RC = BE_CONST.ResolutionCategories + def get_query(table_obj, filters, start_time, end_time): if "road_segment_name" in filters and "road_segment_id" in filters: @@ -31,28 +35,45 @@ def get_query(table_obj, filters, start_time, end_time): if not filters: return query if "road_segment_id" not in filters.keys(): - query = query.filter(get_expression_for_fields(filters, table_obj, and_)) + query = query.filter(get_expression_for_non_road_segment_fields(filters, table_obj, and_)) return query location_fields, other_fields = split_location_fields_and_others(filters) if other_fields: - query = query.filter(get_expression_for_fields(other_fields, table_obj, and_)) + query = query.filter( + get_expression_for_non_road_segment_fields(other_fields, table_obj, and_) + ) query = query.filter( get_expression_for_road_segment_location_fields(location_fields, table_obj) ) return query -def get_expression_for_fields(filters, table_obj, op): - inv_val = op == and_ - ex = op(inv_val, inv_val) - for field_name, value in filters.items(): - ex = op(ex, get_filter_expression(table_obj, field_name, value)) +def get_expression_for_fields(filters: dict, table_obj): + op_other, op_segment = None, None + if "road_segment_id" not in filters.keys(): + return get_expression_for_non_road_segment_fields(filters, table_obj, and_) + location_fields, other_fields = split_location_fields_and_others(filters) + if other_fields: + op_other = get_expression_for_non_road_segment_fields(other_fields, table_obj, and_) + op_segment = get_expression_for_road_segment_location_fields(location_fields, table_obj) + return op_segment if op_other is None else and_(op_segment, op_other) + + +def get_expression_for_non_road_segment_fields(filters, table_obj, op): + items = list(filters.items()) + if len(items) == 0: + return True + field_name, value = items[0] + ex = get_filter_expression(table_obj, field_name, value) + for field_name, value in items[1:]: + ex2 = get_filter_expression(table_obj, field_name, value) + ex = op(ex, ex2) return ex # todo: remove road_segment_name if road_segment_id exists. def get_expression_for_road_segment_location_fields(filters, table_obj): - ex = get_expression_for_fields(filters, table_obj, and_) + ex = get_expression_for_non_road_segment_fields(filters, table_obj, and_) segment_id = filters["road_segment_id"] junctions_ex = get_expression_for_segment_junctions(segment_id, table_obj) res = or_(ex, junctions_ex) @@ -67,22 +88,19 @@ def get_expression_for_segment_junctions(segment_id: int, table_obj): def get_filter_expression(table_obj, field_name, value): if field_name == "street1_hebrew" or field_name == "street1": - if isinstance(value, list): - values = value - else: - values = [value] - o = or_( - (getattr(table_obj, field_name)).in_(values), - (getattr(table_obj, field_name.replace("1", "2"))).in_(values), - # (getattr(table_obj, "street1_hebrew")).in_(values), - # (getattr(table_obj, "street2_hebrew")).in_(values), + return or_( + get_filter_expression_raw(table_obj, field_name, value), + get_filter_expression_raw(table_obj, field_name.replace("1", "2"), value), ) else: - if isinstance(value, list): - o = (getattr(table_obj, field_name)).in_(value) - else: - o = (getattr(table_obj, field_name)) == value - return o + return get_filter_expression_raw(table_obj, field_name, value) + + +def get_filter_expression_raw(table_obj, field_name, value): + if isinstance(value, list): + return (getattr(table_obj, field_name)).in_(value) + else: + return (getattr(table_obj, field_name)) == value def split_location_fields_and_others(filters: dict) -> Tuple[dict, dict]: @@ -102,8 +120,10 @@ def get_accidents_stats( cnt_distinct=False, start_time=None, end_time=None, + resolution: Optional[RC] = None, ): filters = filters or {} + filters = add_resolution_location_accuracy_filter(filters, resolution) provider_code_filters = [ BE_CONST.CBS_ACCIDENT_TYPE_1_CODE, BE_CONST.CBS_ACCIDENT_TYPE_3_CODE, @@ -141,9 +161,9 @@ def get_accidents_stats( # noinspection Mypy -def retro_dictify(indexable) -> Dict[Any, Dict[Any, Any]]: +def retro_dictify(iterable) -> Dict[Any, Dict[Any, Any]]: d = defaultdict(dict) - for row in indexable: + for row in iterable: here = d for elem in row[:-2]: if elem not in here: @@ -273,16 +293,12 @@ def get_involved_counts( .filter(between(table.accident_year, start_year, end_year)) .order_by(table.accident_year) ) - + filters = add_resolution_location_accuracy_filter(location_info, table) if "yishuv_symbol" in location_info: - query = query.filter( - table.accident_yishuv_symbol == location_info["yishuv_symbol"] - ).group_by(table.accident_year) - elif "road_segment_id" in location_info: - ex = get_expression_for_road_segment_location_fields( - {"road_segment_id": location_info["road_segment_id"]}, table - ) - query = query.filter(ex).group_by(table.accident_year) + filters["accident_yishuv_symbol"] = filters["yishuv_symbol"] + filters.pop("yishuv_symbol") + ex = get_expression_for_fields(filters, table) + query = query.filter(ex).group_by(table.accident_year) if severities: query = query.filter(table.injury_severity.in_([severity.value for severity in severities])) @@ -319,3 +335,30 @@ def get_location_text(request_params: RequestParams) -> str: return f'{_("in segment")} {_(request_params.location_info["road_segment_name"])}' elif request_params.resolution == BE_CONST.ResolutionCategories.STREET: return f'{_("in street")} {request_params.location_info["street1_hebrew"]} {in_str}{request_params.location_info["yishuv_name"]}' + + +__RESOLUTION_ACCURACY_VALUES: dict = { + RC.SUBURBAN_JUNCTION: [1, 4], + RC.SUBURBAN_ROAD: [1, 4], + RC.URBAN_JUNCTION: [1, 3], + RC.STREET: [1, 3], +} + + +def get_resolution_location_accuracy_filter(rc: RC) -> Optional[dict]: + vals = __RESOLUTION_ACCURACY_VALUES.get(rc) + return {"location_accuracy": vals} if vals else None + + +def add_resolution_location_accuracy_filter( + filters: Optional[dict], resolution: RC +) -> Optional[dict]: + la = get_resolution_location_accuracy_filter(resolution) + if la is None: + return filters + elif filters is None: + return la + else: + res = copy.copy(filters) + res.update(la) + return res diff --git a/tests/test_infographic_api.py b/tests/test_infographic_api.py index 2ec8decf9..05e5d165e 100644 --- a/tests/test_infographic_api.py +++ b/tests/test_infographic_api.py @@ -2,7 +2,6 @@ import pytest import anyway.request_params import anyway.widgets.widget_utils as widget_utils - from numpy import nan from six.moves import http_client from anyway import app as flask_app @@ -10,7 +9,9 @@ from anyway.app_and_db import db from anyway.vehicle_type import VehicleCategory from anyway.widgets.road_segment_widgets.accident_count_by_car_type_widget import AccidentCountByCarTypeWidget -from anyway.backend_constants import NewsflashLocationQualification +from anyway.backend_constants import NewsflashLocationQualification, BE_CONST +RC = BE_CONST.ResolutionCategories + def insert_infographic_mock_data(app): @@ -104,7 +105,7 @@ def test_accident_count_by_car_type(self, app): assert output_tmp[VehicleCategory.BICYCLE_AND_SMALL_MOTOR.value] == 16 def mock_get_accidents_stats(table_obj, filters=None, group_by=None, count=None, start_time=None, - end_time=None): + end_time=None, resolution=None): return [{'vehicle_type': nan, 'count': 2329}, {'vehicle_type': 14.0, 'count': 112}, {'vehicle_type': 25.0, 'count': 86}, {'vehicle_type': 17.0, 'count': 1852}, {'vehicle_type': 12.0, 'count': 797}, {'vehicle_type': 8.0, 'count': 186}, @@ -269,7 +270,7 @@ def test_fatal_yoy_monthly(self): "type": "object", "properties": {"label_key": {"type": "number"}, "value": {"type": "number"}, }, } - assert widget["data"]["items"][0] == {'label_key': 2014, 'value': 32} + assert widget["data"]["items"][0] == {'label_key': 2014, 'value': 24} validate(widget["data"]["items"][0], schema) assert widget["data"]["text"]["title"] == "כמות ההרוגים בתאונות דרכים בחודש הנוכחי בהשוואה לשנים קודמות" diff --git a/tests/test_infographics_utils.py b/tests/test_infographics_utils.py index 2d6122d1f..dfe289e51 100644 --- a/tests/test_infographics_utils.py +++ b/tests/test_infographics_utils.py @@ -1,10 +1,20 @@ import unittest from unittest.mock import patch -from anyway.widgets.widget_utils import (format_2_level_items, get_filter_expression, - get_expression_for_segment_junctions) +from sqlalchemy import and_ +from anyway.widgets.widget_utils import (format_2_level_items, + get_expression_for_segment_junctions, + add_resolution_location_accuracy_filter, + get_expression_for_fields, + get_filter_expression_raw, + get_filter_expression, + get_expression_for_non_road_segment_fields, + + ) from anyway.backend_constants import AccidentSeverity -from anyway.models import AccidentMarkerView, RoadJunctionKM, RoadSegments +from anyway.models import AccidentMarkerView, RoadJunctionKM, RoadSegments, InvolvedMarkerView from anyway.widgets.segment_junctions import SegmentJunctions +from anyway.backend_constants import BE_CONST +RC = BE_CONST.ResolutionCategories class TestInfographicsUtilsCase(unittest.TestCase): @@ -40,6 +50,7 @@ class TestInfographicsUtilsCase(unittest.TestCase): ] t = RoadJunctionKM() + # noinspection SpellCheckingInspection rjks = [ RoadJunctionKM(road=1, non_urban_intersection=1, km=1.0), RoadJunctionKM(road=1, non_urban_intersection=2, km=2.0), @@ -94,9 +105,9 @@ def test_get_filter_expression(self): actual = get_filter_expression(AccidentMarkerView, "street1", "1") self.assertEqual(2, len(actual.expression.clauses), "8") self.assertEqual('markers_hebrew.street1', str(actual.expression.clauses[0].left), "9") - self.assertEqual('1', actual.clauses[0].right.element.clauses[0].value, "10") + self.assertEqual('1', actual.clauses[0].right.effective_value, "10") self.assertEqual('markers_hebrew.street2', str(actual.expression.clauses[1].left), "11") - self.assertEqual('1', actual.clauses[1].right.element.clauses[0].value, "12") + self.assertEqual('1', actual.clauses[1].right.effective_value, "12") @patch("anyway.widgets.widget_utils.SegmentJunctions") def test_get_expression_for_segment_junctions(self, sg): @@ -105,6 +116,97 @@ def test_get_expression_for_segment_junctions(self, sg): actual = get_expression_for_segment_junctions(17, AccidentMarkerView) self.assertEqual('1 != 1', str(actual.expression), "1") + def test_add_resolution_location_accuracy_filter(self): + f = {"1": 1} + actual = add_resolution_location_accuracy_filter(f, RC.STREET) + self.assertEqual({'1': 1, 'location_accuracy': [1, 3]}, actual, "2") + actual = add_resolution_location_accuracy_filter(None, RC.STREET) + self.assertEqual({'location_accuracy': [1, 3]}, actual, "3") + actual = add_resolution_location_accuracy_filter(f, RC.SUBURBAN_JUNCTION) + self.assertEqual({'1': 1, 'location_accuracy': [1, 4]}, actual, "4") + actual = add_resolution_location_accuracy_filter(None, RC.SUBURBAN_ROAD) + self.assertEqual({'location_accuracy': [1, 4]}, actual, "5" + ) + actual = add_resolution_location_accuracy_filter(f, RC.SUBURBAN_JUNCTION) + self.assertEqual({'1': 1, 'location_accuracy': [1, 4]}, actual, "6",) + actual = add_resolution_location_accuracy_filter(None, RC.STREET) + self.assertEqual({'location_accuracy': [1, 3]}, actual, "7") + actual = add_resolution_location_accuracy_filter(f, RC.OTHER) + self.assertEqual(f, actual, "8") + actual = add_resolution_location_accuracy_filter(None, RC.OTHER) + self.assertIsNone(actual, "9") + + @patch("anyway.widgets.widget_utils.and_") + @patch("anyway.widgets.widget_utils.split_location_fields_and_others") + @patch("anyway.widgets.widget_utils.get_expression_for_road_segment_location_fields") + @patch("anyway.widgets.widget_utils.get_expression_for_non_road_segment_fields") + def test_get_expression_for_fields(self, non_segment_ex, segment_ex, split, sql_and): + sql_and_return_val = "sql_and_return_val" + sql_and.return_value = sql_and_return_val + non_segment_ex_return_val = "non_segment_ex_return_val" + segment_ex_return_val = "segment_ex_return_val" + segment_ex.return_value = segment_ex_return_val + f = {"road1": 1} + o = {"location_accuracy": 1} + non_segment_ex.return_value = non_segment_ex_return_val + actual = get_expression_for_fields(f, InvolvedMarkerView) + non_segment_ex.assert_called_with(f, InvolvedMarkerView, sql_and) + self.assertEqual(non_segment_ex_return_val, actual, "1") + + non_segment_ex.reset_mock() + f = {"road_segment_id": 1} + split.return_value = (f, {}) + actual = get_expression_for_fields(f, InvolvedMarkerView) + self.assertEqual(segment_ex_return_val, actual, "2") + non_segment_ex.assert_not_called() + segment_ex.assert_called_with(f, InvolvedMarkerView) + + non_segment_ex.reset_mock() + segment_ex.reset_mock() + f = {"road_segment_id": 1} + split.return_value = (f, o) + actual = get_expression_for_fields(f, InvolvedMarkerView) + self.assertEqual(sql_and_return_val, actual, "3") + non_segment_ex.assert_called_with(o, InvolvedMarkerView, sql_and) + segment_ex.assert_called_with(f, InvolvedMarkerView) + sql_and.assert_called_with(segment_ex_return_val, non_segment_ex_return_val) + + @patch("anyway.widgets.widget_utils.or_") + def test_get_filter_expression_1(self, sql_or): + sql_or_return_val = "sql_or_return_val" + sql_or.return_value = sql_or_return_val + actual = get_filter_expression(InvolvedMarkerView, "street1_hebrew", "name") + self.assertEqual(sql_or_return_val, actual, "1") + for i in [0, 1]: + arg = str(sql_or.call_args.args[i]) + self.assertTrue(arg.startswith(f'involved_markers_hebrew.street{i+1}_hebrew ='), f"2.{i}") + + def test_get_filter_expression_raw(self): + actual = get_filter_expression_raw(InvolvedMarkerView, "location_accuracy", "1") + self.assertIn("involved_markers_hebrew.location_accuracy =", + str(actual), "1") + + actual = get_filter_expression_raw(InvolvedMarkerView, "location_accuracy", [1, 3]) + self.assertIn("involved_markers_hebrew.location_accuracy IN", + str(actual), "2") + + def test_get_expression_for_non_road_segment_fields(self): + actual = get_expression_for_non_road_segment_fields({"location_accuracy": "1"}, + InvolvedMarkerView, + and_) + self.assertIn("involved_markers_hebrew.location_accuracy =", + str(actual), "1") + actual = get_expression_for_non_road_segment_fields({"location_accuracy": "1", + "road1": 1}, + InvolvedMarkerView, + and_) + self.assertIn("involved_markers_hebrew.location_accuracy =", + str(actual), "2") + self.assertIn("involved_markers_hebrew.road1 =", + str(actual), "3") + self.assertIn(" AND ", + str(actual), "4") + if __name__ == '__main__': unittest.main() diff --git a/tests/test_request_params.py b/tests/test_request_params.py index d719da732..ec0b35a65 100644 --- a/tests/test_request_params.py +++ b/tests/test_request_params.py @@ -1,7 +1,8 @@ import unittest from unittest.mock import patch from datetime import date -from pandas._libs.tslibs.timestamps import Timestamp +# noinspection PyProtectedMember +from pandas._libs.tslibs.timestamps import Timestamp # pylint: disable=E0611 from anyway.request_params import ( extract_non_urban_intersection_location, get_request_params_from_request_values, @@ -22,16 +23,16 @@ class TestRequestParams(unittest.TestCase): } junction_1277_roads = {"non_urban_intersection": 1277, "non_urban_intersection_hebrew": "צומת השיטה", - "roads": set([669, 71]), + "roads": {669, 71}, } loc_1 = {'data': {'non_urban_intersection': 1277, - 'non_urban_intersection_hebrew': 'צומת השיטה', - 'resolution': BE_CONST.ResolutionCategories.SUBURBAN_JUNCTION, - 'road1': 669, - 'road2': 71}, - 'gps': {'lat': 32.82561, 'lon': 35.165395}, - 'name': 'location', - 'text': 'צומת השיטה'} + 'non_urban_intersection_hebrew': 'צומת השיטה', + 'resolution': BE_CONST.ResolutionCategories.SUBURBAN_JUNCTION, + 'road1': 669, + 'road2': 71}, + 'gps': {'lat': 32.82561, 'lon': 35.165395}, + 'name': 'location', + 'text': 'צומת השיטה'} nf = NewsFlash() nf.description = "description" nf.title = "title" @@ -69,7 +70,7 @@ def test_fill_missing_non_urban_intersection_values(self, from_roads, from_key): actual = fill_missing_non_urban_intersection_values(input_params) self.assertEqual(self.junction_1277, actual, "2") # add assertion here - input_params = {"non_urban_intersection_hebrew": "צומת השיטה",} + input_params = {"non_urban_intersection_hebrew": "צומת השיטה"} from_key.return_value = self.junction_1277_roads actual = fill_missing_non_urban_intersection_values(input_params) self.assertEqual(self.junction_1277, actual, "2") # add assertion here @@ -79,6 +80,7 @@ def test_fill_missing_non_urban_intersection_values(self, from_roads, from_key): @patch("anyway.request_params.get_location_from_news_flash_or_request_values") def test_get_request_params_from_request_values(self, get_location, extract_nf, get_date): get_location.return_value = self.loc_1 + # noinspection PyTypeChecker get_date.return_value = Timestamp("2018-01-02 01:15:16") extract_nf.return_value = self.nf input_params = {"road1": 669, "road2": 71}