diff --git a/alembic/versions/312c9eb92e40_add_cbs_locations_table.py b/alembic/versions/312c9eb92e40_add_cbs_locations_table.py index e0b02ec62..6c59c9c78 100644 --- a/alembic/versions/312c9eb92e40_add_cbs_locations_table.py +++ b/alembic/versions/312c9eb92e40_add_cbs_locations_table.py @@ -34,7 +34,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id') ) conn = op.get_bind() - conn.execute("""INSERT INTO cbs_locations + conn.execute(sa.text("""INSERT INTO cbs_locations (SELECT ROW_NUMBER() OVER (ORDER BY road1) as id, LOCATIONS.* FROM (SELECT DISTINCT road1, @@ -52,7 +52,7 @@ def upgrade(): WHERE (provider_code=1 OR provider_code=3) AND (longitude is not null - AND latitude is not null)) LOCATIONS)""") + AND latitude is not null)) LOCATIONS)""")) # ### end Alembic commands ### diff --git a/alembic/versions/4c4b79f8c4a_adding_geom_gix_to_markers.py b/alembic/versions/4c4b79f8c4a_adding_geom_gix_to_markers.py index 22690606e..6d2a64159 100644 --- a/alembic/versions/4c4b79f8c4a_adding_geom_gix_to_markers.py +++ b/alembic/versions/4c4b79f8c4a_adding_geom_gix_to_markers.py @@ -12,13 +12,14 @@ depends_on = None from alembic import op +import sqlalchemy as sa def upgrade(): ### commands auto generated by Alembic - please adjust! ### conn = op.get_bind() - conn.execute('CREATE INDEX geom_gix ON markers USING GIST (geography(geom));') - conn.execute('CREATE INDEX discussions_gix ON discussions USING GIST (geography(geom));') + conn.execute(sa.text('CREATE INDEX geom_gix ON markers USING GIST (geography(geom));')) + conn.execute(sa.text('CREATE INDEX discussions_gix ON discussions USING GIST (geography(geom));')) ### end Alembic commands ### @@ -26,6 +27,6 @@ def upgrade(): def downgrade(): ### commands auto generated by Alembic - please adjust! ### conn = op.get_bind() - conn.execute('DROP INDEX geom_gix;') - conn.execute('DROP INDEX discussions_gix;') + conn.execute(sa.text('DROP INDEX geom_gix;')) + conn.execute(sa.text('DROP INDEX discussions_gix;')) ### end Alembic commands ### diff --git a/alembic/versions/5a5ffe56bb7_adding_geom_table_to_markers_and_discussions.py b/alembic/versions/5a5ffe56bb7_adding_geom_table_to_markers_and_discussions.py index e87c87532..97b370b0c 100644 --- a/alembic/versions/5a5ffe56bb7_adding_geom_table_to_markers_and_discussions.py +++ b/alembic/versions/5a5ffe56bb7_adding_geom_table_to_markers_and_discussions.py @@ -12,19 +12,20 @@ depends_on = None from alembic import op +import sqlalchemy as sa def upgrade(): ### commands auto generated by Alembic - please adjust! ### conn = op.get_bind() - conn.execute('CREATE EXTENSION IF NOT EXISTS postgis;') - conn.execute('CREATE EXTENSION IF NOT EXISTS postgis_topology;') - conn.execute("SELECT AddGeometryColumn('public','markers','geom',4326,'POINT',2);") - conn.execute('UPDATE markers SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326);') - conn.execute('CREATE INDEX idx_markers_geom ON markers USING GIST(geom);') - conn.execute("SELECT AddGeometryColumn('public','discussions','geom',4326,'POINT',2);") - conn.execute('UPDATE discussions SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326);') - conn.execute('CREATE INDEX idx_discussions_geom ON discussions USING GIST(geom);') + conn.execute(sa.text('CREATE EXTENSION IF NOT EXISTS postgis;')) + conn.execute(sa.text('CREATE EXTENSION IF NOT EXISTS postgis_topology;')) + conn.execute(sa.text("SELECT AddGeometryColumn('public','markers','geom',4326,'POINT',2);")) + conn.execute(sa.text('UPDATE markers SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326);')) + conn.execute(sa.text('CREATE INDEX idx_markers_geom ON markers USING GIST(geom);')) + conn.execute(sa.text("SELECT AddGeometryColumn('public','discussions','geom',4326,'POINT',2);")) + conn.execute(sa.text('UPDATE discussions SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326);')) + conn.execute(sa.text('CREATE INDEX idx_discussions_geom ON discussions USING GIST(geom);')) ### end Alembic commands ### @@ -32,11 +33,11 @@ def upgrade(): def downgrade(): ### commands auto generated by Alembic - please adjust! ### conn = op.get_bind() - conn.execute('DROP INDEX idx_markers_geom;') + conn.execute(sa.text('DROP INDEX idx_markers_geom;')) op.drop_column('markers', 'geom') - conn.execute('DROP INDEX idx_discussions_geom;') + conn.execute(sa.text('DROP INDEX idx_discussions_geom;')) op.drop_column('discussions', 'geom') - conn.execute('DROP EXTENSION postgis_topology;') - conn.execute('DROP EXTENSION postgis;') - conn.execute('DROP SCHEMA IF EXISTS topology CASCADE;') + conn.execute(sa.text('DROP EXTENSION postgis_topology;')) + conn.execute(sa.text('DROP EXTENSION postgis;')) + conn.execute(sa.text('DROP SCHEMA IF EXISTS topology CASCADE;')) ### end Alembic commands ### diff --git a/alembic/versions/7574885e1fed_remove_unecessary_table_index.py b/alembic/versions/7574885e1fed_remove_unecessary_table_index.py index 40ec44ba7..768afb1be 100644 --- a/alembic/versions/7574885e1fed_remove_unecessary_table_index.py +++ b/alembic/versions/7574885e1fed_remove_unecessary_table_index.py @@ -34,6 +34,6 @@ def downgrade(): op.create_index('provider_and_id_idx_involved', 'involved', ['provider_and_id'], unique=False) op.create_index('provider_and_id_idx_vehicles', 'vehicles', ['provider_and_id'], unique=False) conn = op.get_bind() - conn.execute('CREATE INDEX geom_gix ON markers USING GIST (geography(geom));') - conn.execute('CREATE INDEX discussions_gix ON discussions USING GIST (geography(geom));') + conn.execute(sa.text('CREATE INDEX geom_gix ON markers USING GIST (geography(geom));')) + conn.execute(sa.text('CREATE INDEX discussions_gix ON discussions USING GIST (geography(geom));')) # ### end Alembic commands ### diff --git a/alembic/versions/7f629b4c8891_add_news_flash_fields.py b/alembic/versions/7f629b4c8891_add_news_flash_fields.py index da55c81eb..0a3a16af4 100644 --- a/alembic/versions/7f629b4c8891_add_news_flash_fields.py +++ b/alembic/versions/7f629b4c8891_add_news_flash_fields.py @@ -19,9 +19,9 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### conn = op.get_bind() - conn.execute("ALTER TABLE news_flash ALTER COLUMN id SET DEFAULT nextval('news_flash_id_seq');") - conn.execute("ALTER SEQUENCE news_flash_id_seq OWNED BY news_flash.id;") - conn.execute("SELECT setval('news_flash_id_seq', COALESCE(max(id), 1)) FROM news_flash;") + conn.execute(sa.text("ALTER TABLE news_flash ALTER COLUMN id SET DEFAULT nextval('news_flash_id_seq');")) + conn.execute(sa.text("ALTER SEQUENCE news_flash_id_seq OWNED BY news_flash.id;")) + conn.execute(sa.text("SELECT setval('news_flash_id_seq', COALESCE(max(id), 1)) FROM news_flash;")) op.add_column('news_flash', sa.Column('district_hebrew', sa.Text(), nullable=True)) op.add_column('news_flash', sa.Column('non_urban_intersection_hebrew', sa.Text(), nullable=True)) op.add_column('news_flash', sa.Column('region_hebrew', sa.Text(), nullable=True)) diff --git a/anyway/accidents_around_schools.py b/anyway/accidents_around_schools.py index f166e6ad4..dddcb1f97 100644 --- a/anyway/accidents_around_schools.py +++ b/anyway/accidents_around_schools.py @@ -7,7 +7,7 @@ from anyway.backend_constants import BE_CONST from anyway.models import AccidentMarker, Involved, School -from anyway.app_and_db import db +from anyway.app_and_db import db, app SUBTYPE_ACCIDENT_WITH_PEDESTRIAN = 1 LOCATION_ACCURACY_PRECISE = True @@ -48,27 +48,26 @@ def acc_inv_query(longitude, latitude, distance, start_date, end_date, school): pol_str = "POLYGON(({0} {1},{0} {3},{2} {3},{2} {1},{0} {1}))".format( base_x, base_y, distance_x, distance_y ) - - query_obj = ( - db.session.query(Involved, AccidentMarker) - .join(AccidentMarker, AccidentMarker.provider_and_id == Involved.provider_and_id) - .filter(AccidentMarker.geom.intersects(pol_str)) - .filter(Involved.injured_type == INJURED_TYPE_PEDESTRIAN) - .filter(AccidentMarker.provider_and_id == Involved.provider_and_id) - .filter( - or_( - (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_1_CODE), - (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_3_CODE), + with app.app_context(): + query_obj = ( + db.session.query(Involved, AccidentMarker) + .join(AccidentMarker, AccidentMarker.provider_and_id == Involved.provider_and_id) + .filter(AccidentMarker.geom.intersects(pol_str)) + .filter(Involved.injured_type == INJURED_TYPE_PEDESTRIAN) + .filter(AccidentMarker.provider_and_id == Involved.provider_and_id) + .filter( + or_( + (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_1_CODE), + (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_3_CODE), + ) ) - ) - .filter(AccidentMarker.created >= start_date) - .filter(AccidentMarker.created < end_date) - .filter(AccidentMarker.location_accuracy == LOCATION_ACCURACY_PRECISE_INT) - .filter(AccidentMarker.yishuv_symbol != YISHUV_SYMBOL_NOT_EXIST) - .filter(Involved.age_group.in_([1, 2, 3, 4])) - ) # ages 0-19 - - df = pd.read_sql_query(query_obj.with_labels().statement, db.get_engine()) + .filter(AccidentMarker.created >= start_date) + .filter(AccidentMarker.created < end_date) + .filter(AccidentMarker.location_accuracy == LOCATION_ACCURACY_PRECISE_INT) + .filter(AccidentMarker.yishuv_symbol != YISHUV_SYMBOL_NOT_EXIST) + .filter(Involved.age_group.in_([1, 2, 3, 4])) + ) # ages 0-19 + df = pd.read_sql_query(query_obj.with_labels().statement, db.get_engine()) if LOCATION_ACCURACY_PRECISE: location_accurate = 1 @@ -110,7 +109,8 @@ def acc_inv_query(longitude, latitude, distance, start_date, end_date, school): def main(start_date, end_date, distance, output_path): schools_query = sa.select([School]) - df_schools = pd.read_sql_query(schools_query, db.get_engine()) + with app.app_context(): + df_schools = pd.read_sql_query(schools_query, db.get_engine()) df_total = pd.DataFrame() df_schools = df_schools.drop_duplicates( # pylint: disable=no-member ["yishuv_name", "longitude", "latitude"] diff --git a/anyway/database.py b/anyway/database.py index b26c6f6de..5848bf9e4 100644 --- a/anyway/database.py +++ b/anyway/database.py @@ -1,7 +1,7 @@ try: - from anyway.app_and_db import db - - Base = db.Model + from anyway.app_and_db import db, app + with app.app_context(): + Base = db.Model except ModuleNotFoundError: from sqlalchemy.ext.declarative.api import declarative_base diff --git a/anyway/models.py b/anyway/models.py index 04aced1fa..baa9bd351 100755 --- a/anyway/models.py +++ b/anyway/models.py @@ -49,7 +49,7 @@ class UserMixin: from anyway.utilities import decode_hebrew try: - from anyway.app_and_db import db + from anyway.app_and_db import db, app except ModuleNotFoundError: pass @@ -333,15 +333,16 @@ class AccidentMarker(MarkerMixin, Base): @staticmethod def get_latest_marker_created_date(): - latest_created_date = ( - db.session.query(func.max(AccidentMarker.created)) - .filter( - AccidentMarker.provider_code.in_( - [BE_CONST.CBS_ACCIDENT_TYPE_1_CODE, BE_CONST.CBS_ACCIDENT_TYPE_3_CODE] + with app.app_context(): + latest_created_date = ( + db.session.query(func.max(AccidentMarker.created)) + .filter( + AccidentMarker.provider_code.in_( + [BE_CONST.CBS_ACCIDENT_TYPE_1_CODE, BE_CONST.CBS_ACCIDENT_TYPE_3_CODE] + ) ) + .first() ) - .first() - ) if latest_created_date is None: return None @@ -422,278 +423,280 @@ def serialize(self, is_thin=False): def bounding_box_query( is_thin=False, yield_per=None, involved_and_vehicles=False, query_entities=None, **kwargs ): - approx = kwargs.get("approx", True) - accurate = kwargs.get("accurate", True) - page = kwargs.get("page") - per_page = kwargs.get("per_page") - - if not kwargs.get("show_markers", True): - return MarkerResult( - accident_markers=db.session.query(AccidentMarker).filter(sql.false()), - rsa_markers=db.session.query(AccidentMarker).filter(sql.false()), - total_records=0, - ) - - sw_lat = float(kwargs["sw_lat"]) - sw_lng = float(kwargs["sw_lng"]) - ne_lat = float(kwargs["ne_lat"]) - ne_lng = float(kwargs["ne_lng"]) - polygon_str = "POLYGON(({0} {1},{0} {3},{2} {3},{2} {1},{0} {1}))".format( - sw_lng, sw_lat, ne_lng, ne_lat - ) - - if query_entities is not None: - markers = ( - db.session.query(AccidentMarker) - .with_entities(*query_entities) - .filter(AccidentMarker.geom.intersects(polygon_str)) - .filter(AccidentMarker.created >= kwargs["start_date"]) - .filter(AccidentMarker.created <= kwargs["end_date"]) - .filter(AccidentMarker.provider_code != BE_CONST.RSA_PROVIDER_CODE) - .order_by(desc(AccidentMarker.created)) + with app.app_context(): + approx = kwargs.get("approx", True) + accurate = kwargs.get("accurate", True) + page = kwargs.get("page") + per_page = kwargs.get("per_page") + + if not kwargs.get("show_markers", True): + with app.app_context(): + return MarkerResult( + accident_markers=db.session.query(AccidentMarker).filter(sql.false()), + rsa_markers=db.session.query(AccidentMarker).filter(sql.false()), + total_records=0, + ) + + sw_lat = float(kwargs["sw_lat"]) + sw_lng = float(kwargs["sw_lng"]) + ne_lat = float(kwargs["ne_lat"]) + ne_lng = float(kwargs["ne_lng"]) + polygon_str = "POLYGON(({0} {1},{0} {3},{2} {3},{2} {1},{0} {1}))".format( + sw_lng, sw_lat, ne_lng, ne_lat ) - rsa_markers = ( - db.session.query(AccidentMarker) - .with_entities(*query_entities) - .filter(AccidentMarker.geom.intersects(polygon_str)) - .filter(AccidentMarker.created >= kwargs["start_date"]) - .filter(AccidentMarker.created <= kwargs["end_date"]) - .filter(AccidentMarker.provider_code == BE_CONST.RSA_PROVIDER_CODE) - .order_by(desc(AccidentMarker.created)) - ) - else: - markers = ( - db.session.query(AccidentMarker) - .filter(AccidentMarker.geom.intersects(polygon_str)) - .filter(AccidentMarker.created >= kwargs["start_date"]) - .filter(AccidentMarker.created <= kwargs["end_date"]) - .filter(AccidentMarker.provider_code != BE_CONST.RSA_PROVIDER_CODE) - .order_by(desc(AccidentMarker.created)) - ) - - rsa_markers = ( - db.session.query(AccidentMarker) - .filter(AccidentMarker.geom.intersects(polygon_str)) - .filter(AccidentMarker.created >= kwargs["start_date"]) - .filter(AccidentMarker.created <= kwargs["end_date"]) - .filter(AccidentMarker.provider_code == BE_CONST.RSA_PROVIDER_CODE) - .order_by(desc(AccidentMarker.created)) - ) - - if not kwargs["show_rsa"]: - rsa_markers = db.session.query(AccidentMarker).filter(sql.false()) - if not kwargs["show_accidents"]: - markers = markers.filter( - and_( - AccidentMarker.provider_code != BE_CONST.CBS_ACCIDENT_TYPE_1_CODE, - AccidentMarker.provider_code != BE_CONST.CBS_ACCIDENT_TYPE_3_CODE, - AccidentMarker.provider_code != BE_CONST.UNITED_HATZALA_CODE, - ) - ) - if yield_per: - markers = markers.yield_per(yield_per) - if accurate and not approx: - markers = markers.filter(AccidentMarker.location_accuracy == 1) - elif approx and not accurate: - markers = markers.filter(AccidentMarker.location_accuracy != 1) - elif not accurate and not approx: - return MarkerResult( - accident_markers=db.session.query(AccidentMarker).filter(sql.false()), - rsa_markers=db.session.query(AccidentMarker).filter(sql.false()), - total_records=0, - ) - if not kwargs.get("show_fatal", True): - markers = markers.filter(AccidentMarker.accident_severity != 1) - if not kwargs.get("show_severe", True): - markers = markers.filter(AccidentMarker.accident_severity != 2) - if not kwargs.get("show_light", True): - markers = markers.filter(AccidentMarker.accident_severity != 3) - if kwargs.get("show_urban", 3) != 3: - if kwargs["show_urban"] == 2: - markers = markers.filter(AccidentMarker.road_type >= 1).filter( - AccidentMarker.road_type <= 2 + if query_entities is not None: + markers = ( + db.session.query(AccidentMarker) + .with_entities(*query_entities) + .filter(AccidentMarker.geom.intersects(polygon_str)) + .filter(AccidentMarker.created >= kwargs["start_date"]) + .filter(AccidentMarker.created <= kwargs["end_date"]) + .filter(AccidentMarker.provider_code != BE_CONST.RSA_PROVIDER_CODE) + .order_by(desc(AccidentMarker.created)) ) - elif kwargs["show_urban"] == 1: - markers = markers.filter(AccidentMarker.road_type >= 3).filter( - AccidentMarker.road_type <= 4 + rsa_markers = ( + db.session.query(AccidentMarker) + .with_entities(*query_entities) + .filter(AccidentMarker.geom.intersects(polygon_str)) + .filter(AccidentMarker.created >= kwargs["start_date"]) + .filter(AccidentMarker.created <= kwargs["end_date"]) + .filter(AccidentMarker.provider_code == BE_CONST.RSA_PROVIDER_CODE) + .order_by(desc(AccidentMarker.created)) ) else: - return MarkerResult( - accident_markers=db.session.query(AccidentMarker).filter(sql.false()), - rsa_markers=rsa_markers, - total_records=None, + markers = ( + db.session.query(AccidentMarker) + .filter(AccidentMarker.geom.intersects(polygon_str)) + .filter(AccidentMarker.created >= kwargs["start_date"]) + .filter(AccidentMarker.created <= kwargs["end_date"]) + .filter(AccidentMarker.provider_code != BE_CONST.RSA_PROVIDER_CODE) + .order_by(desc(AccidentMarker.created)) ) - if kwargs.get("show_intersection", 3) != 3: - if kwargs["show_intersection"] == 2: - markers = markers.filter(AccidentMarker.road_type != 2).filter( - AccidentMarker.road_type != 4 + + rsa_markers = ( + db.session.query(AccidentMarker) + .filter(AccidentMarker.geom.intersects(polygon_str)) + .filter(AccidentMarker.created >= kwargs["start_date"]) + .filter(AccidentMarker.created <= kwargs["end_date"]) + .filter(AccidentMarker.provider_code == BE_CONST.RSA_PROVIDER_CODE) + .order_by(desc(AccidentMarker.created)) ) - elif kwargs["show_intersection"] == 1: - markers = markers.filter(AccidentMarker.road_type != 1).filter( - AccidentMarker.road_type != 3 + + if not kwargs["show_rsa"]: + rsa_markers = db.session.query(AccidentMarker).filter(sql.false()) + if not kwargs["show_accidents"]: + markers = markers.filter( + and_( + AccidentMarker.provider_code != BE_CONST.CBS_ACCIDENT_TYPE_1_CODE, + AccidentMarker.provider_code != BE_CONST.CBS_ACCIDENT_TYPE_3_CODE, + AccidentMarker.provider_code != BE_CONST.UNITED_HATZALA_CODE, + ) ) - else: + if yield_per: + markers = markers.yield_per(yield_per) + if accurate and not approx: + markers = markers.filter(AccidentMarker.location_accuracy == 1) + elif approx and not accurate: + markers = markers.filter(AccidentMarker.location_accuracy != 1) + elif not accurate and not approx: return MarkerResult( accident_markers=db.session.query(AccidentMarker).filter(sql.false()), - rsa_markers=rsa_markers, - total_records=None, + rsa_markers=db.session.query(AccidentMarker).filter(sql.false()), + total_records=0, ) - if kwargs.get("show_lane", 3) != 3: - if kwargs["show_lane"] == 2: - markers = markers.filter(AccidentMarker.one_lane >= 2).filter( - AccidentMarker.one_lane <= 3 + if not kwargs.get("show_fatal", True): + markers = markers.filter(AccidentMarker.accident_severity != 1) + if not kwargs.get("show_severe", True): + markers = markers.filter(AccidentMarker.accident_severity != 2) + if not kwargs.get("show_light", True): + markers = markers.filter(AccidentMarker.accident_severity != 3) + if kwargs.get("show_urban", 3) != 3: + if kwargs["show_urban"] == 2: + markers = markers.filter(AccidentMarker.road_type >= 1).filter( + AccidentMarker.road_type <= 2 + ) + elif kwargs["show_urban"] == 1: + markers = markers.filter(AccidentMarker.road_type >= 3).filter( + AccidentMarker.road_type <= 4 + ) + else: + return MarkerResult( + accident_markers=db.session.query(AccidentMarker).filter(sql.false()), + rsa_markers=rsa_markers, + total_records=None, + ) + if kwargs.get("show_intersection", 3) != 3: + if kwargs["show_intersection"] == 2: + markers = markers.filter(AccidentMarker.road_type != 2).filter( + AccidentMarker.road_type != 4 + ) + elif kwargs["show_intersection"] == 1: + markers = markers.filter(AccidentMarker.road_type != 1).filter( + AccidentMarker.road_type != 3 + ) + else: + return MarkerResult( + accident_markers=db.session.query(AccidentMarker).filter(sql.false()), + rsa_markers=rsa_markers, + total_records=None, + ) + if kwargs.get("show_lane", 3) != 3: + if kwargs["show_lane"] == 2: + markers = markers.filter(AccidentMarker.one_lane >= 2).filter( + AccidentMarker.one_lane <= 3 + ) + elif kwargs["show_lane"] == 1: + markers = markers.filter(AccidentMarker.one_lane == 1) + else: + return MarkerResult( + accident_markers=db.session.query(AccidentMarker).filter(sql.false()), + rsa_markers=rsa_markers, + total_records=None, + ) + + if kwargs.get("show_day", 7) != 7: + markers = markers.filter( + func.extract("dow", AccidentMarker.created) == kwargs["show_day"] ) - elif kwargs["show_lane"] == 1: - markers = markers.filter(AccidentMarker.one_lane == 1) + if kwargs.get("show_holiday", 0) != 0: + markers = markers.filter(AccidentMarker.day_type == kwargs["show_holiday"]) + + if kwargs.get("show_time", 24) != 24: + if kwargs["show_time"] == 25: # Daylight (6-18) + markers = markers.filter(func.extract("hour", AccidentMarker.created) >= 6).filter( + func.extract("hour", AccidentMarker.created) < 18 + ) + elif kwargs["show_time"] == 26: # Darktime (18-6) + markers = markers.filter( + (func.extract("hour", AccidentMarker.created) >= 18) + | (func.extract("hour", AccidentMarker.created) < 6) + ) + else: + markers = markers.filter( + func.extract("hour", AccidentMarker.created) >= kwargs["show_time"] + ).filter(func.extract("hour", AccidentMarker.created) < kwargs["show_time"] + 6) + elif kwargs["start_time"] != 25 and kwargs["end_time"] != 25: + markers = markers.filter( + func.extract("hour", AccidentMarker.created) >= kwargs["start_time"] + ).filter(func.extract("hour", AccidentMarker.created) < kwargs["end_time"]) + if kwargs.get("weather", 0) != 0: + markers = markers.filter(AccidentMarker.weather == kwargs["weather"]) + if kwargs.get("road", 0) != 0: + markers = markers.filter(AccidentMarker.road_shape == kwargs["road"]) + if kwargs.get("separation", 0) != 0: + markers = markers.filter(AccidentMarker.multi_lane == kwargs["separation"]) + if kwargs.get("surface", 0) != 0: + markers = markers.filter(AccidentMarker.road_surface == kwargs["surface"]) + if kwargs.get("acctype", 0) != 0: + if kwargs["acctype"] <= 20: + markers = markers.filter(AccidentMarker.accident_type == kwargs["acctype"]) + elif kwargs["acctype"] == BE_CONST.BIKE_ACCIDENTS: + markers = markers.filter( + AccidentMarker.vehicles.any(Vehicle.vehicle_type == BE_VehicleType.BIKE.value) + ) + if kwargs.get("controlmeasure", 0) != 0: + markers = markers.filter(AccidentMarker.road_control == kwargs["controlmeasure"]) + if kwargs.get("district", 0) != 0: + markers = markers.filter(AccidentMarker.police_unit == kwargs["district"]) + + if kwargs.get("case_type", 0) != 0: + markers = markers.filter(AccidentMarker.provider_code == kwargs["case_type"]) + + if is_thin: + markers = markers.options(load_only("id", "longitude", "latitude")) + + if kwargs.get("age_groups"): + age_groups_list = kwargs.get("age_groups").split(",") + if len(age_groups_list) < (BE_CONST.AGE_GROUPS_NUMBER + 1): + markers = markers.filter( + AccidentMarker.involved.any(Involved.age_group.in_(age_groups_list)) + ) else: - return MarkerResult( - accident_markers=db.session.query(AccidentMarker).filter(sql.false()), - rsa_markers=rsa_markers, - total_records=None, - ) + markers = db.session.query(AccidentMarker).filter(sql.false()) - if kwargs.get("show_day", 7) != 7: - markers = markers.filter( - func.extract("dow", AccidentMarker.created) == kwargs["show_day"] - ) - if kwargs.get("show_holiday", 0) != 0: - markers = markers.filter(AccidentMarker.day_type == kwargs["show_holiday"]) - - if kwargs.get("show_time", 24) != 24: - if kwargs["show_time"] == 25: # Daylight (6-18) - markers = markers.filter(func.extract("hour", AccidentMarker.created) >= 6).filter( - func.extract("hour", AccidentMarker.created) < 18 - ) - elif kwargs["show_time"] == 26: # Darktime (18-6) + if kwargs.get("light_transportation", False): + age_groups_list = kwargs.get("age_groups").split(",") + LOCATION_ACCURACY_PRECISE_LIST = [1, 3, 4] markers = markers.filter( - (func.extract("hour", AccidentMarker.created) >= 18) - | (func.extract("hour", AccidentMarker.created) < 6) + AccidentMarker.location_accuracy.in_(LOCATION_ACCURACY_PRECISE_LIST) ) - else: - markers = markers.filter( - func.extract("hour", AccidentMarker.created) >= kwargs["show_time"] - ).filter(func.extract("hour", AccidentMarker.created) < kwargs["show_time"] + 6) - elif kwargs["start_time"] != 25 and kwargs["end_time"] != 25: - markers = markers.filter( - func.extract("hour", AccidentMarker.created) >= kwargs["start_time"] - ).filter(func.extract("hour", AccidentMarker.created) < kwargs["end_time"]) - if kwargs.get("weather", 0) != 0: - markers = markers.filter(AccidentMarker.weather == kwargs["weather"]) - if kwargs.get("road", 0) != 0: - markers = markers.filter(AccidentMarker.road_shape == kwargs["road"]) - if kwargs.get("separation", 0) != 0: - markers = markers.filter(AccidentMarker.multi_lane == kwargs["separation"]) - if kwargs.get("surface", 0) != 0: - markers = markers.filter(AccidentMarker.road_surface == kwargs["surface"]) - if kwargs.get("acctype", 0) != 0: - if kwargs["acctype"] <= 20: - markers = markers.filter(AccidentMarker.accident_type == kwargs["acctype"]) - elif kwargs["acctype"] == BE_CONST.BIKE_ACCIDENTS: + INJURED_TYPES = [1, 6, 7] markers = markers.filter( - AccidentMarker.vehicles.any(Vehicle.vehicle_type == BE_VehicleType.BIKE.value) + or_( + AccidentMarker.involved.any( + and_( + Involved.injured_type.in_(INJURED_TYPES), + Involved.injury_severity >= 1, + Involved.injury_severity <= 3, + Involved.age_group.in_(age_groups_list), + ) + ), + AccidentMarker.involved.any( + and_( + Involved.vehicle_type == 15, + Involved.injury_severity >= 1, + Involved.injury_severity <= 3, + Involved.age_group.in_(age_groups_list), + ) + ), + AccidentMarker.involved.any( + and_( + Involved.vehicle_type == 21, + Involved.injury_severity >= 1, + Involved.injury_severity <= 3, + Involved.age_group.in_(age_groups_list), + ) + ), + AccidentMarker.involved.any( + and_( + Involved.vehicle_type == 23, + Involved.injury_severity >= 1, + Involved.injury_severity <= 3, + Involved.age_group.in_(age_groups_list), + ) + ), + ) ) - if kwargs.get("controlmeasure", 0) != 0: - markers = markers.filter(AccidentMarker.road_control == kwargs["controlmeasure"]) - if kwargs.get("district", 0) != 0: - markers = markers.filter(AccidentMarker.police_unit == kwargs["district"]) - - if kwargs.get("case_type", 0) != 0: - markers = markers.filter(AccidentMarker.provider_code == kwargs["case_type"]) - - if is_thin: - markers = markers.options(load_only("id", "longitude", "latitude")) - if kwargs.get("age_groups"): - age_groups_list = kwargs.get("age_groups").split(",") - if len(age_groups_list) < (BE_CONST.AGE_GROUPS_NUMBER + 1): - markers = markers.filter( - AccidentMarker.involved.any(Involved.age_group.in_(age_groups_list)) + if page and per_page: + markers = markers.offset((page - 1) * per_page).limit(per_page) + + if involved_and_vehicles: + fetch_markers = kwargs.get("fetch_markers", True) + fetch_vehicles = kwargs.get("fetch_vehicles", True) + fetch_involved = kwargs.get("fetch_involved", True) + markers_ids = [marker.id for marker in markers] + markers = None + vehicles = None + involved = None + if fetch_markers: + markers = db.session.query(AccidentMarker).filter( + AccidentMarker.id.in_(markers_ids) + ) + if fetch_vehicles: + vehicles = db.session.query(Vehicle).filter(Vehicle.accident_id.in_(markers_ids)) + if fetch_involved: + involved = db.session.query(Involved).filter(Involved.accident_id.in_(markers_ids)) + result = ( + markers.all() if markers is not None else [], + vehicles.all() if vehicles is not None else [], + involved.all() if involved is not None else [], ) - else: - markers = db.session.query(AccidentMarker).filter(sql.false()) - - if kwargs.get("light_transportation", False): - age_groups_list = kwargs.get("age_groups").split(",") - LOCATION_ACCURACY_PRECISE_LIST = [1, 3, 4] - markers = markers.filter( - AccidentMarker.location_accuracy.in_(LOCATION_ACCURACY_PRECISE_LIST) - ) - INJURED_TYPES = [1, 6, 7] - markers = markers.filter( - or_( - AccidentMarker.involved.any( - and_( - Involved.injured_type.in_(INJURED_TYPES), - Involved.injury_severity >= 1, - Involved.injury_severity <= 3, - Involved.age_group.in_(age_groups_list), - ) - ), - AccidentMarker.involved.any( - and_( - Involved.vehicle_type == 15, - Involved.injury_severity >= 1, - Involved.injury_severity <= 3, - Involved.age_group.in_(age_groups_list), - ) - ), - AccidentMarker.involved.any( - and_( - Involved.vehicle_type == 21, - Involved.injury_severity >= 1, - Involved.injury_severity <= 3, - Involved.age_group.in_(age_groups_list), - ) - ), - AccidentMarker.involved.any( - and_( - Involved.vehicle_type == 23, - Involved.injury_severity >= 1, - Involved.injury_severity <= 3, - Involved.age_group.in_(age_groups_list), - ) - ), + return MarkerResult( + accident_markers=result, + rsa_markers=db.session.query(AccidentMarker).filter(sql.false()), + total_records=len(result), ) - ) - - if page and per_page: - markers = markers.offset((page - 1) * per_page).limit(per_page) - - if involved_and_vehicles: - fetch_markers = kwargs.get("fetch_markers", True) - fetch_vehicles = kwargs.get("fetch_vehicles", True) - fetch_involved = kwargs.get("fetch_involved", True) - markers_ids = [marker.id for marker in markers] - markers = None - vehicles = None - involved = None - if fetch_markers: - markers = db.session.query(AccidentMarker).filter( - AccidentMarker.id.in_(markers_ids) + else: + return MarkerResult( + accident_markers=markers, rsa_markers=rsa_markers, total_records=None ) - if fetch_vehicles: - vehicles = db.session.query(Vehicle).filter(Vehicle.accident_id.in_(markers_ids)) - if fetch_involved: - involved = db.session.query(Involved).filter(Involved.accident_id.in_(markers_ids)) - result = ( - markers.all() if markers is not None else [], - vehicles.all() if vehicles is not None else [], - involved.all() if involved is not None else [], - ) - return MarkerResult( - accident_markers=result, - rsa_markers=db.session.query(AccidentMarker).filter(sql.false()), - total_records=len(result), - ) - else: - return MarkerResult( - accident_markers=markers, rsa_markers=rsa_markers, total_records=None - ) @staticmethod def get_marker(marker_id): - return db.session.query(AccidentMarker).filter_by(id=marker_id) + with app.app_context(): + return db.session.query(AccidentMarker).filter_by(id=marker_id) @classmethod def parse(cls, data): @@ -731,7 +734,8 @@ def serialize(self, is_thin=False): @staticmethod def get_by_identifier(identifier): - return db.session.query(DiscussionMarker).filter_by(identifier=identifier) + with app.app_context(): + return db.session.query(DiscussionMarker).filter_by(identifier=identifier) @classmethod def parse(cls, data): @@ -749,17 +753,18 @@ def parse(cls, data): @staticmethod def bounding_box_query(ne_lat, ne_lng, sw_lat, sw_lng, show_discussions): - if not show_discussions: - return db.session.query(AccidentMarker).filter(sql.false()) - markers = ( - db.session.query(DiscussionMarker) - .filter(DiscussionMarker.longitude <= ne_lng) - .filter(DiscussionMarker.longitude >= sw_lng) - .filter(DiscussionMarker.latitude <= ne_lat) - .filter(DiscussionMarker.latitude >= sw_lat) - .order_by(desc(DiscussionMarker.created)) - ) - return markers + with app.app_context(): + if not show_discussions: + return db.session.query(AccidentMarker).filter(sql.false()) + markers = ( + db.session.query(DiscussionMarker) + .filter(DiscussionMarker.longitude <= ne_lng) + .filter(DiscussionMarker.longitude >= sw_lng) + .filter(DiscussionMarker.latitude <= ne_lat) + .filter(DiscussionMarker.latitude >= sw_lat) + .order_by(desc(DiscussionMarker.created)) + ) + return markers class Involved(Base): @@ -1044,27 +1049,30 @@ class City(CityFields, Base): @staticmethod def get_name_from_symbol(symbol: int) -> str: - res = db.session.query(City.heb_name).filter(City.yishuv_symbol == symbol).first() - if res is None: - raise ValueError(f"{symbol}: could not find city with that symbol") - return res.heb_name + with app.app_context(): + res = db.session.query(City.heb_name).filter(City.yishuv_symbol == symbol).first() + if res is None: + raise ValueError(f"{symbol}: could not find city with that symbol") + return res.heb_name @staticmethod def get_symbol_from_name(name: str) -> int: - res: City = db.session.query(City.yishuv_symbol).filter(City.heb_name == name).first() - if res is None: - logging.error(f"City: no city with name:{name}.") - raise ValueError(f"City: no city with name:{name}.") - return res.yishuv_symbol + with app.app_context(): + res: City = db.session.query(City.yishuv_symbol).filter(City.heb_name == name).first() + if res is None: + logging.error(f"City: no city with name:{name}.") + raise ValueError(f"City: no city with name:{name}.") + return res.yishuv_symbol @staticmethod def get_all_cities() -> List[dict]: - res: City = db.session.query(City.yishuv_symbol, City.heb_name).all() - if res is None: - logging.error(f"Failed to get cities.") - raise RuntimeError(f"When retrieving all cities") - res1 = [{"yishuv_symbol": c.yishuv_symbol, "yishuv_name": c.heb_name} for c in res] - return res1 + with app.app_context(): + res: City = db.session.query(City.yishuv_symbol, City.heb_name).all() + if res is None: + logging.error(f"Failed to get cities.") + raise RuntimeError(f"When retrieving all cities") + res1 = [{"yishuv_symbol": c.yishuv_symbol, "yishuv_name": c.heb_name} for c in res] + return res1 class CityTemp(CityFields, Base): @@ -1092,35 +1100,38 @@ def serialize(self): @staticmethod def get_name_from_symbol(symbol: int) -> str: - res = ( - db.session.query(DeprecatedCity.search_heb) - .filter(DeprecatedCity.symbol_code == symbol) - .first() - ) - if res is None: - raise ValueError(f"{symbol}: could not find city with that symbol") - return res.search_heb + with app.app_context(): + res = ( + db.session.query(DeprecatedCity.search_heb) + .filter(DeprecatedCity.symbol_code == symbol) + .first() + ) + if res is None: + raise ValueError(f"{symbol}: could not find city with that symbol") + return res.search_heb @staticmethod def get_symbol_from_name(name: str) -> int: - res = ( - db.session.query(DeprecatedCity.symbol_code) - .filter(DeprecatedCity.search_heb == name) - .first() - ) - if res is None: - logging.error(f"DeprecatedCity: no city with name:{name}.") - raise ValueError(f"DeprecatedCity: no city with name:{name}.") - return res.symbol_code + with app.app_context(): + res = ( + db.session.query(DeprecatedCity.symbol_code) + .filter(DeprecatedCity.search_heb == name) + .first() + ) + if res is None: + logging.error(f"DeprecatedCity: no city with name:{name}.") + raise ValueError(f"DeprecatedCity: no city with name:{name}.") + return res.symbol_code @staticmethod def get_all_cities() -> List[dict]: - res = db.session.query(DeprecatedCity.symbol_code, DeprecatedCity.search_heb).all() - if res is None: - logging.error(f"Failed to get cities.") - raise RuntimeError(f"When retrieving all cities") - res1 = [{"yishuv_symbol": c.symbol_code, "yishuv_name": c.search_heb} for c in res] - return res1 + with app.app_context(): + res = db.session.query(DeprecatedCity.symbol_code, DeprecatedCity.search_heb).all() + if res is None: + logging.error(f"Failed to get cities.") + raise RuntimeError(f"When retrieving all cities") + res1 = [{"yishuv_symbol": c.symbol_code, "yishuv_name": c.search_heb} for c in res] + return res1 # Flask-Login integration def is_authenticated(self): @@ -1152,39 +1163,42 @@ def serialize(self): @staticmethod def get_street_name_by_street(yishuv_symbol: int, street: int) -> str: - res = ( - db.session.query(Streets.street_hebrew) - .filter(Streets.yishuv_symbol == yishuv_symbol) - .filter(Streets.street == street) - .first() - ) - if res is None: - raise ValueError(f"{street}: could not find street in yishuv:{yishuv_symbol}") - return res.street_hebrew + with app.app_context(): + res = ( + db.session.query(Streets.street_hebrew) + .filter(Streets.yishuv_symbol == yishuv_symbol) + .filter(Streets.street == street) + .first() + ) + if res is None: + raise ValueError(f"{street}: could not find street in yishuv:{yishuv_symbol}") + return res.street_hebrew @staticmethod def get_street_by_street_name(yishuv_symbol: int, name: str) -> int: - res = ( - db.session.query(Streets.street) - .filter(Streets.yishuv_symbol == yishuv_symbol) - .filter(Streets.street_hebrew == name) - .first() - ) - if res is None: - raise ValueError(f"{name}: could not find street in yishuv:{yishuv_symbol}") - return res.street + with app.app_context(): + res = ( + db.session.query(Streets.street) + .filter(Streets.yishuv_symbol == yishuv_symbol) + .filter(Streets.street_hebrew == name) + .first() + ) + if res is None: + raise ValueError(f"{name}: could not find street in yishuv:{yishuv_symbol}") + return res.street @staticmethod def get_streets_by_yishuv(yishuv_symbol: int) -> List[dict]: - res = ( - db.session.query(Streets.street, Streets.street_hebrew) - .filter(Streets.yishuv_symbol == yishuv_symbol) - .all() - ) - res1 = [{"street": s.street, "street_hebrew": s.street_hebrew} for s in res] - if res is None: - raise RuntimeError(f"When retrieving streets of {yishuv_symbol}") - return res1 + with app.app_context(): + res = ( + db.session.query(Streets.street, Streets.street_hebrew) + .filter(Streets.yishuv_symbol == yishuv_symbol) + .all() + ) + res1 = [{"street": s.street, "street_hebrew": s.street_hebrew} for s in res] + if res is None: + raise RuntimeError(f"When retrieving streets of {yishuv_symbol}") + return res1 class SuburbanJunction(Base): @@ -1197,42 +1211,46 @@ class SuburbanJunction(Base): @staticmethod def get_hebrew_name_from_id(non_urban_intersection: int) -> str: - res = db.session.query(SuburbanJunction.non_urban_intersection_hebrew).filter( - SuburbanJunction.non_urban_intersection == non_urban_intersection).first() - if res is None: - raise ValueError(f"{non_urban_intersection}: could not find " - f"SuburbanJunction with that symbol") - return res.non_urban_intersection_hebrew + with app.app_context(): + res = db.session.query(SuburbanJunction.non_urban_intersection_hebrew).filter( + SuburbanJunction.non_urban_intersection == non_urban_intersection).first() + if res is None: + raise ValueError(f"{non_urban_intersection}: could not find " + f"SuburbanJunction with that symbol") + return res.non_urban_intersection_hebrew @staticmethod def get_id_from_hebrew_name(non_urban_intersection_hebrew: str) -> int: - res = db.session.query(SuburbanJunction.non_urban_intersection).filter( - SuburbanJunction.non_urban_intersection == non_urban_intersection_hebrew).first() - if res is None: - raise ValueError(f"{non_urban_intersection_hebrew}: could not find " - f"SuburbanJunction with that name") - return res.non_urban_intersection + with app.app_context(): + res = db.session.query(SuburbanJunction.non_urban_intersection).filter( + SuburbanJunction.non_urban_intersection == non_urban_intersection_hebrew).first() + if res is None: + raise ValueError(f"{non_urban_intersection_hebrew}: could not find " + f"SuburbanJunction with that name") + return res.non_urban_intersection @staticmethod def get_intersection_from_roads(roads: Set[int]) -> dict: if not all([isinstance(x, int) for x in roads]): raise ValueError(f"{roads}: Should be integers") - res = db.session.query(SuburbanJunction).filter( - SuburbanJunction.roads.contains(roads)).first() - if res is None: - raise ValueError(f"{roads}: could not find " - f"SuburbanJunction with these roads") - return res.serialize() + with app.app_context(): + res = db.session.query(SuburbanJunction).filter( + SuburbanJunction.roads.contains(roads)).first() + if res is None: + raise ValueError(f"{roads}: could not find " + f"SuburbanJunction with these roads") + return res.serialize() @staticmethod def get_all_from_key_value(key: str, val: Iterable) -> dict: if not isinstance(val, Iterable): val = [val] - res = db.session.query(SuburbanJunction).filter( - (getattr(SuburbanJunction, key)).in_(val)).first() - if res is None: - raise ValueError(f"{key}:{val}: could not find SuburbanJunction") - return res.serialize() + with app.app_context(): + res = db.session.query(SuburbanJunction).filter( + (getattr(SuburbanJunction, key)).in_(val)).first() + if res is None: + raise ValueError(f"{key}:{val}: could not find SuburbanJunction") + return res.serialize() def serialize(self): return { diff --git a/anyway/parsers/casualties_costs.py b/anyway/parsers/casualties_costs.py index 70d885371..7936c419f 100755 --- a/anyway/parsers/casualties_costs.py +++ b/anyway/parsers/casualties_costs.py @@ -2,7 +2,7 @@ from anyway.models import CasualtiesCosts import pandas as pd import logging -from anyway.app_and_db import db +from anyway.app_and_db import db, app def _iter_rows(filename): @@ -21,14 +21,15 @@ def _iter_rows(filename): def parse(filename): - for row in _iter_rows(filename): - current_report = ( - db.session.query(CasualtiesCosts).filter(CasualtiesCosts.id == row["id"]).all() - ) - if not current_report: - logging.debug(f"adding line {row}") - db.session.bulk_insert_mappings(CasualtiesCosts, [row]) - else: - logging.debug(f"updating line {row}") - db.session.bulk_update_mappings(CasualtiesCosts, [row]) - db.session.commit() + with app.app_context(): + for row in _iter_rows(filename): + current_report = ( + db.session.query(CasualtiesCosts).filter(CasualtiesCosts.id == row["id"]).all() + ) + if not current_report: + logging.debug(f"adding line {row}") + db.session.bulk_insert_mappings(CasualtiesCosts, [row]) + else: + logging.debug(f"updating line {row}") + db.session.bulk_update_mappings(CasualtiesCosts, [row]) + db.session.commit() diff --git a/anyway/parsers/cbs/executor.py b/anyway/parsers/cbs/executor.py index edc08eed8..38e469d65 100644 --- a/anyway/parsers/cbs/executor.py +++ b/anyway/parsers/cbs/executor.py @@ -11,6 +11,7 @@ import math import pandas as pd from sqlalchemy import or_, event +import sqlalchemy as sa from typing import Tuple, Dict, List, Any from anyway.parsers.cbs import preprocessing_cbs_files @@ -88,7 +89,7 @@ from anyway.utilities import ItmToWGS84, time_delta, ImporterUI, truncate_tables, delete_all_rows_from_table, \ chunks, run_query_and_insert_to_table_in_chunks from anyway.db_views import VIEWS -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.parsers.cbs.s3 import S3DataRetriever street_map_type: Dict[int, List[dict]] @@ -571,8 +572,9 @@ def import_accidents(provider_code, accidents, streets, roads, non_urban_interse marker = create_marker(provider_code, accident, streets, roads, non_urban_intersection) add_suburban_junction_from_marker(marker) accidents_result.append(marker) - db.session.bulk_insert_mappings(AccidentMarker, accidents_result) - db.session.commit() + with app.app_context(): + db.session.bulk_insert_mappings(AccidentMarker, accidents_result) + db.session.commit() logging.debug("Finished Importing markers") logging.debug("Inserted " + str(len(accidents_result)) + " new accident markers") fill_db_geo_data() @@ -630,8 +632,9 @@ def import_involved(provider_code, involved, **kwargs): "accident_month": get_data_value(involve.get(field_names.accident_month)), } ) - db.session.bulk_insert_mappings(Involved, involved_result) - db.session.commit() + with app.app_context(): + db.session.bulk_insert_mappings(Involved, involved_result) + db.session.commit() logging.debug("Finished Importing involved") return len(involved_result) @@ -663,8 +666,9 @@ def import_vehicles(provider_code, vehicles, **kwargs): "vehicle_damage": get_data_value(vehicle.get(field_names.vehicle_damage)), } ) - db.session.bulk_insert_mappings(Vehicle, vehicles_result) - db.session.commit() + with app.app_context(): + db.session.bulk_insert_mappings(Vehicle, vehicles_result) + db.session.commit() logging.debug("Finished Importing vehicles") return len(vehicles_result) @@ -779,9 +783,10 @@ def import_streets_into_db(): logging.debug( f"Writing to db: {len(yishuv_street_dict)}:{len(yishuv_name_dict)} -> {len(items)} rows" ) - db.session.query(Streets).delete() - db.session.bulk_insert_mappings(Streets, items) - db.session.commit() + with app.app_context(): + db.session.query(Streets).delete() + db.session.bulk_insert_mappings(Streets, items) + db.session.commit() if max_name_len > Streets.MAX_NAME_LEN: logging.error( f"Importing streets table: Street hebrew name length exceeded: max name: {max_name_len}" @@ -798,15 +803,16 @@ def import_streets_into_db(): def load_existing_streets(): - streets = db.session.query(Streets).all() - for s in streets: - s_dict = { - "yishuv_symbol": s.yishuv_symbol, - "street": s.street, - "street_hebrew": s.street_hebrew, - } - add_street_remove_name_duplicates(s_dict) - add_street_remove_num_duplicates(s_dict) + with app.app_context(): + streets = db.session.query(Streets).all() + for s in streets: + s_dict = { + "yishuv_symbol": s.yishuv_symbol, + "street": s.street, + "street_hebrew": s.street_hebrew, + } + add_street_remove_name_duplicates(s_dict) + add_street_remove_num_duplicates(s_dict) logging.debug(f"Loaded streets: {len(yishuv_street_dict)}:{len(yishuv_name_dict)}") @@ -850,9 +856,10 @@ def import_suburban_junctions_into_db(): logging.debug( f"Writing to db: {len(items)} suburban junctions" ) - db.session.query(SuburbanJunction).delete() - db.session.bulk_insert_mappings(SuburbanJunction, items) - db.session.commit() + with app.app_context(): + db.session.query(SuburbanJunction).delete() + db.session.bulk_insert_mappings(SuburbanJunction, items) + db.session.commit() logging.debug(f"Done.") @@ -865,9 +872,10 @@ def fix_name_len(name: str) -> str: return name[: SuburbanJunction.MAX_NAME_LEN] def load_existing_suburban_junctions(): - junctions: List[SuburbanJunction] = db.session.query(SuburbanJunction).all() - for j in junctions: - add_suburban_junction(j) + with app.app_context(): + junctions: List[SuburbanJunction] = db.session.query(SuburbanJunction).all() + for j in junctions: + add_suburban_junction(j) logging.debug(f"Loaded suburban junctions: {len(suburban_junctions_dict)}.") @@ -908,12 +916,12 @@ def delete_invalid_entries(batch_size): deletes all markers in the database with null latitude or longitude first deletes from tables Involved and Vehicle, then from table AccidentMarker """ - - marker_ids_to_delete = ( - db.session.query(AccidentMarker.id) - .filter(or_((AccidentMarker.longitude == None), (AccidentMarker.latitude == None))) - .all() - ) + with app.app_context(): + marker_ids_to_delete = ( + db.session.query(AccidentMarker.id) + .filter(or_((AccidentMarker.longitude == None), (AccidentMarker.latitude == None))) + .all() + ) marker_ids_to_delete = [acc_id[0] for acc_id in marker_ids_to_delete] @@ -922,24 +930,24 @@ def delete_invalid_entries(batch_size): for ids_chunk in chunks(marker_ids_to_delete, batch_size): logging.debug("Deleting a chunk of " + str(len(ids_chunk))) - - q = db.session.query(Involved).filter(Involved.accident_id.in_(ids_chunk)) - if q.all(): - logging.debug("deleting invalid entries from Involved") - q.delete(synchronize_session="fetch") - db.session.commit() - - q = db.session.query(Vehicle).filter(Vehicle.accident_id.in_(ids_chunk)) - if q.all(): - logging.debug("deleting invalid entries from Vehicle") - q.delete(synchronize_session="fetch") - db.session.commit() - - q = db.session.query(AccidentMarker).filter(AccidentMarker.id.in_(ids_chunk)) - if q.all(): - logging.debug("deleting invalid entries from AccidentMarker") - q.delete(synchronize_session="fetch") - db.session.commit() + with app.app_context(): + q = db.session.query(Involved).filter(Involved.accident_id.in_(ids_chunk)) + if q.all(): + logging.debug("deleting invalid entries from Involved") + q.delete(synchronize_session="fetch") + db.session.commit() + + q = db.session.query(Vehicle).filter(Vehicle.accident_id.in_(ids_chunk)) + if q.all(): + logging.debug("deleting invalid entries from Vehicle") + q.delete(synchronize_session="fetch") + db.session.commit() + + q = db.session.query(AccidentMarker).filter(AccidentMarker.id.in_(ids_chunk)) + if q.all(): + logging.debug("deleting invalid entries from AccidentMarker") + q.delete(synchronize_session="fetch") + db.session.commit() def delete_cbs_entries(start_year, batch_size): @@ -948,17 +956,18 @@ def delete_cbs_entries(start_year, batch_size): first deletes from tables Involved and Vehicle, then from table AccidentMarker """ start_date = f"{start_year}-01-01" - marker_ids_to_delete = ( - db.session.query(AccidentMarker.id) - .filter(AccidentMarker.created >= datetime.strptime(start_date, "%Y-%m-%d")) - .filter( - or_( - (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_1_CODE), - (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_3_CODE), + with app.app_context(): + marker_ids_to_delete = ( + db.session.query(AccidentMarker.id) + .filter(AccidentMarker.created >= datetime.strptime(start_date, "%Y-%m-%d")) + .filter( + or_( + (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_1_CODE), + (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_3_CODE), + ) ) + .all() ) - .all() - ) marker_ids_to_delete = [acc_id[0] for acc_id in marker_ids_to_delete] @@ -972,24 +981,24 @@ def delete_cbs_entries(start_year, batch_size): for ids_chunk in chunks(marker_ids_to_delete, batch_size): logging.debug("Deleting a chunk of " + str(len(ids_chunk))) - - q = db.session.query(Involved).filter(Involved.accident_id.in_(ids_chunk)) - if q.all(): - logging.debug("deleting entries from Involved") - q.delete(synchronize_session=False) - db.session.commit() - - q = db.session.query(Vehicle).filter(Vehicle.accident_id.in_(ids_chunk)) - if q.all(): - logging.debug("deleting entries from Vehicle") - q.delete(synchronize_session=False) - db.session.commit() - - q = db.session.query(AccidentMarker).filter(AccidentMarker.id.in_(ids_chunk)) - if q.all(): - logging.debug("deleting entries from AccidentMarker") - q.delete(synchronize_session=False) - db.session.commit() + with app.app_context(): + q = db.session.query(Involved).filter(Involved.accident_id.in_(ids_chunk)) + if q.all(): + logging.debug("deleting entries from Involved") + q.delete(synchronize_session=False) + db.session.commit() + + q = db.session.query(Vehicle).filter(Vehicle.accident_id.in_(ids_chunk)) + if q.all(): + logging.debug("deleting entries from Vehicle") + q.delete(synchronize_session=False) + db.session.commit() + + q = db.session.query(AccidentMarker).filter(AccidentMarker.id.in_(ids_chunk)) + if q.all(): + logging.debug("deleting entries from AccidentMarker") + q.delete(synchronize_session=False) + db.session.commit() def fill_db_geo_data(): @@ -997,11 +1006,12 @@ def fill_db_geo_data(): Fills empty geometry object according to coordinates in database SRID = 4326 """ - db.session.execute( - "UPDATE markers SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326)\ - WHERE geom IS NULL;" - ) - db.session.commit() + with app.app_context(): + db.session.execute(sa.text( + "UPDATE markers SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326)\ + WHERE geom IS NULL;" + )) + db.session.commit() def get_provider_code(directory_name=None): @@ -1041,7 +1051,7 @@ def fill_dictionary_tables(cbs_dictionary, provider_code, year): for inner_k, inner_v in v.items(): if inner_v is None or (isinstance(inner_v, float) and math.isnan(inner_v)): continue - sql_delete = ( + sql_delete = (sa.text( "DELETE FROM " + curr_table + " WHERE provider_code=" @@ -1050,27 +1060,28 @@ def fill_dictionary_tables(cbs_dictionary, provider_code, year): + str(year) + " AND id=" + str(inner_k) - ) - db.session.execute(sql_delete) - db.session.commit() - sql_insert = ( - "INSERT INTO " - + curr_table - + " VALUES (" - + str(inner_k) - + "," - + str(year) - + "," - + str(provider_code) - + "," - + "'" - + inner_v.replace("'", "") - + "'" - + ")" - + " ON CONFLICT DO NOTHING" - ) - db.session.execute(sql_insert) - db.session.commit() + )) + with app.app_context(): + db.session.execute(sql_delete) + db.session.commit() + sql_insert = (sa.text( + "INSERT INTO " + + curr_table + + " VALUES (" + + str(inner_k) + + "," + + str(year) + + "," + + str(provider_code) + + "," + + "'" + + inner_v.replace("'", "") + + "'" + + ")" + + " ON CONFLICT DO NOTHING" + )) + db.session.execute(sql_insert) + db.session.commit() logging.debug("Inserted/Updated dictionary values into table " + curr_table) create_provider_code_table() @@ -1081,17 +1092,19 @@ def truncate_dictionary_tables(dictionary_file): if k == 97: continue curr_table = TABLES_DICT[k] - sql_truncate = "TRUNCATE TABLE " + curr_table - db.session.execute(sql_truncate) - db.session.commit() + sql_truncate = sa.text("TRUNCATE TABLE " + curr_table) + with app.app_context(): + db.session.execute(sql_truncate) + db.session.commit() logging.debug("Truncated table " + curr_table) def create_provider_code_table(): provider_code_table = "provider_code" provider_code_class = ProviderCode - table_entries = db.session.query(provider_code_class) - table_entries.delete() + with app.app_context(): + table_entries = db.session.query(provider_code_class) + table_entries.delete() provider_code_dict = { 1: "הלשכה המרכזית לסטטיסטיקה - סוג תיק 1", 2: "איחוד הצלה", @@ -1099,11 +1112,12 @@ def create_provider_code_table(): 4: "שומרי הדרך", } for k, v in provider_code_dict.items(): - sql_insert = ( + sql_insert = (sa.text( "INSERT INTO " + provider_code_table + " VALUES (" + str(k) + "," + "'" + v + "'" + ")" - ) - db.session.execute(sql_insert) - db.session.commit() + )) + with app.app_context(): + db.session.execute(sql_insert) + db.session.commit() def receive_rollback(conn, **kwargs): @@ -1115,33 +1129,34 @@ def receive_rollback(conn, **kwargs): def create_tables(): chunk_size = 5000 try: - with db.get_engine().begin() as conn: - event.listen(conn, "rollback", receive_rollback) - delete_all_rows_from_table(conn, AccidentMarkerView) - run_query_and_insert_to_table_in_chunks(VIEWS.create_markers_hebrew_view(), AccidentMarkerView, - AccidentMarker.id, chunk_size, conn) - logging.debug("after insertion to markers_hebrew ") - - delete_all_rows_from_table(conn, InvolvedView) - run_query_and_insert_to_table_in_chunks(VIEWS.create_involved_hebrew_view(), InvolvedView, - Involved.id, chunk_size, conn) - logging.debug("after insertion to involved_hebrew ") - - delete_all_rows_from_table(conn, VehiclesView) - run_query_and_insert_to_table_in_chunks(VIEWS.create_vehicles_hebrew_view(), - VehiclesView, Vehicle.id, chunk_size, conn) - logging.debug("after insertion to vehicles_hebrew ") - - delete_all_rows_from_table(conn, VehicleMarkerView) - run_query_and_insert_to_table_in_chunks(VIEWS.create_vehicles_markers_hebrew_view(), - VehicleMarkerView, VehiclesView.id, chunk_size, conn) - logging.debug("after insertion to vehicles_markers_hebrew ") - - delete_all_rows_from_table(conn, InvolvedMarkerView) - run_query_and_insert_to_table_in_chunks(VIEWS.create_involved_hebrew_markers_hebrew_view(), - InvolvedMarkerView, InvolvedView.accident_id, chunk_size, conn) - logging.debug("after insertion to involved_markers_hebrew") - logging.debug("Created DB Hebrew Tables") + with app.app_context(): + with db.get_engine().begin() as conn: + event.listen(conn, "rollback", receive_rollback) + delete_all_rows_from_table(conn, AccidentMarkerView) + run_query_and_insert_to_table_in_chunks(VIEWS.create_markers_hebrew_view(), AccidentMarkerView, + AccidentMarker.id, chunk_size, conn) + logging.debug("after insertion to markers_hebrew ") + + delete_all_rows_from_table(conn, InvolvedView) + run_query_and_insert_to_table_in_chunks(VIEWS.create_involved_hebrew_view(), InvolvedView, + Involved.id, chunk_size, conn) + logging.debug("after insertion to involved_hebrew ") + + delete_all_rows_from_table(conn, VehiclesView) + run_query_and_insert_to_table_in_chunks(VIEWS.create_vehicles_hebrew_view(), + VehiclesView, Vehicle.id, chunk_size, conn) + logging.debug("after insertion to vehicles_hebrew ") + + delete_all_rows_from_table(conn, VehicleMarkerView) + run_query_and_insert_to_table_in_chunks(VIEWS.create_vehicles_markers_hebrew_view(), + VehicleMarkerView, VehiclesView.id, chunk_size, conn) + logging.debug("after insertion to vehicles_markers_hebrew ") + + delete_all_rows_from_table(conn, InvolvedMarkerView) + run_query_and_insert_to_table_in_chunks(VIEWS.create_involved_hebrew_markers_hebrew_view(), + InvolvedMarkerView, InvolvedView.accident_id, chunk_size, conn) + logging.debug("after insertion to involved_markers_hebrew") + logging.debug("Created DB Hebrew Tables") except Exception as e: logging.exception(f"Exception while creating hebrew tables, {e}", e) raise e diff --git a/anyway/parsers/cbs/preprocessing_cbs_files.py b/anyway/parsers/cbs/preprocessing_cbs_files.py index 98728adf5..758294b6e 100644 --- a/anyway/parsers/cbs/preprocessing_cbs_files.py +++ b/anyway/parsers/cbs/preprocessing_cbs_files.py @@ -1,8 +1,9 @@ import os import logging import pandas as pd -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.models import City, CityTemp +import sqlalchemy as sa CBS_FILES_HEBREW = { "sadot": "Fields", @@ -88,15 +89,16 @@ def load_cities_data(file_name: str): ) cities_list = cities.to_dict(orient="records") logging.info(f"Read {len(cities_list)} from {file_name}") - db.session.commit() - db.session.execute("DROP table IF EXISTS cbs_cities_temp") - db.session.execute("CREATE TABLE cbs_cities_temp AS TABLE cbs_cities with NO DATA") - db.session.execute(CityTemp.__table__.insert(), cities_list) - db.session.execute("TRUNCATE table cbs_cities") - db.session.execute("INSERT INTO cbs_cities SELECT * FROM cbs_cities_temp") - db.session.execute("DROP table cbs_cities_temp") - db.session.commit() - num_items = db.session.query(City).count() + with app.app_context(): + db.session.commit() + db.session.execute(sa.text("DROP table IF EXISTS cbs_cities_temp")) + db.session.execute(sa.text("CREATE TABLE cbs_cities_temp AS TABLE cbs_cities with NO DATA")) + db.session.execute(CityTemp.__table__.insert(), cities_list) + db.session.execute(sa.text("TRUNCATE table cbs_cities")) + db.session.execute(sa.text("INSERT INTO cbs_cities SELECT * FROM cbs_cities_temp")) + db.session.execute(sa.text("DROP table cbs_cities_temp")) + db.session.commit() + num_items = db.session.query(City).count() logging.info(f"num items in cities: {num_items}.") diff --git a/anyway/parsers/embedded_reports.py b/anyway/parsers/embedded_reports.py index 30d79ba29..b1b356874 100644 --- a/anyway/parsers/embedded_reports.py +++ b/anyway/parsers/embedded_reports.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from anyway.models import EmbeddedReports import pandas as pd -from anyway.app_and_db import db +from anyway.app_and_db import db, app def _iter_rows(filename): @@ -18,14 +18,15 @@ def _iter_rows(filename): def parse(filename): - for row in _iter_rows(filename): - current_report = ( - db.session.query(EmbeddedReports) - .filter(EmbeddedReports.report_name_english == row["report_name_english"]) - .all() - ) - if not current_report: - db.session.bulk_insert_mappings(EmbeddedReports, [row]) - else: - db.session.bulk_update_mappings(EmbeddedReports, [row]) - db.session.commit() + with app.app_context(): + for row in _iter_rows(filename): + current_report = ( + db.session.query(EmbeddedReports) + .filter(EmbeddedReports.report_name_english == row["report_name_english"]) + .all() + ) + if not current_report: + db.session.bulk_insert_mappings(EmbeddedReports, [row]) + else: + db.session.bulk_update_mappings(EmbeddedReports, [row]) + db.session.commit() diff --git a/anyway/parsers/infographics_data_cache_updater.py b/anyway/parsers/infographics_data_cache_updater.py index 9e9f96371..bc4c0fb8d 100755 --- a/anyway/parsers/infographics_data_cache_updater.py +++ b/anyway/parsers/infographics_data_cache_updater.py @@ -2,6 +2,7 @@ from datetime import datetime from sqlalchemy import not_ +import sqlalchemy as sa from anyway.models import ( Base, InfographicsDataCache, @@ -17,7 +18,7 @@ from typing import Dict, Iterable from anyway.constants import CONST from anyway.backend_constants import BE_CONST -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.request_params import RequestParams import anyway.infographics_utils from anyway.widgets.widget import widgets_dict @@ -42,12 +43,13 @@ def is_in_cache(nf): - return ( - len(CONST.INFOGRAPHICS_CACHE_YEARS_AGO) - == db.session.query(InfographicsDataCache) - .filter(InfographicsDataCache.news_flash_id == nf.get_id()) - .count() - ) + with app.app_context(): + return ( + len(CONST.INFOGRAPHICS_CACHE_YEARS_AGO) + == db.session.query(InfographicsDataCache) + .filter(InfographicsDataCache.news_flash_id == nf.get_id()) + .count() + ) # noinspection PyUnresolvedReferences @@ -214,20 +216,22 @@ def copy_temp_into_cache(table: Dict[str, Base]): db.session.commit() start = datetime.now() with db.get_engine().begin() as conn: - conn.execute("lock table infographics_data_cache in exclusive mode") + conn.execute(sa.text("lock table infographics_data_cache in exclusive mode")) logging.debug(f"in transaction, after lock") - conn.execute(f"delete from {table[CACHE].__tablename__}") + conn.execute(sa.text(f"delete from {table[CACHE].__tablename__}")) logging.debug(f"in transaction, after delete") conn.execute( - f"insert into {table[CACHE].__tablename__} " - f"SELECT * from {table[TEMP].__tablename__}" + sa.text( + f"insert into {table[CACHE].__tablename__} " + f"SELECT * from {table[TEMP].__tablename__}" + ) ) logging.debug(f"in transaction, after insert into") logging.info(f"cache unavailable time: {str(datetime.now() - start)}") num_items_cache = db.session.query(table[CACHE]).count() num_items_temp = db.session.query(table[TEMP]).count() logging.debug(f"num items in cache: {num_items_cache}, temp:{num_items_temp}") - db.session.execute(f"truncate table {table[TEMP].__tablename__}") + db.session.execute(sa.text(f"truncate table {table[TEMP].__tablename__}")) db.session.commit() num_items_cache = db.session.query(table[CACHE]).count() num_items_temp = db.session.query(table[TEMP]).count() diff --git a/anyway/parsers/injured_around_schools.py b/anyway/parsers/injured_around_schools.py index 3e21f550c..dc652ac86 100644 --- a/anyway/parsers/injured_around_schools.py +++ b/anyway/parsers/injured_around_schools.py @@ -7,6 +7,7 @@ import math import pandas as pd from sqlalchemy import or_, not_, and_ +import sqlalchemy as sa from anyway.backend_constants import BE_CONST from anyway.models import ( @@ -17,7 +18,7 @@ InjuredAroundSchoolAllData, ) from anyway.utilities import time_delta, chunks -from anyway.app_and_db import db +from anyway.app_and_db import db, app SUBTYPE_ACCIDENT_WITH_PEDESTRIAN = 1 LOCATION_ACCURACY_PRECISE = True @@ -58,27 +59,27 @@ def acc_inv_query(longitude, latitude, distance, start_date, end_date, school): pol_str = "POLYGON(({0} {1},{0} {3},{2} {3},{2} {1},{0} {1}))".format( base_x, base_y, distance_x, distance_y ) - - query_obj = ( - db.session.query(Involved, AccidentMarker) - .join(AccidentMarker, AccidentMarker.provider_and_id == Involved.provider_and_id) - .filter(AccidentMarker.geom.intersects(pol_str)) - .filter(Involved.injured_type == INJURED_TYPE_PEDESTRIAN) - .filter(AccidentMarker.provider_and_id == Involved.provider_and_id) - .filter( - or_( - (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_1_CODE), - (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_3_CODE), + with app.app_context(): + query_obj = ( + db.session.query(Involved, AccidentMarker) + .join(AccidentMarker, AccidentMarker.provider_and_id == Involved.provider_and_id) + .filter(AccidentMarker.geom.intersects(pol_str)) + .filter(Involved.injured_type == INJURED_TYPE_PEDESTRIAN) + .filter(AccidentMarker.provider_and_id == Involved.provider_and_id) + .filter( + or_( + (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_1_CODE), + (AccidentMarker.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_3_CODE), + ) ) - ) - .filter(AccidentMarker.created >= start_date) - .filter(AccidentMarker.created < end_date) - .filter(AccidentMarker.location_accuracy == LOCATION_ACCURACY_PRECISE_INT) - .filter(AccidentMarker.yishuv_symbol != YISHUV_SYMBOL_NOT_EXIST) - .filter(Involved.age_group.in_([1, 2, 3, 4])) - ) # ages 0-19 - - df = pd.read_sql_query(query_obj.with_labels().statement, db.get_engine()) + .filter(AccidentMarker.created >= start_date) + .filter(AccidentMarker.created < end_date) + .filter(AccidentMarker.location_accuracy == LOCATION_ACCURACY_PRECISE_INT) + .filter(AccidentMarker.yishuv_symbol != YISHUV_SYMBOL_NOT_EXIST) + .filter(Involved.age_group.in_([1, 2, 3, 4])) + ) # ages 0-19 + + df = pd.read_sql_query(query_obj.with_labels().statement, db.get_engine()) if LOCATION_ACCURACY_PRECISE: location_accurate = 1 location_approx = "" @@ -221,22 +222,26 @@ def select_columns_df_total(df): def get_injured_around_schools(start_date, end_date, distance): - schools = ( - db.session.query(SchoolWithDescription) - .filter( - not_(and_(SchoolWithDescription.latitude == 0, SchoolWithDescription.longitude == 0)), - not_( - and_( - SchoolWithDescription.latitude == None, SchoolWithDescription.longitude == None - ) - ), - or_( - SchoolWithDescription.school_type == "גן ילדים", - SchoolWithDescription.school_type == "בית ספר", - ), + with app.app_context(): + schools = ( + db.session.query(SchoolWithDescription) + .filter( + not_( + and_(SchoolWithDescription.latitude == 0, SchoolWithDescription.longitude == 0) + ), + not_( + and_( + SchoolWithDescription.latitude == None, + SchoolWithDescription.longitude == None, + ) + ), + or_( + SchoolWithDescription.school_type == "גן ילדים", + SchoolWithDescription.school_type == "בית ספר", + ), + ) + .all() ) - .all() - ) data_dir = "tmp_school_data" if os.path.exists(data_dir): shutil.rmtree(data_dir) @@ -432,15 +437,17 @@ def get_injured_around_schools(start_date, end_date, distance): def truncate_injured_around_schools(): curr_table = "injured_around_school" - sql_truncate = "TRUNCATE TABLE " + curr_table - db.session.execute(sql_truncate) - db.session.commit() + sql_truncate = sa.text("TRUNCATE TABLE " + curr_table) + with app.app_context(): + db.session.execute(sql_truncate) + db.session.commit() logging.info("Truncated table " + curr_table) curr_table = "injured_around_school_all_data" - sql_truncate = "TRUNCATE TABLE " + curr_table - db.session.execute(sql_truncate) - db.session.commit() + sql_truncate = sa.text("TRUNCATE TABLE " + curr_table) + with app.app_context(): + db.session.execute(sql_truncate) + db.session.commit() logging.info("Truncated table " + curr_table) @@ -460,15 +467,17 @@ def import_to_datastore(start_date, end_date, distance, batch_size): if chunk_idx % 10 == 0: logging_chunk = f"Chunk idx in injured_around_schools: {chunk_idx}" logging.info(logging_chunk) - db.session.bulk_insert_mappings(InjuredAroundSchool, schools_chunk) - db.session.commit() + with app.app_context(): + db.session.bulk_insert_mappings(InjuredAroundSchool, schools_chunk) + db.session.commit() logging.info(f"inserting {len(df_total)} new rows injured_around_school_all_data") for chunk_idx, schools_chunk in enumerate(chunks(df_total, batch_size)): if chunk_idx % 10 == 0: logging_chunk = f"Chunk idx in injured_around_school_all_data: {chunk_idx}" logging.info(logging_chunk) - db.session.bulk_insert_mappings(InjuredAroundSchoolAllData, schools_chunk) - db.session.commit() + with app.app_context(): + db.session.bulk_insert_mappings(InjuredAroundSchoolAllData, schools_chunk) + db.session.commit() new_items += len(injured_around_schools) + len(df_total) logging.info(f"\t{new_items} items in {time_delta(started)}") return new_items diff --git a/anyway/parsers/injured_around_schools_2022.py b/anyway/parsers/injured_around_schools_2022.py index a82d93123..cce30d2ab 100644 --- a/anyway/parsers/injured_around_schools_2022.py +++ b/anyway/parsers/injured_around_schools_2022.py @@ -15,7 +15,7 @@ from anyway.models import SchoolWithDescription2020, InvolvedMarkerView from anyway.utilities import time_delta -from anyway.app_and_db import db +from anyway.app_and_db import db, app class SchoolsJsonFile(Enum): @@ -72,30 +72,31 @@ def acc_inv_query(longitude, latitude, distance, start_date, end_date, school): pol_str_for_google_csv = "POLYGON(({0} {1},{0} {3},{2} {3},{2} {1},{0} {1}))".format( round(baseX, 6), round(baseY, 6), round(distanceX, 6), round(distanceY, 6) ) - query_obj = ( - db.session.query(InvolvedMarkerView) - .filter(InvolvedMarkerView.geom.intersects(pol_str)) - .filter( - or_( - (InvolvedMarkerView.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_1_CODE), - (InvolvedMarkerView.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_3_CODE), + with app.app_context(): + query_obj = ( + db.session.query(InvolvedMarkerView) + .filter(InvolvedMarkerView.geom.intersects(pol_str)) + .filter( + or_( + (InvolvedMarkerView.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_1_CODE), + (InvolvedMarkerView.provider_code == BE_CONST.CBS_ACCIDENT_TYPE_3_CODE), + ) ) - ) - .filter(InvolvedMarkerView.accident_timestamp >= start_date) - .filter(InvolvedMarkerView.accident_timestamp <= end_date) - .filter(InvolvedMarkerView.location_accuracy.in_(LOCATION_ACCURACY_PRECISE_LIST)) - .filter(InvolvedMarkerView.age_group.in_(AGE_GROUPS)) - .filter(InvolvedMarkerView.injury_severity.in_(INJURY_SEVERITIES)) - .filter( - or_( - (InvolvedMarkerView.injured_type.in_(INJURED_TYPES)), - (InvolvedMarkerView.involve_vehicle_type.in_(VEHICLE_TYPES)), + .filter(InvolvedMarkerView.accident_timestamp >= start_date) + .filter(InvolvedMarkerView.accident_timestamp <= end_date) + .filter(InvolvedMarkerView.location_accuracy.in_(LOCATION_ACCURACY_PRECISE_LIST)) + .filter(InvolvedMarkerView.age_group.in_(AGE_GROUPS)) + .filter(InvolvedMarkerView.injury_severity.in_(INJURY_SEVERITIES)) + .filter( + or_( + (InvolvedMarkerView.injured_type.in_(INJURED_TYPES)), + (InvolvedMarkerView.involve_vehicle_type.in_(VEHICLE_TYPES)), + ) ) + .filter(InvolvedMarkerView.accident_hour_raw.between(SEVEN_AM_RAW, SEVEN_PM_RAW)) ) - .filter(InvolvedMarkerView.accident_hour_raw.between(SEVEN_AM_RAW, SEVEN_PM_RAW)) - ) - df = pd.read_sql_query(query_obj.statement, db.get_engine()) + df = pd.read_sql_query(query_obj.statement, db.get_engine()) if LOCATION_ACCURACY_PRECISE: location_accurate = 1 @@ -136,24 +137,25 @@ def acc_inv_query(longitude, latitude, distance, start_date, end_date, school): def calculate_injured_around_schools(start_date, end_date, distance): - schools = ( - db.session.query(SchoolWithDescription2020) - .filter( - not_( - and_( - SchoolWithDescription2020.latitude == 0, - SchoolWithDescription2020.longitude == 0, - ) - ), - not_( - and_( - SchoolWithDescription2020.latitude == None, - SchoolWithDescription2020.longitude == None, - ) - ), + with app.app_context(): + schools = ( + db.session.query(SchoolWithDescription2020) + .filter( + not_( + and_( + SchoolWithDescription2020.latitude == 0, + SchoolWithDescription2020.longitude == 0, + ) + ), + not_( + and_( + SchoolWithDescription2020.latitude == None, + SchoolWithDescription2020.longitude == None, + ) + ), + ) + .all() ) - .all() - ) if os.path.exists(ALL_SCHOOLS_DATA_DIR): shutil.rmtree(ALL_SCHOOLS_DATA_DIR) os.mkdir(ALL_SCHOOLS_DATA_DIR) diff --git a/anyway/parsers/location_extraction.py b/anyway/parsers/location_extraction.py index 7b283ab58..4efb09683 100644 --- a/anyway/parsers/location_extraction.py +++ b/anyway/parsers/location_extraction.py @@ -37,31 +37,33 @@ def extract_road_number(location): def get_road_segment_by_name(road_segment_name: str) -> RoadSegments: try: - from anyway.app_and_db import db + from anyway.app_and_db import db, app except ModuleNotFoundError: pass # TODO: maybe throw exception? from_name = road_segment_name.split(" - ")[0].strip() to_name = road_segment_name.split(" - ")[1].strip() - query_obj = ( - db.session.query(RoadSegments) - .filter(RoadSegments.from_name == from_name) - .filter(RoadSegments.to_name == to_name) - ) - segment = query_obj.first() + with app.app_context(): + query_obj = ( + db.session.query(RoadSegments) + .filter(RoadSegments.from_name == from_name) + .filter(RoadSegments.to_name == to_name) + ) + segment = query_obj.first() return segment def get_road_segment_by_name_and_road(road_segment_name: str, road: int) -> RoadSegments: try: - from anyway.app_and_db import db + from anyway.app_and_db import db, app except ModuleNotFoundError: pass # TODO: maybe throw exception? - segments = db.session.query(RoadSegments).filter(RoadSegments.road == road).all() - for segment in segments: - if road_segment_name.startswith(segment.from_name) and road_segment_name.endswith( - segment.to_name - ): - return segment + with app.app_context(): + segments = db.session.query(RoadSegments).filter(RoadSegments.road == road).all() + for segment in segments: + if road_segment_name.startswith(segment.from_name) and road_segment_name.endswith( + segment.to_name + ): + return segment err_msg = f"get_road_segment_by_name_and_road:{road_segment_name},{road}: not found" logging.error(err_msg) raise ValueError(err_msg) @@ -69,15 +71,18 @@ def get_road_segment_by_name_and_road(road_segment_name: str, road: int) -> Road def get_road_segment_name_and_number(road_segment_id) -> (float, str): try: - from anyway.app_and_db import db + from anyway.app_and_db import db, app except ModuleNotFoundError: pass # TODO: maybe throw exception? - query_obj = db.session.query(RoadSegments).filter(RoadSegments.segment_id == road_segment_id) - segment = query_obj.first() - from_name = segment.from_name # pylint: disable=maybe-no-member - to_name = segment.to_name # pylint: disable=maybe-no-member - road_segment_name = " - ".join([from_name, to_name]) - road = segment.road # pylint: disable=maybe-no-member + with app.app_context(): + query_obj = db.session.query(RoadSegments).filter( + RoadSegments.segment_id == road_segment_id + ) + segment = query_obj.first() + from_name = segment.from_name # pylint: disable=maybe-no-member + to_name = segment.to_name # pylint: disable=maybe-no-member + road_segment_name = " - ".join([from_name, to_name]) + road = segment.road # pylint: disable=maybe-no-member return float(road), road_segment_name @@ -106,7 +111,7 @@ def get_bounding_box(latitude, longitude, distance_in_km): return rad2deg(lat_min), rad2deg(lon_min), rad2deg(lat_max), rad2deg(lon_max) try: - from anyway.app_and_db import db + from anyway.app_and_db import db, app except ModuleNotFoundError: pass # TODO: maybe throw exception? @@ -121,26 +126,27 @@ def get_bounding_box(latitude, longitude, distance_in_km): ) cutoff_year = (date.today()).year - 6 - query_obj = ( - db.session.query(AccidentMarkerView) - .filter(AccidentMarkerView.geom.intersects(polygon_str)) - .filter(AccidentMarkerView.accident_year >= cutoff_year) - .filter(AccidentMarkerView.provider_code != BE_CONST.RSA_PROVIDER_CODE) - .filter(not_(AccidentMarkerView.road_segment_name == None)) - .options( - load_only( - "road1", - "road_segment_id", - "road_segment_name", - "latitude", - "longitude", - "geom", - "accident_year", - "provider_code", + with app.app_context(): + query_obj = ( + db.session.query(AccidentMarkerView) + .filter(AccidentMarkerView.geom.intersects(polygon_str)) + .filter(AccidentMarkerView.accident_year >= cutoff_year) + .filter(AccidentMarkerView.provider_code != BE_CONST.RSA_PROVIDER_CODE) + .filter(not_(AccidentMarkerView.road_segment_name == None)) + .options( + load_only( + "road1", + "road_segment_id", + "road_segment_name", + "latitude", + "longitude", + "geom", + "accident_year", + "provider_code", + ) ) ) - ) - markers = pd.read_sql_query(query_obj.statement, db.get_engine()) + markers = pd.read_sql_query(query_obj.statement, db.get_engine()) geod = Geodesic.WGS84 markers["geohash"] = markers.apply( # pylint: disable=maybe-no-member @@ -189,7 +195,10 @@ def get_db_matching_location(db, latitude, longitude, resolution, road_no=None): # READ MARKERS FROM DB geod = Geodesic.WGS84 relevant_fields = resolution_dict[resolution] - markers = db.get_markers_for_location_extraction() + from anyway.app_and_db import app + + with app.app_context(): + markers = db.get_markers_for_location_extraction() markers["geohash"] = markers.apply( lambda x: geohash.encode(x["latitude"], x["longitude"], precision=4), axis=1 ) diff --git a/anyway/parsers/news_flash.py b/anyway/parsers/news_flash.py index b05a39a1e..1970f3907 100644 --- a/anyway/parsers/news_flash.py +++ b/anyway/parsers/news_flash.py @@ -9,6 +9,7 @@ classify_organization, ) from anyway.parsers.location_extraction import extract_geo_features +from anyway.app_and_db import app # FIX: classifier should be chosen by source (screen name), so `twitter` should be `mda` news_flash_classifiers = {"ynet": classify_rss, "twitter": classify_tweets, "walla": classify_rss} @@ -20,50 +21,53 @@ def update_all_in_db(source=None, newsflash_id=None): Should be executed each time the classification or location-extraction are updated. """ - db = init_db() - if newsflash_id is not None: - newsflash_items = db.get_newsflash_by_id(newsflash_id) - elif source is not None: - newsflash_items = db.select_newsflash_where_source(source) - else: - newsflash_items = db.get_all_newsflash() + with app.app_context(): + db = init_db() + if newsflash_id is not None: + newsflash_items = db.get_newsflash_by_id(newsflash_id) + elif source is not None: + newsflash_items = db.select_newsflash_where_source(source) + else: + newsflash_items = db.get_all_newsflash() - for newsflash in newsflash_items: - classify = news_flash_classifiers[newsflash.source] - newsflash.organization = classify_organization(newsflash.source) - newsflash.accident = classify(newsflash.description or newsflash.title) - if newsflash.accident: - extract_geo_features(db, newsflash) - db.commit() + for newsflash in newsflash_items: + classify = news_flash_classifiers[newsflash.source] + newsflash.organization = classify_organization(newsflash.source) + newsflash.accident = classify(newsflash.description or newsflash.title) + if newsflash.accident: + extract_geo_features(db, newsflash) + db.commit() def scrape_extract_store_rss(site_name, db): - latest_date = db.get_latest_date_of_source(site_name) - for newsflash in rss_sites.scrape(site_name): - if newsflash.date <= latest_date: - break - # TODO: pass both title and description, leaving this choice to the classifier - newsflash.accident = classify_rss(newsflash.title or newsflash.description) - newsflash.organization = classify_organization(site_name) - if newsflash.accident: - # FIX: No accident-accurate date extracted - extract_geo_features(db, newsflash) - newsflash.set_critical() - db.insert_new_newsflash(newsflash) + with app.app_context(): + latest_date = db.get_latest_date_of_source(site_name) + for newsflash in rss_sites.scrape(site_name): + if newsflash.date <= latest_date: + break + # TODO: pass both title and description, leaving this choice to the classifier + newsflash.accident = classify_rss(newsflash.title or newsflash.description) + newsflash.organization = classify_organization(site_name) + if newsflash.accident: + # FIX: No accident-accurate date extracted + extract_geo_features(db, newsflash) + newsflash.set_critical() + db.insert_new_newsflash(newsflash) def scrape_extract_store_twitter(screen_name, db): - latest_date = db.get_latest_date_of_source("twitter") - for newsflash in twitter.scrape(screen_name, db.get_latest_tweet_id()): - if newsflash.date <= latest_date: - # We can break if we're guaranteed the order is descending - continue - newsflash.accident = classify_tweets(newsflash.description) - newsflash.organization = classify_organization("twitter") - if newsflash.accident: - extract_geo_features(db, newsflash) - newsflash.set_critical() - db.insert_new_newsflash(newsflash) + with app.app_context(): + latest_date = db.get_latest_date_of_source("twitter") + for newsflash in twitter.scrape(screen_name, db.get_latest_tweet_id()): + if newsflash.date <= latest_date: + # We can break if we're guaranteed the order is descending + continue + newsflash.accident = classify_tweets(newsflash.description) + newsflash.organization = classify_organization("twitter") + if newsflash.accident: + extract_geo_features(db, newsflash) + newsflash.set_critical() + db.insert_new_newsflash(newsflash) def scrape_all(): @@ -71,7 +75,8 @@ def scrape_all(): main function for newsflash scraping """ sys.path.append(os.path.dirname(os.path.realpath(__file__))) - db = init_db() - scrape_extract_store_rss("ynet", db) - scrape_extract_store_rss("walla", db) - # scrape_extract_store_twitter("mda_israel", db) + with app.app_context(): + db = init_db() + scrape_extract_store_rss("ynet", db) + scrape_extract_store_rss("walla", db) + # scrape_extract_store_twitter("mda_israel", db) diff --git a/anyway/parsers/news_flash_db_adapter.py b/anyway/parsers/news_flash_db_adapter.py index e2d5889ce..140e6631b 100644 --- a/anyway/parsers/news_flash_db_adapter.py +++ b/anyway/parsers/news_flash_db_adapter.py @@ -8,14 +8,15 @@ from anyway.parsers import timezones from anyway.models import NewsFlash from anyway.slack_accident_notifications import publish_notification - +from anyway.app_and_db import db, app +import sqlalchemy as sa # fmt: off def init_db() -> "DBAdapter": - from anyway.app_and_db import db - return DBAdapter(db) + with app.app_context(): + return DBAdapter(db) class DBAdapter: @@ -24,96 +25,107 @@ def __init__(self, db: SQLAlchemy): self.__null_types: set = {np.nan} def execute(self, *args, **kwargs): - return self.db.session.execute(*args, **kwargs) + with app.app_context(): + return self.db.session.execute(*args, **kwargs) def commit(self, *args, **kwargs): - return self.db.session.commit(*args, **kwargs) + with app.app_context(): + return self.db.session.commit(*args, **kwargs) def recreate_table_for_location_extraction(self): - with self.db.session.begin(): - self.db.session.execute("""TRUNCATE cbs_locations""") - self.db.session.execute("""INSERT INTO cbs_locations - (SELECT ROW_NUMBER() OVER (ORDER BY road1) as id, LOCATIONS.* - FROM - (SELECT DISTINCT road1, - road2, - non_urban_intersection_hebrew, - yishuv_name, - street1_hebrew, - street2_hebrew, - district_hebrew, - region_hebrew, - road_segment_name, - longitude, - latitude - FROM markers_hebrew - WHERE (provider_code=1 - OR provider_code=3) - AND (longitude is not null - AND latitude is not null)) LOCATIONS)""" - ) + with app.app_context(): + with self.db.session.begin(): + self.db.session.execute(sa.text("""TRUNCATE cbs_locations""")) + self.db.session.execute(sa.text("""INSERT INTO cbs_locations + (SELECT ROW_NUMBER() OVER (ORDER BY road1) as id, LOCATIONS.* + FROM + (SELECT DISTINCT road1, + road2, + non_urban_intersection_hebrew, + yishuv_name, + street1_hebrew, + street2_hebrew, + district_hebrew, + region_hebrew, + road_segment_name, + longitude, + latitude + FROM markers_hebrew + WHERE (provider_code=1 + OR provider_code=3) + AND (longitude is not null + AND latitude is not null)) LOCATIONS)""" + )) def get_markers_for_location_extraction(self): - query_res = self.execute( - """SELECT * FROM cbs_locations""" - ) - df = pd.DataFrame(query_res.fetchall()) - df.columns = query_res.keys() - return df + with app.app_context(): + query_res = self.execute(sa.text( + """SELECT * FROM cbs_locations""" + )) + df = pd.DataFrame(query_res.fetchall()) + df.columns = query_res.keys() + return df def remove_duplicate_rows(self): """ remove duplicate rows by link """ - self.execute( - """ - DELETE FROM news_flash T1 - USING news_flash T2 - WHERE T1.ctid < T2.ctid -- delete the older versions - AND T1.link = T2.link; -- add more columns if needed - """ - ) - self.commit() + with app.app_context(): + self.execute(sa.text( + """ + DELETE FROM news_flash T1 + USING news_flash T2 + WHERE T1.ctid < T2.ctid -- delete the older versions + AND T1.link = T2.link; -- add more columns if needed + """ + )) + self.commit() def insert_new_newsflash(self, newsflash: NewsFlash) -> None: - logging.info("Adding newsflash, is accident: {}, date: {}" - .format(newsflash.accident, newsflash.date)) - self.__fill_na(newsflash) - self.db.session.add(newsflash) - self.db.session.commit() - infographics_data_cache_updater.add_news_flash_to_cache(newsflash) - if os.environ.get("FLASK_ENV") == "production" and newsflash.accident: - publish_notification(newsflash) + with app.app_context(): + logging.info("Adding newsflash, is accident: {}, date: {}" + .format(newsflash.accident, newsflash.date)) + self.__fill_na(newsflash) + self.db.session.add(newsflash) + self.db.session.commit() + infographics_data_cache_updater.add_news_flash_to_cache(newsflash) + if os.environ.get("FLASK_ENV") == "production" and newsflash.accident: + publish_notification(newsflash) def get_newsflash_by_id(self, id): - return self.db.session.query(NewsFlash).filter(NewsFlash.id == id) + with app.app_context(): + return self.db.session.query(NewsFlash).filter(NewsFlash.id == id) def select_newsflash_where_source(self, source): - return self.db.session.query(NewsFlash).filter(NewsFlash.source == source) + with app.app_context(): + return self.db.session.query(NewsFlash).filter(NewsFlash.source == source) def get_all_newsflash(self): - return self.db.session.query(NewsFlash) + with app.app_context(): + return self.db.session.query(NewsFlash) def get_latest_date_of_source(self, source): """ :return: latest date of news flash """ - latest_date = self.execute( - "SELECT max(date) FROM news_flash WHERE source=:source", - {"source": source}, - ).fetchone()[0] or datetime.datetime(1900, 1, 1, 0, 0, 0) - res = timezones.from_db(latest_date) - logging.info('Latest time fetched for source {} is {}' - .format(source, res)) - return res + with app.app_context(): + latest_date = self.execute(sa.text( + "SELECT max(date) FROM news_flash WHERE source=:source", + {"source": source}, + )).fetchone()[0] or datetime.datetime(1900, 1, 1, 0, 0, 0) + res = timezones.from_db(latest_date) + logging.info('Latest time fetched for source {} is {}' + .format(source, res)) + return res def get_latest_tweet_id(self): """ :return: latest tweet id """ - latest_id = self.execute( - "SELECT tweet_id FROM news_flash where source='twitter' ORDER BY date DESC LIMIT 1" - ).fetchone() + with app.app_context(): + latest_id = self.execute(sa.text( + "SELECT tweet_id FROM news_flash where source='twitter' ORDER BY date DESC LIMIT 1" + )).fetchone() if latest_id: return latest_id[0] return None diff --git a/anyway/parsers/registered.py b/anyway/parsers/registered.py index 3764915b5..6a9ab1442 100644 --- a/anyway/parsers/registered.py +++ b/anyway/parsers/registered.py @@ -7,7 +7,8 @@ from datetime import datetime from anyway.models import RegisteredVehicle, DeprecatedCity from anyway.utilities import time_delta, CsvReader, ImporterUI, truncate_tables, decode_hebrew -from anyway.app_and_db import db +from anyway.app_and_db import db, app +import sqlalchemy as sa COLUMN_CITY_NAME_ENG = 0 @@ -71,8 +72,8 @@ def import_file(self, inputfile): else: self.header_row(row) row_count += 1 - - db.session.bulk_insert_mappings(RegisteredVehicle, inserts) + with app.app_context(): + db.session.bulk_insert_mappings(RegisteredVehicle, inserts) return total @staticmethod @@ -131,11 +132,14 @@ def main(specific_folder, delete_all, path): started = datetime.now() for fname in dir_files: total += importer.import_file(fname) - - db.session.commit() - db.engine.execute( - "UPDATE {0} SET city_id = (SELECT id FROM {1} WHERE {0}.search_name = {1}.search_heb) WHERE city_id IS NULL".format( - RegisteredVehicle.__tablename__, DeprecatedCity.__tablename__ - ) - ) - logging.info("Total: {0} items in {1}".format(total, time_delta(started))) + with app.app_context(): + db.session.commit() + with db.get_engine().begin() as conn: + conn.execute( + sa.text( + "UPDATE {0} SET city_id = (SELECT id FROM {1} WHERE {0}.search_name = {1}.search_heb) WHERE city_id IS NULL".format( + RegisteredVehicle.__tablename__, DeprecatedCity.__tablename__ + ) + ) + ) + logging.info("Total: {0} items in {1}".format(total, time_delta(started))) diff --git a/anyway/parsers/road_segments.py b/anyway/parsers/road_segments.py index 13a96343a..fcbe768a7 100644 --- a/anyway/parsers/road_segments.py +++ b/anyway/parsers/road_segments.py @@ -3,7 +3,7 @@ from anyway.parsers.utils import batch_iterator from anyway.models import RoadSegments -from anyway.app_and_db import db +from anyway.app_and_db import db, app def _iter_rows(filename): @@ -40,5 +40,6 @@ def _iter_rows(filename): def parse(filename): RoadSegments.query.delete() for batch in batch_iterator(_iter_rows(filename), batch_size=50): - db.session.bulk_insert_mappings(RoadSegments, batch) - db.session.commit() + with app.app_context(): + db.session.bulk_insert_mappings(RoadSegments, batch) + db.session.commit() diff --git a/anyway/parsers/rsa.py b/anyway/parsers/rsa.py index 09d084efb..8e5fa9455 100644 --- a/anyway/parsers/rsa.py +++ b/anyway/parsers/rsa.py @@ -5,7 +5,8 @@ from anyway.parsers.utils import batch_iterator from anyway.backend_constants import BE_CONST from anyway.models import AccidentMarker -from anyway.app_and_db import db +from anyway.app_and_db import db, app +import sqlalchemy as sa def _iter_rows(filename): @@ -68,16 +69,23 @@ def _iter_rows(filename): def parse(filename): - db.session.execute(f"DELETE from markers where provider_code = {BE_CONST.RSA_PROVIDER_CODE}") + with app.app_context(): + db.session.execute( + sa.text(f"DELETE from markers where provider_code = {BE_CONST.RSA_PROVIDER_CODE}") + ) for batch in batch_iterator(_iter_rows(filename), batch_size=50000): - db.session.bulk_insert_mappings(AccidentMarker, batch) - db.session.commit() + with app.app_context(): + db.session.bulk_insert_mappings(AccidentMarker, batch) + db.session.commit() """ Fills empty geometry object according to coordinates in database """ - db.session.execute( - "UPDATE markers SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326)\ - WHERE geom IS NULL;" - ) - db.session.commit() + with app.app_context(): + db.session.execute( + sa.text( + "UPDATE markers SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326)\ + WHERE geom IS NULL;" + ) + ) + db.session.commit() diff --git a/anyway/parsers/schools.py b/anyway/parsers/schools.py index d0ae3156c..38f76479c 100644 --- a/anyway/parsers/schools.py +++ b/anyway/parsers/schools.py @@ -6,7 +6,7 @@ from static.data.schools import school_fields from anyway.models import School from anyway.utilities import time_delta, chunks -from anyway.app_and_db import db +from anyway.app_and_db import db, app def get_data_value(value): @@ -57,12 +57,14 @@ def import_to_datastore(filepath, batch_size): started = datetime.now() schools = get_schools(filepath) new_items = 0 - all_existing_schools_ids = set(map(lambda x: x[0], db.session.query(School.id).all())) + with app.app_context(): + all_existing_schools_ids = set(map(lambda x: x[0], db.session.query(School.id).all())) schools = [school for school in schools if school["id"] not in all_existing_schools_ids] logging.info(f"inserting {len(schools)} new schools") for schools_chunk in chunks(schools, batch_size): - db.session.bulk_insert_mappings(School, schools_chunk) - db.session.commit() + with app.app_context(): + db.session.bulk_insert_mappings(School, schools_chunk) + db.session.commit() new_items += len(schools) logging.info(f"\t{new_items} items in {time_delta(started)}") return new_items diff --git a/anyway/parsers/schools_with_description.py b/anyway/parsers/schools_with_description.py index c14722b86..d17fc3312 100644 --- a/anyway/parsers/schools_with_description.py +++ b/anyway/parsers/schools_with_description.py @@ -7,7 +7,8 @@ from anyway.models import SchoolWithDescription from anyway.utilities import time_delta, chunks, ItmToWGS84 -from anyway.app_and_db import db +from anyway.app_and_db import db, app +import sqlalchemy as sa school_fields = { "data_year": "שנה", @@ -130,9 +131,10 @@ def get_schools_with_description(schools_description_filepath, schools_coordinat def truncate_schools_with_description(): curr_table = "schools_with_description" - sql_truncate = "TRUNCATE TABLE " + curr_table - db.session.execute(sql_truncate) - db.session.commit() + sql_truncate = sa.text("TRUNCATE TABLE " + curr_table) + with app.app_context(): + db.session.execute(sql_truncate) + db.session.commit() logging.info("Truncated table " + curr_table) @@ -147,8 +149,9 @@ def import_to_datastore(schools_description_filepath, schools_coordinates_filepa new_items = 0 logging.info(f"inserting {len(schools)} new schools") for schools_chunk in chunks(schools, batch_size): - db.session.bulk_insert_mappings(SchoolWithDescription, schools_chunk) - db.session.commit() + with app.app_context(): + db.session.bulk_insert_mappings(SchoolWithDescription, schools_chunk) + db.session.commit() new_items += len(schools) logging.info(f"\t{new_items} items in {time_delta(started)}") return new_items @@ -164,8 +167,11 @@ def parse(schools_description_filepath, schools_coordinates_filepath, batch_size schools_coordinates_filepath=schools_coordinates_filepath, batch_size=batch_size, ) - db.session.execute( - "UPDATE schools_with_description SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326)\ - WHERE geom IS NULL;" - ) + with app.app_context(): + db.session.execute( + sa.text( + "UPDATE schools_with_description SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326)\ + WHERE geom IS NULL;" + ) + ) logging.info("Total: {0} schools in {1}".format(total, time_delta(started))) diff --git a/anyway/parsers/schools_with_description_2020.py b/anyway/parsers/schools_with_description_2020.py index 270f44776..123871b76 100644 --- a/anyway/parsers/schools_with_description_2020.py +++ b/anyway/parsers/schools_with_description_2020.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd from flask_sqlalchemy import SQLAlchemy +import sqlalchemy as sa from ..models import SchoolWithDescription2020 from ..utilities import init_flask, time_delta, chunks, ItmToWGS84 @@ -123,7 +124,7 @@ def get_schools_with_description(schools_description_filepath, schools_coordinat def truncate_schools_with_description(): curr_table = "schools_with_description2020" - sql_truncate = "TRUNCATE TABLE " + curr_table + sql_truncate = sa.text("TRUNCATE TABLE " + curr_table) db.session.execute(sql_truncate) db.session.commit() logging.info("Truncated table " + curr_table) @@ -158,7 +159,9 @@ def parse(schools_description_filepath, schools_coordinates_filepath, batch_size batch_size=batch_size, ) db.session.execute( - "UPDATE schools_with_description SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326)\ + sa.text( + "UPDATE schools_with_description SET geom = ST_SetSRID(ST_MakePoint(longitude,latitude),4326)\ WHERE geom IS NULL;" + ) ) logging.info("Total: {0} schools in {1}".format(total, time_delta(started))) diff --git a/anyway/parsers/traffic_volume.py b/anyway/parsers/traffic_volume.py index 57c105662..2c1735329 100644 --- a/anyway/parsers/traffic_volume.py +++ b/anyway/parsers/traffic_volume.py @@ -9,7 +9,7 @@ from anyway.models import TrafficVolume from anyway.utilities import chunks from anyway.utilities import time_delta -from anyway.app_and_db import db +from anyway.app_and_db import db, app dictionary = { @@ -34,11 +34,12 @@ def get_value_or_none(param): def delete_traffic_volume_of_year(year): - q = db.session.query(TrafficVolume).filter_by(year=year) - if q.all(): - logging.info("Deleting traffic volume of year: " + str(int(year))) - q.delete(synchronize_session="fetch") - db.session.commit() + with app.app_context(): + q = db.session.query(TrafficVolume).filter_by(year=year) + if q.all(): + logging.info("Deleting traffic volume of year: " + str(int(year))) + q.delete(synchronize_session="fetch") + db.session.commit() def delete_traffic_volume_of_directory(path): @@ -85,18 +86,21 @@ def import_to_datastore(path, batch_size): try: assert batch_size > 0 dir_list = glob.glob("{0}/*".format(path)) - for directory in sorted(dir_list, reverse=False): - started = datetime.now() - delete_traffic_volume_of_directory(directory) - traffic_volume_rows = get_traffic_volume_rows(directory) - new_items = 0 - logging.info("inserting " + str(len(traffic_volume_rows)) + " new traffic data rows") - for traffic_volume_chunk in chunks(traffic_volume_rows, batch_size): - db.session.bulk_insert_mappings(TrafficVolume, traffic_volume_chunk) - db.session.commit() - new_items += len(traffic_volume_rows) - logging.info("\t{0} items in {1}".format(new_items, time_delta(started))) - db.session.commit() + with app.app_context(): + for directory in sorted(dir_list, reverse=False): + started = datetime.now() + delete_traffic_volume_of_directory(directory) + traffic_volume_rows = get_traffic_volume_rows(directory) + new_items = 0 + logging.info( + "inserting " + str(len(traffic_volume_rows)) + " new traffic data rows" + ) + for traffic_volume_chunk in chunks(traffic_volume_rows, batch_size): + db.session.bulk_insert_mappings(TrafficVolume, traffic_volume_chunk) + db.session.commit() + new_items += len(traffic_volume_rows) + logging.info("\t{0} items in {1}".format(new_items, time_delta(started))) + db.session.commit() return new_items except: error = ( diff --git a/anyway/parsers/waze/waze_db_functions.py b/anyway/parsers/waze/waze_db_functions.py index 8f3b41781..1c55881b2 100644 --- a/anyway/parsers/waze/waze_db_functions.py +++ b/anyway/parsers/waze/waze_db_functions.py @@ -30,39 +30,41 @@ def enrich_waze_traffic_jams_ended_at_timestamp(timestamp, latest_waze_objects, def _upsert_waze_objects_by_uuid(model, waze_objects): new_records = 0 - with db.session.no_autoflush: - for waze_object in waze_objects: - db.session.flush() - existing_objects = db.session.query(model).filter( - model.uuid == str(waze_object["uuid"]) - ) - object_count = existing_objects.count() - if object_count == 0: - new_object = model(**waze_object) - db.session.add(new_object) - new_records += 1 - elif object_count > 1: - - # sanity: as the uuid field is unique - this should never happen - raise RuntimeError("Too many waze objects with the same uuid") - else: - - # update the existing alert - existing_object = existing_objects[0] - for key, val in waze_object.items(): - setattr(existing_object, key, val) - - db.session.commit() + with app.app_context(): + with db.session.no_autoflush: + for waze_object in waze_objects: + db.session.flush() + existing_objects = db.session.query(model).filter( + model.uuid == str(waze_object["uuid"]) + ) + object_count = existing_objects.count() + if object_count == 0: + new_object = model(**waze_object) + db.session.add(new_object) + new_records += 1 + elif object_count > 1: + + # sanity: as the uuid field is unique - this should never happen + raise RuntimeError("Too many waze objects with the same uuid") + else: + + # update the existing alert + existing_object = existing_objects[0] + for key, val in waze_object.items(): + setattr(existing_object, key, val) + + db.session.commit() return new_records def _enrich_ended_at_timestamp(model, timestamp, latest_waze_objects, back_filled): latest_waze_objects_uuids = [waze_object["uuid"] for waze_object in latest_waze_objects] - query = db.session.query(model).filter( - model.ended_at_estimate.is_(None), - model.uuid.in_(latest_waze_objects_uuids), - model.back_filled.is_(back_filled), - ) - for waze_object in query: - waze_object.ended_at_estimate = timestamp - db.session.commit() + with app.app_context(): + query = db.session.query(model).filter( + model.ended_at_estimate.is_(None), + model.uuid.in_(latest_waze_objects_uuids), + model.back_filled.is_(back_filled), + ) + for waze_object in query: + waze_object.ended_at_estimate = timestamp + db.session.commit() diff --git a/anyway/request_params.py b/anyway/request_params.py index c5a93e417..4e626ad5e 100644 --- a/anyway/request_params.py +++ b/anyway/request_params.py @@ -12,7 +12,7 @@ get_road_segment_by_name_and_road, ) from anyway.backend_constants import BE_CONST -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.parsers import resolution_dict NON_URBAN_INTERSECTION_HEBREW = "non_urban_intersection_hebrew" @@ -307,7 +307,8 @@ def extract_news_flash_obj(vals) -> Optional[NewsFlash]: news_flash_id = vals.get("news_flash_id") if news_flash_id is None: return None - news_flash_obj = db.session.query(NewsFlash).filter(NewsFlash.id == news_flash_id).first() + with app.app_context(): + news_flash_obj = db.session.query(NewsFlash).filter(NewsFlash.id == news_flash_id).first() if not news_flash_obj: logging.warning(f"Could not find news flash id {news_flash_id}") @@ -323,8 +324,9 @@ def get_latest_accident_date(table_obj, filters): BE_CONST.CBS_ACCIDENT_TYPE_1_CODE, BE_CONST.CBS_ACCIDENT_TYPE_3_CODE, ] - query = db.session.query(func.max(table_obj.accident_timestamp)) - df = pd.read_sql_query(query.statement, db.get_engine()) + with app.app_context(): + query = db.session.query(func.max(table_obj.accident_timestamp)) + df = pd.read_sql_query(query.statement, db.get_engine()) return (df.to_dict(orient="records"))[0].get("max_1") # pylint: disable=no-member diff --git a/anyway/utilities.py b/anyway/utilities.py index a3692a39e..c79e5b2ff 100644 --- a/anyway/utilities.py +++ b/anyway/utilities.py @@ -12,6 +12,7 @@ from urllib.parse import urlparse from sqlalchemy import func, or_ from sqlalchemy.sql import select +import sqlalchemy as sa import phonenumbers from dateutil.relativedelta import relativedelta @@ -188,15 +189,17 @@ def fetch_first_and_every_nth_value_for_column(conn, column_to_fetch, n): def truncate_tables(db, tables): logging.info("Deleting tables: " + ", ".join(table.__name__ for table in tables)) - for table in tables: - db.session.query(table).delete() - db.session.commit() + from anyway.app_and_db import app + with app.app_context(): + for table in tables: + db.session.query(table).delete() + db.session.commit() def delete_all_rows_from_table(conn, table): table_name = table.__tablename__ logging.info("Deleting all rows from table " + table_name) - conn.execute("DELETE FROM " + table_name) + conn.execute(sa.text("DELETE FROM " + table_name)) def split_query_to_chunks_by_column(base_select, column_to_chunk_by, chunk_size, conn): diff --git a/anyway/views/news_flash/api.py b/anyway/views/news_flash/api.py index 5227c47b1..9e6e23a7e 100644 --- a/anyway/views/news_flash/api.py +++ b/anyway/views/news_flash/api.py @@ -14,7 +14,7 @@ from sqlalchemy import and_, not_, or_ -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.backend_constants import ( BE_CONST, NewsflashLocationQualification, @@ -71,8 +71,9 @@ def news_flash(): news_flash_id = request.values.get("id") if news_flash_id is not None: - query = db.session.query(NewsFlash) - news_flash_obj = query.filter(NewsFlash.id == news_flash_id).first() + with app.app_context(): + query = db.session.query(NewsFlash) + news_flash_obj = query.filter(NewsFlash.id == news_flash_id).first() if news_flash_obj is not None: if is_news_flash_resolution_supported(news_flash_obj): return Response( @@ -81,22 +82,21 @@ def news_flash(): else: return Response("News flash location not supported", 406) return Response(status=404) - - query = gen_news_flash_query( - db.session, - source=request.values.get("source"), - start_date=request.values.get("start_date"), - end_date=request.values.get("end_date"), - interurban_only=request.values.get("interurban_only"), - road_number=request.values.get("road_number"), - road_segment=request.values.get("road_segment_only"), - last_minutes=request.values.get("last_minutes"), - offset=request.values.get("offset", DEFAULT_OFFSET_REQ_PARAMETER), - limit=request.values.get("limit", DEFAULT_LIMIT_REQ_PARAMETER), - ) - news_flashes = query.all() - - news_flashes_jsons = [n.serialize() for n in news_flashes] + with app.app_context(): + query = gen_news_flash_query( + db.session, + source=request.values.get("source"), + start_date=request.values.get("start_date"), + end_date=request.values.get("end_date"), + interurban_only=request.values.get("interurban_only"), + road_number=request.values.get("road_number"), + road_segment=request.values.get("road_segment_only"), + last_minutes=request.values.get("last_minutes"), + offset=request.values.get("offset", DEFAULT_OFFSET_REQ_PARAMETER), + limit=request.values.get("limit", DEFAULT_LIMIT_REQ_PARAMETER), + ) + news_flashes = query.all() + news_flashes_jsons = [n.serialize() for n in news_flashes] for news_flash in news_flashes_jsons: set_display_source(news_flash) return Response(json.dumps(news_flashes_jsons, default=str), mimetype="application/json") @@ -111,11 +111,10 @@ def news_flash_v2(): if "id" in validated_query_params: return get_news_flash_by_id(validated_query_params["id"]) - - query = gen_news_flash_query_v2(db.session, validated_query_params) - news_flashes = query.all() - - news_flashes_jsons = [n.serialize() for n in news_flashes] + with app.app_context(): + query = gen_news_flash_query_v2(db.session, validated_query_params) + news_flashes = query.all() + news_flashes_jsons = [n.serialize() for n in news_flashes] for news_flash in news_flashes_jsons: set_display_source(news_flash) return Response(json.dumps(news_flashes_jsons, default=str), mimetype="application/json") @@ -126,22 +125,21 @@ def news_flash_new(args: dict) -> List[dict]: if news_flash_id is not None: return single_news_flash(news_flash_id) - - query = gen_news_flash_query( - db.session, - source=args.get("source"), - start_date=args.get("start_date"), - end_date=args.get("end_date"), - interurban_only=args.get("interurban_only"), - road_number=args.get("road_number"), - road_segment=args.get("road_segment_only"), - offset=args.get("offset"), - limit=args.get("limit"), - last_minutes=args.get("last_minutes"), - ) - news_flashes = query.all() - - news_flashes_jsons = [n.serialize() for n in news_flashes] + with app.app_context(): + query = gen_news_flash_query( + db.session, + source=args.get("source"), + start_date=args.get("start_date"), + end_date=args.get("end_date"), + interurban_only=args.get("interurban_only"), + road_number=args.get("road_number"), + road_segment=args.get("road_segment_only"), + offset=args.get("offset"), + limit=args.get("limit"), + last_minutes=args.get("last_minutes"), + ) + news_flashes = query.all() + news_flashes_jsons = [n.serialize() for n in news_flashes] for news_flash in news_flashes_jsons: set_display_source(news_flash) return news_flashes_jsons @@ -159,11 +157,12 @@ def gen_news_flash_query( limit=None, last_minutes=None ): - query = session.query(NewsFlash) - # get all possible sources - sources = [ - str(source_name[0]) for source_name in db.session.query(NewsFlash.source).distinct().all() - ] + with app.app_context(): + query = session.query(NewsFlash) + # get all possible sources + sources = [ + str(source_name[0]) for source_name in db.session.query(NewsFlash.source).distinct().all() + ] if source: if source not in sources: return Response( @@ -209,32 +208,33 @@ def gen_news_flash_query( def gen_news_flash_query_v2(session, valid_params: dict): - query = session.query(NewsFlash) - for param, value in valid_params.items(): - if param == "road_number": - query = query.filter(NewsFlash.road1 == value) - if param == "source": - sources = [source.value for source in value] - query = query.filter(NewsFlash.source.in_(sources)) - if param == "start_date": - query = query.filter(value <= NewsFlash.date <= valid_params["end_date"]) - if param == "resolution": - query = filter_by_resolutions(query, value) - if param == "critical": - query = query.filter(NewsFlash.critical == value) - if param == "last_minutes": - last_timestamp = datetime.datetime.now() - datetime.timedelta(minutes=value) - query = query.filter(NewsFlash.date >= last_timestamp) - query = query.filter( - and_( - NewsFlash.accident == True, - not_(and_(NewsFlash.lat == 0, NewsFlash.lon == 0)), - not_(and_(NewsFlash.lat == None, NewsFlash.lon == None)), - ) - ).order_by(NewsFlash.date.desc()) - query = query.offset(valid_params["offset"]) - query = query.limit(valid_params["limit"]) - return query + with app.app_context(): + query = session.query(NewsFlash) + for param, value in valid_params.items(): + if param == "road_number": + query = query.filter(NewsFlash.road1 == value) + if param == "source": + sources = [source.value for source in value] + query = query.filter(NewsFlash.source.in_(sources)) + if param == "start_date": + query = query.filter(value <= NewsFlash.date <= valid_params["end_date"]) + if param == "resolution": + query = filter_by_resolutions(query, value) + if param == "critical": + query = query.filter(NewsFlash.critical == value) + if param == "last_minutes": + last_timestamp = datetime.datetime.now() - datetime.timedelta(minutes=value) + query = query.filter(NewsFlash.date >= last_timestamp) + query = query.filter( + and_( + NewsFlash.accident == True, + not_(and_(NewsFlash.lat == 0, NewsFlash.lon == 0)), + not_(and_(NewsFlash.lat == None, NewsFlash.lon == None)), + ) + ).order_by(NewsFlash.date.desc()) + query = query.offset(valid_params["offset"]) + query = query.limit(valid_params["limit"]) + return query def set_display_source(news_flash): @@ -253,7 +253,8 @@ def filter_by_timeframe(end_date, news_flash_obj, start_date): def single_news_flash(news_flash_id: int): - news_flash_obj = db.session.query(NewsFlash).filter(NewsFlash.id == news_flash_id).first() + with app.app_context(): + news_flash_obj = db.session.query(NewsFlash).filter(NewsFlash.id == news_flash_id).first() if news_flash_obj is not None: return Response( json.dumps(news_flash_obj.serialize(), default=str), mimetype="application/json" @@ -266,8 +267,9 @@ def get_supported_resolutions() -> set: def get_news_flash_by_id(id: int): - query = db.session.query(NewsFlash) - news_flash_with_id = query.filter(NewsFlash.id == id).first() + with app.app_context(): + query = db.session.query(NewsFlash) + news_flash_with_id = query.filter(NewsFlash.id == id).first() if news_flash_with_id is None: return Response(status=404) if not is_news_flash_resolution_supported(news_flash_with_id): @@ -327,8 +329,9 @@ def update_location_verification_history( location_verification_after_change=new_qualification, location_after_change=new_location, ) - db.session.add(new_location_qualifiction_history) - db.session.commit() + with app.app_context(): + db.session.add(new_location_qualifiction_history) + db.session.commit() def extracted_location_and_qualification(news_flash_obj: NewsFlash): @@ -368,7 +371,8 @@ def update_news_flash_qualifying(id): if road_segment_name is not None or street1_hebrew is not None or yishuv_name is not None: logging.error("only manual update should contain location details.") return return_json_error(Es.BR_BAD_FIELD) - news_flash_obj = db.session.query(NewsFlash).filter(NewsFlash.id == id).first() + with app.app_context(): + news_flash_obj = db.session.query(NewsFlash).filter(NewsFlash.id == id).first() old_location, old_location_qualifiction = extracted_location_and_qualification(news_flash_obj) if news_flash_obj is not None: if manual_update: @@ -390,7 +394,8 @@ def update_news_flash_qualifying(id): news_flash_obj.newsflash_location_qualification = newsflash_location_qualification news_flash_obj.location_qualifying_user = current_user.id - db.session.commit() + with app.app_context(): + db.session.commit() new_location, new_location_qualifiction = extracted_location_and_qualification( news_flash_obj ) diff --git a/anyway/views/schools/api.py b/anyway/views/schools/api.py index a9b3fa19f..658315ce4 100644 --- a/anyway/views/schools/api.py +++ b/anyway/views/schools/api.py @@ -7,37 +7,38 @@ from flask import Response, request from sqlalchemy import and_, not_ -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.models import School, SchoolWithDescription2020 def schools_api(): logging.debug("getting schools") - schools = ( - db.session.query(School) - .filter( - not_(and_(School.latitude == 0, School.longitude == 0)), - not_(and_(School.latitude == None, School.longitude == None)), + with app.app_context(): + schools = ( + db.session.query(School) + .filter( + not_(and_(School.latitude == 0, School.longitude == 0)), + not_(and_(School.latitude == None, School.longitude == None)), + ) + .with_entities( + School.yishuv_symbol, + School.yishuv_name, + School.school_name, + School.longitude, + School.latitude, + ) + .all() ) - .with_entities( - School.yishuv_symbol, - School.yishuv_name, - School.school_name, - School.longitude, - School.latitude, - ) - .all() - ) - schools_list = [ - { - "yishuv_symbol": x.yishuv_symbol, - "yishuv_name": x.yishuv_name, - "school_name": x.school_name, - "longitude": x.longitude, - "latitude": x.latitude, - } - for x in schools - ] + schools_list = [ + { + "yishuv_symbol": x.yishuv_symbol, + "yishuv_name": x.yishuv_name, + "school_name": x.school_name, + "longitude": x.longitude, + "latitude": x.latitude, + } + for x in schools + ] response = Response(json.dumps(schools_list, default=str), mimetype="application/json") response.headers.add("Access-Control-Allow-Origin", "*") return response @@ -47,34 +48,35 @@ def schools_description_api(): # Disable all the no-member violations in this function # pylint: disable=no-member logging.debug("getting schools with description") - query_obj = ( - db.session.query(SchoolWithDescription2020) - .filter( - not_( - and_( - SchoolWithDescription2020.latitude == 0, - SchoolWithDescription2020.longitude == 0, - ) - ), - not_( - and_( - SchoolWithDescription2020.latitude == None, - SchoolWithDescription2020.longitude == None, - ) - ), - ) - .with_entities( - SchoolWithDescription2020.school_id, - SchoolWithDescription2020.school_name, - SchoolWithDescription2020.municipality_name, - SchoolWithDescription2020.yishuv_name, - SchoolWithDescription2020.institution_type, - SchoolWithDescription2020.location_accuracy, - SchoolWithDescription2020.longitude, - SchoolWithDescription2020.latitude, + with app.app_context(): + query_obj = ( + db.session.query(SchoolWithDescription2020) + .filter( + not_( + and_( + SchoolWithDescription2020.latitude == 0, + SchoolWithDescription2020.longitude == 0, + ) + ), + not_( + and_( + SchoolWithDescription2020.latitude == None, + SchoolWithDescription2020.longitude == None, + ) + ), + ) + .with_entities( + SchoolWithDescription2020.school_id, + SchoolWithDescription2020.school_name, + SchoolWithDescription2020.municipality_name, + SchoolWithDescription2020.yishuv_name, + SchoolWithDescription2020.institution_type, + SchoolWithDescription2020.location_accuracy, + SchoolWithDescription2020.longitude, + SchoolWithDescription2020.latitude, + ) ) - ) - df = pd.read_sql_query(query_obj.statement, db.get_engine()) + df = pd.read_sql_query(query_obj.statement, db.get_engine()) schools_list = df.to_dict(orient="records") response = Response(json.dumps(schools_list, default=str), mimetype="application/json") response.headers.add("Access-Control-Allow-Origin", "*") @@ -83,27 +85,28 @@ def schools_description_api(): def schools_yishuvs_api(): logging.debug("getting schools yishuvs") - schools_yishuvs = ( - db.session.query(SchoolWithDescription2020) - .filter( - not_( - and_( - SchoolWithDescription2020.latitude == 0, - SchoolWithDescription2020.longitude == 0, - ) - ), - not_( - and_( - SchoolWithDescription2020.latitude == None, - SchoolWithDescription2020.longitude == None, - ) - ), + with app.app_context(): + schools_yishuvs = ( + db.session.query(SchoolWithDescription2020) + .filter( + not_( + and_( + SchoolWithDescription2020.latitude == 0, + SchoolWithDescription2020.longitude == 0, + ) + ), + not_( + and_( + SchoolWithDescription2020.latitude == None, + SchoolWithDescription2020.longitude == None, + ) + ), + ) + .group_by(SchoolWithDescription2020.yishuv_name) + .with_entities(SchoolWithDescription2020.yishuv_name) + .all() ) - .group_by(SchoolWithDescription2020.yishuv_name) - .with_entities(SchoolWithDescription2020.yishuv_name) - .all() - ) - schools_yishuvs_list = sorted([x[0] for x in schools_yishuvs]) + schools_yishuvs_list = sorted([x[0] for x in schools_yishuvs]) response = Response(json.dumps(schools_yishuvs_list, default=str), mimetype="application/json") response.headers.add("Access-Control-Allow-Origin", "*") return response diff --git a/anyway/views/user_system/api.py b/anyway/views/user_system/api.py index 373f360a8..8ed19c5ad 100644 --- a/anyway/views/user_system/api.py +++ b/anyway/views/user_system/api.py @@ -89,7 +89,8 @@ def decorated_view(*args, **kwargs): @login_manager.user_loader def load_user(id: str) -> Users: - return db.session.query(Users).get(id) + with app.app_context(): + return db.session.query(Users).get(id) # noinspection PyUnusedLocal @@ -164,18 +165,20 @@ def oauth_callback(provider: str) -> Response: user = None try: - user = ( - db.session.query(Users) - .filter_by(oauth_provider=provider, oauth_provider_user_id=user_data.service_user_id) - .one() - ) - except (NoResultFound, MultipleResultsFound): - try: + with app.app_context(): user = ( db.session.query(Users) - .filter_by(oauth_provider=provider, email=user_data.email) + .filter_by(oauth_provider=provider, oauth_provider_user_id=user_data.service_user_id) .one() ) + except (NoResultFound, MultipleResultsFound): + try: + with app.app_context(): + user = ( + db.session.query(Users) + .filter_by(oauth_provider=provider, email=user_data.email) + .one() + ) except MultipleResultsFound as e: # Internal server error - this case should not exists raise e @@ -183,36 +186,36 @@ def oauth_callback(provider: str) -> Response: pass if not user: - user = Users( - user_register_date=datetime.datetime.now(), - user_last_login_date=datetime.datetime.now(), - email=user_data.email, - oauth_provider_user_name=user_data.name, - is_active=True, - oauth_provider=provider, - oauth_provider_user_id=user_data.service_user_id, - oauth_provider_user_domain=user_data.service_user_domain, - oauth_provider_user_picture_url=user_data.picture_url, - oauth_provider_user_locale=user_data.service_user_locale, - oauth_provider_user_profile_url=user_data.user_profile_url, - ) - db.session.add(user) + with app.app_context(): + user = Users( + user_register_date=datetime.datetime.now(), + user_last_login_date=datetime.datetime.now(), + email=user_data.email, + oauth_provider_user_name=user_data.name, + is_active=True, + oauth_provider=provider, + oauth_provider_user_id=user_data.service_user_id, + oauth_provider_user_domain=user_data.service_user_domain, + oauth_provider_user_picture_url=user_data.picture_url, + oauth_provider_user_locale=user_data.service_user_locale, + oauth_provider_user_profile_url=user_data.user_profile_url, + ) + db.session.add(user) else: if not user.is_active: return return_json_error(Es.BR_USER_NOT_ACTIVE) - - user.user_last_login_date = datetime.datetime.now() - if ( - user.oauth_provider_user_id == "unknown-manual-insert" - ): # Only for anyway@anyway.co.il first login - user.oauth_provider_user_id = user_data.service_user_id - user.oauth_provider_user_name = user_data.name - user.oauth_provider_user_domain = user_data.service_user_domain - user.oauth_provider_user_picture_url = user_data.picture_url - user.oauth_provider_user_locale = user_data.service_user_locale - user.oauth_provider_user_profile_url = user_data.user_profile_url - - db.session.commit() + with app.app_context(): + user.user_last_login_date = datetime.datetime.now() + if ( + user.oauth_provider_user_id == "unknown-manual-insert" + ): # Only for anyway@anyway.co.il first login + user.oauth_provider_user_id = user_data.service_user_id + user.oauth_provider_user_name = user_data.name + user.oauth_provider_user_domain = user_data.service_user_domain + user.oauth_provider_user_picture_url = user_data.picture_url + user.oauth_provider_user_locale = user_data.service_user_locale + user.oauth_provider_user_profile_url = user_data.user_profile_url + db.session.commit() redirect_url = BE_CONST.DEFAULT_REDIRECT_URL redirect_url_json_base64 = request.args.get("state", type=str) @@ -232,8 +235,9 @@ def oauth_callback(provider: str) -> Response: @roles_accepted(BE_CONST.Roles2Names.Admins.value) def get_all_users_info() -> Response: dict_ret = [] - for user_obj in db.session.query(Users).order_by(Users.user_register_date).all(): - dict_ret.append(user_obj.serialize_exposed_to_user()) + with app.app_context(): + for user_obj in db.session.query(Users).order_by(Users.user_register_date).all(): + dict_ret.append(user_obj.serialize_exposed_to_user()) return jsonify(dict_ret) @@ -269,51 +273,53 @@ def is_input_fields_malformed(request: Request, allowed_fields: typing.List[str] def change_user_roles(action: str) -> Response: - req_dict = request.json - if not req_dict: - return return_json_error(Es.BR_FIELD_MISSING) - - role_name = req_dict.get("role") - if not role_name: - return return_json_error(Es.BR_ROLE_NAME_MISSING) - role = get_role_object(role_name) - if role is None: - return return_json_error(Es.BR_ROLE_NOT_EXIST, role_name) - - email = req_dict.get("email") - user = get_user_by_email(db, email) - if user is None: - return return_json_error(Es.BR_USER_NOT_FOUND, email) - - if action == "add": - # Add user to role - for user_role in user.roles: - if role.name == user_role.name: - return return_json_error(Es.BR_USER_ALREADY_IN_ROLE, role_name) - user.roles.append(role) - # Add user to role in the current instance - if current_user.email == user.email: - # g is flask global data - g.identity.provides.add(RoleNeed(role.name)) - elif action == "remove": - # Remove user from role - removed = False - for user_role in user.roles: - if role.name == user_role.name: - d = users_to_roles.delete().where( # noqa pylint: disable=no-value-for-parameter - (users_to_roles.c.user_id == user.id) & (users_to_roles.c.role_id == role.id) - ) - db.session.execute(d) - removed = True - if not removed: - return return_json_error(Es.BR_USER_NOT_IN_ROLE, email, role_name) - db.session.commit() + with app.app_context(): + req_dict = request.json + if not req_dict: + return return_json_error(Es.BR_FIELD_MISSING) + + role_name = req_dict.get("role") + if not role_name: + return return_json_error(Es.BR_ROLE_NAME_MISSING) + role = get_role_object(role_name) + if role is None: + return return_json_error(Es.BR_ROLE_NOT_EXIST, role_name) + + email = req_dict.get("email") + user = get_user_by_email(db, email) + if user is None: + return return_json_error(Es.BR_USER_NOT_FOUND, email) + + if action == "add": + # Add user to role + for user_role in user.roles: + if role.name == user_role.name: + return return_json_error(Es.BR_USER_ALREADY_IN_ROLE, role_name) + user.roles.append(role) + # Add user to role in the current instance + if current_user.email == user.email: + # g is flask global data + g.identity.provides.add(RoleNeed(role.name)) + elif action == "remove": + # Remove user from role + removed = False + for user_role in user.roles: + if role.name == user_role.name: + d = users_to_roles.delete().where( # noqa pylint: disable=no-value-for-parameter + (users_to_roles.c.user_id == user.id) & (users_to_roles.c.role_id == role.id) + ) + db.session.execute(d) + removed = True + if not removed: + return return_json_error(Es.BR_USER_NOT_IN_ROLE, email, role_name) + db.session.commit() return Response(status=HTTPStatus.OK) def get_role_object(role_name): - role = db.session.query(Roles).filter(Roles.name == role_name).one() + with app.app_context(): + role = db.session.query(Roles).filter(Roles.name == role_name).one() return role @@ -441,67 +447,70 @@ def update_user_in_db( user_url: str, is_user_completed_registration: bool, ) -> None: - user.first_name = first_name - user.last_name = last_name - user.email = user_db_email - user.phone = phone - user.user_type = user_type - user.user_url = user_url - user.user_desc = user_desc - user.is_user_completed_registration = is_user_completed_registration - db.session.commit() + with app.app_context(): + user.first_name = first_name + user.last_name = last_name + user.email = user_db_email + user.phone = phone + user.user_type = user_type + user.user_url = user_url + user.user_desc = user_desc + user.is_user_completed_registration = is_user_completed_registration + db.session.commit() @roles_accepted(BE_CONST.Roles2Names.Admins.value) def change_user_active_mode() -> Response: - req_dict = request.json - if not req_dict: - return return_json_error(Es.BR_FIELD_MISSING) + with app.app_context(): + req_dict = request.json + if not req_dict: + return return_json_error(Es.BR_FIELD_MISSING) - email = req_dict.get("email") - user = get_user_by_email(db, email) - if user is None: - return return_json_error(Es.BR_USER_NOT_FOUND, email) + email = req_dict.get("email") + user = get_user_by_email(db, email) + if user is None: + return return_json_error(Es.BR_USER_NOT_FOUND, email) - mode = req_dict.get("mode") - if mode is None: - return return_json_error(Es.BR_NO_MODE) + mode = req_dict.get("mode") + if mode is None: + return return_json_error(Es.BR_NO_MODE) - if type(mode) != bool: - return return_json_error(Es.BR_BAD_MODE) + if type(mode) != bool: + return return_json_error(Es.BR_BAD_MODE) - user.is_active = mode - db.session.commit() + user.is_active = mode + db.session.commit() return Response(status=HTTPStatus.OK) @roles_accepted(BE_CONST.Roles2Names.Admins.value) def add_role() -> Response: - req_dict = request.json - if not req_dict: - return return_json_error(Es.BR_FIELD_MISSING) + with app.app_context(): + req_dict = request.json + if not req_dict: + return return_json_error(Es.BR_FIELD_MISSING) - name = req_dict.get("name") - if not name: - return return_json_error(Es.BR_ROLE_NAME_MISSING) + name = req_dict.get("name") + if not name: + return return_json_error(Es.BR_ROLE_NAME_MISSING) - if not is_a_valid_role_name(name): - return return_json_error(Es.BR_BAD_ROLE_NAME) + if not is_a_valid_role_name(name): + return return_json_error(Es.BR_BAD_ROLE_NAME) - role = db.session.query(Roles).filter(Roles.name == name).first() - if role: - return return_json_error(Es.BR_ROLE_EXIST) + role = db.session.query(Roles).filter(Roles.name == name).first() + if role: + return return_json_error(Es.BR_ROLE_EXIST) - description = req_dict.get("description") - if not description: - return return_json_error(Es.BR_ROLE_DESCRIPTION_MISSING) + description = req_dict.get("description") + if not description: + return return_json_error(Es.BR_ROLE_DESCRIPTION_MISSING) - if not is_a_valid_role_description(description): - return return_json_error(Es.BR_BAD_ROLE_DESCRIPTION) + if not is_a_valid_role_description(description): + return return_json_error(Es.BR_BAD_ROLE_DESCRIPTION) - role = Roles(name=name, description=description, create_date=datetime.datetime.now()) - db.session.add(role) - db.session.commit() + role = Roles(name=name, description=description, create_date=datetime.datetime.now()) + db.session.add(role) + db.session.commit() return Response(status=HTTPStatus.OK) @@ -525,10 +534,11 @@ def is_a_valid_role_description(name: str) -> bool: @roles_accepted(BE_CONST.Roles2Names.Admins.value) def get_roles_list() -> Response: - roles_list = db.session.query(Roles).all() - send_list = [ - {"id": role.id, "name": role.name, "description": role.description} for role in roles_list - ] + with app.app_context(): + roles_list = db.session.query(Roles).all() + send_list = [ + {"id": role.id, "name": role.name, "description": role.description} for role in roles_list + ] return Response( response=json.dumps(send_list), status=HTTPStatus.OK, mimetype="application/json" @@ -537,8 +547,9 @@ def get_roles_list() -> Response: @roles_accepted(BE_CONST.Roles2Names.Admins.value) def get_organization_list() -> Response: - orgs_list = db.session.query(Organization).all() - send_list = [org.name for org in orgs_list] + with app.app_context(): + orgs_list = db.session.query(Organization).all() + send_list = [org.name for org in orgs_list] return Response( response=json.dumps(send_list), status=HTTPStatus.OK, mimetype="application/json" @@ -549,47 +560,48 @@ def get_organization_list() -> Response: def add_organization(name: str) -> Response: if not name: return return_json_error(Es.BR_FIELD_MISSING) - - org = db.session.query(Organization).filter(Organization.name == name).first() - if not org: - org = Organization(name=name, create_date=datetime.datetime.now()) - db.session.add(org) - db.session.commit() + with app.app_context(): + org = db.session.query(Organization).filter(Organization.name == name).first() + if not org: + org = Organization(name=name, create_date=datetime.datetime.now()) + db.session.add(org) + db.session.commit() return Response(status=HTTPStatus.OK) @roles_accepted(BE_CONST.Roles2Names.Admins.value) def update_user_org(user_email: str, org_name: str) -> Response: - user = get_user_by_email(db, user_email) - if user is None: - return return_json_error(Es.BR_USER_NOT_FOUND, user_email) + with app.app_context(): + user = get_user_by_email(db, user_email) + if user is None: + return return_json_error(Es.BR_USER_NOT_FOUND, user_email) + + if org_name is not None: + try: + org_obj = db.session.query(Organization).filter(Organization.name == org_name).one() + except NoResultFound: + return return_json_error(Es.BR_ORG_NOT_FOUND) + user.organizations = [org_obj] + else: + user.organizations = [] - if org_name is not None: - try: - org_obj = db.session.query(Organization).filter(Organization.name == org_name).one() - except NoResultFound: - return return_json_error(Es.BR_ORG_NOT_FOUND) - user.organizations = [org_obj] - else: - user.organizations = [] - - db.session.commit() + db.session.commit() return Response(status=HTTPStatus.OK) @roles_accepted(BE_CONST.Roles2Names.Admins.value) def delete_user(email: str) -> Response: - user = get_user_by_email(db, email) - if user is None: - return return_json_error(Es.BR_USER_NOT_FOUND, email) - - # Delete user roles - user.roles = [] + with app.app_context(): + user = get_user_by_email(db, email) + if user is None: + return return_json_error(Es.BR_USER_NOT_FOUND, email) - # Delete user organizations membership - user.organizations = [] + # Delete user roles + user.roles = [] - db.session.delete(user) - db.session.commit() + # Delete user organizations membership + user.organizations = [] + db.session.delete(user) + db.session.commit() return Response(status=HTTPStatus.OK) diff --git a/anyway/views/user_system/user_functions.py b/anyway/views/user_system/user_functions.py index 578de265b..33cc2404d 100644 --- a/anyway/views/user_system/user_functions.py +++ b/anyway/views/user_system/user_functions.py @@ -2,14 +2,15 @@ from flask_login import current_user from flask_sqlalchemy import SQLAlchemy - +from anyway.app_and_db import app from anyway.models import Users def get_user_by_email(db: SQLAlchemy, email: str) -> Optional[Users]: if not email: return None - user = db.session.query(Users).filter(Users.email == email).first() + with app.app_context(): + user = db.session.query(Users).filter(Users.email == email).first() return user diff --git a/anyway/widgets/all_locations_widgets/most_severe_accidents_table_widget.py b/anyway/widgets/all_locations_widgets/most_severe_accidents_table_widget.py index ba72e116a..4c883774c 100644 --- a/anyway/widgets/all_locations_widgets/most_severe_accidents_table_widget.py +++ b/anyway/widgets/all_locations_widgets/most_severe_accidents_table_widget.py @@ -10,8 +10,8 @@ from anyway.models import AccidentMarkerView, InvolvedMarkerView from anyway.widgets.all_locations_widgets.all_locations_widget import AllLocationsWidget from anyway.widgets.widget import register -from anyway.app_and_db import db -from sqlalchemy.sql import text +from anyway.app_and_db import db, app +import sqlalchemy as sa def get_most_severe_accidents_with_entities( table_obj, @@ -29,15 +29,16 @@ def get_most_severe_accidents_with_entities( ] # pylint: disable=no-member filters["accident_severity"] = [AccidentSeverity.FATAL.value, AccidentSeverity.SEVERE.value] - query = get_query(table_obj, filters, start_time, end_time) - columns_with_entities = [text(e) for e in entities] - query = query.with_entities(*columns_with_entities) - query = query.order_by( - getattr(table_obj, "accident_timestamp").desc(), - getattr(table_obj, "accident_severity").asc(), - ) - query = query.limit(limit) - df = pd.read_sql_query(query.statement, db.get_engine()) + with app.app_context(): + query = get_query(table_obj, filters, start_time, end_time) + columns_with_entities = [sa.text(e) for e in entities] + query = query.with_entities(*columns_with_entities) + query = query.order_by( + getattr(table_obj, "accident_timestamp").desc(), + getattr(table_obj, "accident_severity").asc(), + ) + query = query.limit(limit) + df = pd.read_sql_query(query.statement, db.get_engine()) df.columns = [c.replace("_hebrew", "") for c in df.columns] return df.to_dict(orient="records") # pylint: disable=no-member diff --git a/anyway/widgets/road_segment_widgets/accident_type_vehicle_type_road_comparison_widget.py b/anyway/widgets/road_segment_widgets/accident_type_vehicle_type_road_comparison_widget.py index 2252bdea5..2a3b6ae04 100644 --- a/anyway/widgets/road_segment_widgets/accident_type_vehicle_type_road_comparison_widget.py +++ b/anyway/widgets/road_segment_widgets/accident_type_vehicle_type_road_comparison_widget.py @@ -6,7 +6,7 @@ from sqlalchemy import func, distinct, desc from anyway.request_params import RequestParams -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.widgets.widget_utils import get_query, run_query from anyway.models import VehicleMarkerView, AccidentType from anyway.vehicle_type import VehicleCategory @@ -84,20 +84,21 @@ def get_accident_count_by_vehicle_type_query( num_accidents_label: str, vehicle_types: List[int], ) -> db.session.query: - return ( - get_query( - table_obj=VehicleMarkerView, - start_time=start_time, - end_time=end_time, - filters={VehicleMarkerView.vehicle_type.name: vehicle_types}, - ) - .with_entities( - VehicleMarkerView.accident_type, - func.count(distinct(VehicleMarkerView.provider_and_id)).label(num_accidents_label), + with app.app_context(): + return ( + get_query( + table_obj=VehicleMarkerView, + start_time=start_time, + end_time=end_time, + filters={VehicleMarkerView.vehicle_type.name: vehicle_types}, + ) + .with_entities( + VehicleMarkerView.accident_type, + func.count(distinct(VehicleMarkerView.provider_and_id)).label(num_accidents_label), + ) + .group_by(VehicleMarkerView.accident_type) + .order_by(desc(num_accidents_label)) ) - .group_by(VehicleMarkerView.accident_type) - .order_by(desc(num_accidents_label)) - ) @staticmethod def localize_items(request_params: RequestParams, items: Dict) -> Dict: diff --git a/anyway/widgets/road_segment_widgets/accidents_heat_map_widget.py b/anyway/widgets/road_segment_widgets/accidents_heat_map_widget.py index cdfa37b78..6396dd5ad 100644 --- a/anyway/widgets/road_segment_widgets/accidents_heat_map_widget.py +++ b/anyway/widgets/road_segment_widgets/accidents_heat_map_widget.py @@ -10,7 +10,7 @@ from anyway.models import AccidentMarkerView from anyway.widgets.widget import register from anyway.widgets.road_segment_widgets.road_segment_widget import RoadSegmentWidget -from anyway.app_and_db import db +from anyway.app_and_db import db, app from sqlalchemy.sql import text @register @@ -43,9 +43,10 @@ def get_accidents_heat_map(filters, start_time, end_time): BE_CONST.CBS_ACCIDENT_TYPE_1_CODE, BE_CONST.CBS_ACCIDENT_TYPE_3_CODE, ] - query = get_query(AccidentMarkerView, filters, start_time, end_time) - query = query.with_entities(text("longitude"), text("latitude")) - df = pd.read_sql_query(query.statement, db.get_engine()) + with app.app_context(): + query = get_query(AccidentMarkerView, filters, start_time, end_time) + query = query.with_entities(text("longitude"), text("latitude")) + df = pd.read_sql_query(query.statement, db.get_engine()) return df.to_dict(orient="records") # pylint: disable=no-member @staticmethod diff --git a/anyway/widgets/road_segment_widgets/front_to_side_accidents_by_severity.py b/anyway/widgets/road_segment_widgets/front_to_side_accidents_by_severity.py index 244d0f708..b5970d5dc 100644 --- a/anyway/widgets/road_segment_widgets/front_to_side_accidents_by_severity.py +++ b/anyway/widgets/road_segment_widgets/front_to_side_accidents_by_severity.py @@ -11,7 +11,7 @@ from anyway.widgets.road_segment_widgets.road_segment_widget import RoadSegmentWidget from anyway.widgets.widget import register from anyway.widgets.widget_utils import get_query -from anyway.app_and_db import db +from anyway.app_and_db import db, app ROAD_SEGMENT_ACCIDENTS = "specific_road_segment_accidents" @@ -91,30 +91,31 @@ def _get_raw_front_to_side_accidents( == AccidentType.COLLISION_OF_FRONT_TO_SIDE.value, AccidentMarkerView.provider_and_id), ) - query = get_query( - table_obj=AccidentMarkerView, filters={}, start_time=start_date, end_time=end_date - ) - entities_query = query.with_entities( - AccidentMarkerView.accident_severity, - AccidentMarkerView.accident_severity_hebrew, - func.count(distinct(other_accidents)).label(OTHER_ACCIDENTS_LABEL), - func.count(distinct(front_side_accidents)).label(FRONT_SIDE_ACCIDENTS_LABEL), - ) - - if road_segment_id: - entities_query = entities_query.filter( - AccidentMarkerView.road_segment_id == road_segment_id + with app.app_context(): + query = get_query( + table_obj=AccidentMarkerView, filters={}, start_time=start_date, end_time=end_date ) + entities_query = query.with_entities( + AccidentMarkerView.accident_severity, + AccidentMarkerView.accident_severity_hebrew, + func.count(distinct(other_accidents)).label(OTHER_ACCIDENTS_LABEL), + func.count(distinct(front_side_accidents)).label(FRONT_SIDE_ACCIDENTS_LABEL), + ) + + if road_segment_id: + entities_query = entities_query.filter( + AccidentMarkerView.road_segment_id == road_segment_id + ) - query_filtered = entities_query.filter( - AccidentMarkerView.accident_severity.in_( - [AccidentSeverity.FATAL.value, AccidentSeverity.SEVERE.value] + query_filtered = entities_query.filter( + AccidentMarkerView.accident_severity.in_( + [AccidentSeverity.FATAL.value, AccidentSeverity.SEVERE.value] + ) ) - ) - query = query_filtered.group_by( - AccidentMarkerView.accident_severity, AccidentMarkerView.accident_severity_hebrew - ) - results = pd.read_sql_query(query.statement, db.get_engine()).to_dict(orient="records") + query = query_filtered.group_by( + AccidentMarkerView.accident_severity, AccidentMarkerView.accident_severity_hebrew + ) + results = pd.read_sql_query(query.statement, db.get_engine()).to_dict(orient="records") return results @staticmethod diff --git a/anyway/widgets/road_segment_widgets/killed_and_injured_count_per_age_group_widget_utils.py b/anyway/widgets/road_segment_widgets/killed_and_injured_count_per_age_group_widget_utils.py index 27c21ec56..bd5de48fd 100644 --- a/anyway/widgets/road_segment_widgets/killed_and_injured_count_per_age_group_widget_utils.py +++ b/anyway/widgets/road_segment_widgets/killed_and_injured_count_per_age_group_widget_utils.py @@ -5,7 +5,7 @@ from flask_sqlalchemy import BaseQuery from sqlalchemy import func, asc -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.backend_constants import BE_CONST, InjurySeverity from anyway.models import InvolvedMarkerView from anyway.request_params import RequestParams @@ -90,35 +90,36 @@ def defaultdict_int_factory() -> Callable: def create_query_for_killed_and_injured_count_per_age_group( end_time: datetime.date, road_number: int, road_segment: str, start_time: datetime.date ) -> BaseQuery: - query = ( - db.session.query(InvolvedMarkerView) - .filter(InvolvedMarkerView.accident_timestamp >= start_time) - .filter(InvolvedMarkerView.accident_timestamp <= end_time) - .filter( - InvolvedMarkerView.provider_code.in_( - [BE_CONST.CBS_ACCIDENT_TYPE_1_CODE, BE_CONST.CBS_ACCIDENT_TYPE_3_CODE] + with app.app_context(): + query = ( + db.session.query(InvolvedMarkerView) + .filter(InvolvedMarkerView.accident_timestamp >= start_time) + .filter(InvolvedMarkerView.accident_timestamp <= end_time) + .filter( + InvolvedMarkerView.provider_code.in_( + [BE_CONST.CBS_ACCIDENT_TYPE_1_CODE, BE_CONST.CBS_ACCIDENT_TYPE_3_CODE] + ) ) - ) - .filter( - InvolvedMarkerView.injury_severity.in_( - [ - InjurySeverity.KILLED.value, # pylint: disable=no-member - InjurySeverity.SEVERE_INJURED.value, # pylint: disable=no-member - InjurySeverity.LIGHT_INJURED.value, # pylint: disable=no-member - ] + .filter( + InvolvedMarkerView.injury_severity.in_( + [ + InjurySeverity.KILLED.value, # pylint: disable=no-member + InjurySeverity.SEVERE_INJURED.value, # pylint: disable=no-member + InjurySeverity.LIGHT_INJURED.value, # pylint: disable=no-member + ] + ) ) + .filter( + (InvolvedMarkerView.road1 == road_number) + | (InvolvedMarkerView.road2 == road_number) + ) + .filter(InvolvedMarkerView.road_segment_name == road_segment) + .group_by(InvolvedMarkerView.age_group, InvolvedMarkerView.injury_severity) + .with_entities( + InvolvedMarkerView.age_group, + InvolvedMarkerView.injury_severity, + func.count().label("count"), + ) + .order_by(asc(InvolvedMarkerView.age_group)) ) - .filter( - (InvolvedMarkerView.road1 == road_number) - | (InvolvedMarkerView.road2 == road_number) - ) - .filter(InvolvedMarkerView.road_segment_name == road_segment) - .group_by(InvolvedMarkerView.age_group, InvolvedMarkerView.injury_severity) - .with_entities( - InvolvedMarkerView.age_group, - InvolvedMarkerView.injury_severity, - func.count().label("count"), - ) - .order_by(asc(InvolvedMarkerView.age_group)) - ) return query diff --git a/anyway/widgets/road_segment_widgets/motorcycle_accidents_vs_all_accidents_widget.py b/anyway/widgets/road_segment_widgets/motorcycle_accidents_vs_all_accidents_widget.py index 6cb65ff7c..ad9e60972 100644 --- a/anyway/widgets/road_segment_widgets/motorcycle_accidents_vs_all_accidents_widget.py +++ b/anyway/widgets/road_segment_widgets/motorcycle_accidents_vs_all_accidents_widget.py @@ -11,7 +11,7 @@ from anyway.models import InvolvedMarkerView from anyway.vehicle_type import VehicleCategory from anyway.widgets.road_segment_widgets.road_segment_widget import RoadSegmentWidget -from anyway.app_and_db import db +from anyway.app_and_db import db, app from typing import Dict from flask_babel import _ @@ -63,32 +63,32 @@ def motorcycle_accidents_vs_all_accidents( ), else_=literal_column(f"'{vehicle_other}'"), ).label(vehicle_label) - - query = get_query( - table_obj=InvolvedMarkerView, filters={}, start_time=start_time, end_time=end_time - ) - - num_accidents_label = "num_of_accidents" - query = ( - query.with_entities( - case_location, - case_vehicle, - func.count(distinct(InvolvedMarkerView.provider_and_id)).label(num_accidents_label), + with app.app_context(): + query = get_query( + table_obj=InvolvedMarkerView, filters={}, start_time=start_time, end_time=end_time ) - .filter(InvolvedMarkerView.road_type.in_(BE_CONST.NON_CITY_ROAD_TYPES)) - .filter( - InvolvedMarkerView.accident_severity.in_( - # pylint: disable=no-member - [AccidentSeverity.FATAL.value, AccidentSeverity.SEVERE.value] + + num_accidents_label = "num_of_accidents" + query = ( + query.with_entities( + case_location, + case_vehicle, + func.count(distinct(InvolvedMarkerView.provider_and_id)).label(num_accidents_label), + ) + .filter(InvolvedMarkerView.road_type.in_(BE_CONST.NON_CITY_ROAD_TYPES)) + .filter( + InvolvedMarkerView.accident_severity.in_( + # pylint: disable=no-member + [AccidentSeverity.FATAL.value, AccidentSeverity.SEVERE.value] + ) ) + .group_by(location_label, vehicle_label) + .order_by(desc(num_accidents_label)) ) - .group_by(location_label, vehicle_label) - .order_by(desc(num_accidents_label)) - ) - # pylint: disable=no-member - results = pd.read_sql_query(query.statement, db.get_engine()).to_dict( - orient="records" - ) # pylint: disable=no-member + # pylint: disable=no-member + results = pd.read_sql_query(query.statement, db.get_engine()).to_dict( + orient="records" + ) # pylint: disable=no-member counter_road_motorcycle = 0 counter_other_motorcycle = 0 diff --git a/anyway/widgets/road_segment_widgets/top_road_segments_accidents_per_km_widget.py b/anyway/widgets/road_segment_widgets/top_road_segments_accidents_per_km_widget.py index b2c11b07f..d13670dec 100644 --- a/anyway/widgets/road_segment_widgets/top_road_segments_accidents_per_km_widget.py +++ b/anyway/widgets/road_segment_widgets/top_road_segments_accidents_per_km_widget.py @@ -12,7 +12,7 @@ from anyway.widgets.widget_utils import get_query from anyway.models import AccidentMarkerView from anyway.widgets.road_segment_widgets.road_segment_widget import RoadSegmentWidget -from anyway.app_and_db import db +from anyway.app_and_db import db, app @register class TopRoadSegmentsAccidentsPerKmWidget(RoadSegmentWidget): @@ -41,35 +41,35 @@ def get_top_road_segments_accidents_per_km( start_time=start_time, end_time=end_time, ) - try: - query = ( - query.with_entities( - AccidentMarkerView.road_segment_name, - AccidentMarkerView.road_segment_length_km.label("segment_length"), - cast( - ( - func.count(AccidentMarkerView.id) - / AccidentMarkerView.road_segment_length_km - ), - Numeric(10, 4), - ).label("accidents_per_km"), - func.count(AccidentMarkerView.id).label("total_accidents"), - ) - .filter(AccidentMarkerView.road_segment_name.isnot(None)) - .filter( - AccidentMarkerView.accident_severity.in_( - [AccidentSeverity.FATAL.value, AccidentSeverity.SEVERE.value] + with app.app_context(): + query = ( + query.with_entities( + AccidentMarkerView.road_segment_name, + AccidentMarkerView.road_segment_length_km.label("segment_length"), + cast( + ( + func.count(AccidentMarkerView.id) + / AccidentMarkerView.road_segment_length_km + ), + Numeric(10, 4), + ).label("accidents_per_km"), + func.count(AccidentMarkerView.id).label("total_accidents"), ) + .filter(AccidentMarkerView.road_segment_name.isnot(None)) + .filter( + AccidentMarkerView.accident_severity.in_( + [AccidentSeverity.FATAL.value, AccidentSeverity.SEVERE.value] + ) + ) + .group_by( + AccidentMarkerView.road_segment_name, AccidentMarkerView.road_segment_length_km + ) + .order_by(desc("accidents_per_km")) + .limit(limit) ) - .group_by( - AccidentMarkerView.road_segment_name, AccidentMarkerView.road_segment_length_km - ) - .order_by(desc("accidents_per_km")) - .limit(limit) - ) - result = pd.read_sql_query(query.statement, db.get_engine()) + result = pd.read_sql_query(query.statement, db.get_engine()) return result.to_dict(orient="records") # pylint: disable=no-member except Exception as exception: diff --git a/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py b/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py index 645b8cbce..f92fab121 100644 --- a/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py +++ b/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py @@ -6,7 +6,7 @@ from sqlalchemy.sql.elements import and_ from flask_babel import _ from anyway.request_params import RequestParams -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.backend_constants import InjurySeverity, InjuredType from anyway.widgets.widget_utils import ( add_empty_keys_to_gen_two_level_dict, @@ -58,51 +58,51 @@ def generate_items(self) -> None: # TODO: this will fail since there is no news_flash_obj in request_params logging.exception(f"Could not validate parameters yishuv_name + street1_hebrew in widget : {self.name}") return None - - query = ( - db.session.query(InvolvedMarkerView) - .with_entities( - InvolvedMarkerView.accident_year, - InvolvedMarkerView.injury_severity, - func.count().label("count"), - ) - .filter(InvolvedMarkerView.accident_yishuv_name == yishuv_name) - .filter( - InvolvedMarkerView.injury_severity.in_( - [ - InjurySeverity.KILLED.value, - InjurySeverity.SEVERE_INJURED.value, - InjurySeverity.LIGHT_INJURED.value, - ] + with app.app_context(): + query = ( + db.session.query(InvolvedMarkerView) + .with_entities( + InvolvedMarkerView.accident_year, + InvolvedMarkerView.injury_severity, + func.count().label("count"), ) - ) - .filter(InvolvedMarkerView.injured_type == InjuredType.PEDESTRIAN.value) - .filter( - or_( - InvolvedMarkerView.street1_hebrew == street1_hebrew, - InvolvedMarkerView.street2_hebrew == street1_hebrew, + .filter(InvolvedMarkerView.accident_yishuv_name == yishuv_name) + .filter( + InvolvedMarkerView.injury_severity.in_( + [ + InjurySeverity.KILLED.value, + InjurySeverity.SEVERE_INJURED.value, + InjurySeverity.LIGHT_INJURED.value, + ] + ) ) - ) - .filter( - and_( - InvolvedMarkerView.accident_timestamp >= self.request_params.start_time, - InvolvedMarkerView.accident_timestamp <= self.request_params.end_time, + .filter(InvolvedMarkerView.injured_type == InjuredType.PEDESTRIAN.value) + .filter( + or_( + InvolvedMarkerView.street1_hebrew == street1_hebrew, + InvolvedMarkerView.street2_hebrew == street1_hebrew, + ) ) + .filter( + and_( + InvolvedMarkerView.accident_timestamp >= self.request_params.start_time, + InvolvedMarkerView.accident_timestamp <= self.request_params.end_time, + ) + ) + .group_by(InvolvedMarkerView.accident_year, InvolvedMarkerView.injury_severity) ) - .group_by(InvolvedMarkerView.accident_year, InvolvedMarkerView.injury_severity) - ) - res = add_empty_keys_to_gen_two_level_dict( - self.convert_to_dict(query.all()), - [ - str(year) - for year in range( - self.request_params.start_time.year, self.request_params.end_time.year + 1 - ) - ], - InjurySeverity.codes(), - ) - self.items = format_2_level_items(res, None, InjurySeverity) + res = add_empty_keys_to_gen_two_level_dict( + self.convert_to_dict(query.all()), + [ + str(year) + for year in range( + self.request_params.start_time.year, self.request_params.end_time.year + 1 + ) + ], + InjurySeverity.codes(), + ) + self.items = format_2_level_items(res, None, InjurySeverity) except Exception as e: logging.error(f"InjuredAccidentsWithPedestriansWidget.generate_items(): {e}") diff --git a/anyway/widgets/widget_utils.py b/anyway/widgets/widget_utils.py index c9c122939..73eb4aef5 100644 --- a/anyway/widgets/widget_utils.py +++ b/anyway/widgets/widget_utils.py @@ -9,7 +9,7 @@ from sqlalchemy import func, distinct, between, or_ from sqlalchemy.sql import text -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.backend_constants import BE_CONST, LabeledCode, InjurySeverity from anyway.models import InvolvedMarkerView from anyway.request_params import LocationInfo @@ -17,29 +17,30 @@ def get_query(table_obj, filters, start_time, end_time): - query = db.session.query(table_obj) - if start_time: - query = query.filter(getattr(table_obj, "accident_timestamp") >= start_time) - if end_time: - query = query.filter(getattr(table_obj, "accident_timestamp") <= end_time) - if filters: - for field_name, value in filters.items(): - # TODO: why are we always doing a list check? wouldn't it be more efficient to do a single comparison if it's not a list? - if isinstance(value, list): - values = value - else: - values = [value] - - if field_name == "street1_hebrew": - query = query.filter( - or_( - (getattr(table_obj, "street1_hebrew")).in_(values), - (getattr(table_obj, "street2_hebrew")).in_(values), + with app.app_context(): + query = db.session.query(table_obj) + if start_time: + query = query.filter(getattr(table_obj, "accident_timestamp") >= start_time) + if end_time: + query = query.filter(getattr(table_obj, "accident_timestamp") <= end_time) + if filters: + for field_name, value in filters.items(): + # TODO: why are we always doing a list check? wouldn't it be more efficient to do a single comparison if it's not a list? + if isinstance(value, list): + values = value + else: + values = [value] + + if field_name == "street1_hebrew": + query = query.filter( + or_( + (getattr(table_obj, "street1_hebrew")).in_(values), + (getattr(table_obj, "street2_hebrew")).in_(values), + ) ) - ) - else: - query = query.filter((getattr(table_obj, field_name)).in_(values)) - return query + else: + query = query.filter((getattr(table_obj, field_name)).in_(values)) + return query def get_accidents_stats( @@ -55,32 +56,32 @@ def get_accidents_stats( filters = filters or {} provider_code_filters = [BE_CONST.CBS_ACCIDENT_TYPE_1_CODE, BE_CONST.CBS_ACCIDENT_TYPE_3_CODE] filters["provider_code"] = filters.get("provider_code", provider_code_filters) - - # get stats - query = get_query(table_obj, filters, start_time, end_time) - if columns: - columns_with_entities = [text(c) for c in columns] - query = query.with_entities(*columns_with_entities) - if group_by: - if isinstance(group_by, tuple): - if len(group_by) == 2: - query = query.group_by(*group_by) - group_by_with_entities = [text(gb) for gb in group_by] - query = query.with_entities(*group_by_with_entities, func.count(count)) - dd = query.all() - res = retro_dictify(dd) - return res + with app.app_context(): + # get stats + query = get_query(table_obj, filters, start_time, end_time) + if columns: + columns_with_entities = [text(c) for c in columns] + query = query.with_entities(*columns_with_entities) + if group_by: + if isinstance(group_by, tuple): + if len(group_by) == 2: + query = query.group_by(*group_by) + group_by_with_entities = [text(gb) for gb in group_by] + query = query.with_entities(*group_by_with_entities, func.count(count)) + dd = query.all() + res = retro_dictify(dd) + return res + else: + err_msg = f"get_accidents_stats: {group_by}: Only a string or a tuple of two are valid for group_by" + logging.error(err_msg) + raise Exception(err_msg) else: - err_msg = f"get_accidents_stats: {group_by}: Only a string or a tuple of two are valid for group_by" - logging.error(err_msg) - raise Exception(err_msg) - else: - query = query.group_by(group_by) - query = query.with_entities( - text(group_by), - func.count(count) if not cnt_distinct else func.count(distinct(count)), - ) - df = pd.read_sql_query(query.statement, db.get_engine()) + query = query.group_by(group_by) + query = query.with_entities( + text(group_by), + func.count(count) if not cnt_distinct else func.count(distinct(count)), + ) + df = pd.read_sql_query(query.statement, db.get_engine()) df.rename(columns={"count_1": "count"}, inplace=True) # pylint: disable=no-member df.columns = [c.replace("_hebrew", "") for c in df.columns] return ( # pylint: disable=no-member @@ -135,7 +136,8 @@ def get_injured_filters(location_info): def run_query(query: db.session.query) -> Dict: # pylint: disable=no-member - return pd.read_sql_query(query.statement, db.get_engine()).to_dict(orient="records") + with app.app_context(): + return pd.read_sql_query(query.statement, db.get_engine()).to_dict(orient="records") # TODO: Find a better way to deal with typing.Union[int, str] @@ -189,39 +191,41 @@ def get_involved_counts( vehicle_types: Sequence[VehicleType], location_info: LocationInfo, ) -> Dict[str, int]: - table = InvolvedMarkerView + with app.app_context(): + table = InvolvedMarkerView - selected_columns = ( - table.accident_year.label("label_key"), - func.count(distinct(table.involve_id)).label("value"), - ) - - query = ( - db.session.query() - .select_from(table) - .with_entities(*selected_columns) - .filter(between(table.accident_year, start_year, end_year)) - .order_by(table.accident_year) - ) - - if "yishuv_symbol" in location_info: - query = query.filter( - table.accident_yishuv_symbol == location_info["yishuv_symbol"] - ).group_by(table.accident_year) - elif "road_segment_id" in location_info: - query = query.filter(table.road_segment_id == location_info["road_segment_id"]).group_by( - table.accident_year + selected_columns = ( + table.accident_year.label("label_key"), + func.count(distinct(table.involve_id)).label("value"), + ) + query = ( + db.session.query() + .select_from(table) + .with_entities(*selected_columns) + .filter(between(table.accident_year, start_year, end_year)) + .order_by(table.accident_year) ) - if severities: - query = query.filter(table.injury_severity.in_([severity.value for severity in severities])) + if "yishuv_symbol" in location_info: + query = query.filter( + table.accident_yishuv_symbol == location_info["yishuv_symbol"] + ).group_by(table.accident_year) + elif "road_segment_id" in location_info: + query = query.filter( + table.road_segment_id == location_info["road_segment_id"] + ).group_by(table.accident_year) + + if severities: + query = query.filter( + table.injury_severity.in_([severity.value for severity in severities]) + ) - if vehicle_types: - query = query.filter( - table.involve_vehicle_type.in_([v_type.value for v_type in vehicle_types]) - ) + if vehicle_types: + query = query.filter( + table.involve_vehicle_type.in_([v_type.value for v_type in vehicle_types]) + ) - df = pd.read_sql_query(query.statement, db.get_engine()) + df = pd.read_sql_query(query.statement, db.get_engine()) return df.to_dict(orient="records") # pylint: disable=no-member diff --git a/main.py b/main.py index 9428b49ad..743d880be 100755 --- a/main.py +++ b/main.py @@ -80,8 +80,9 @@ def update(source, news_flash_id): @update_news_flash.command() def remove_duplicate_news_flash_rows(): from anyway.parsers import news_flash_db_adapter - - news_flash_db_adapter.init_db().remove_duplicate_rows() + from anyway.app_and_db import app + with app.app_context(): + news_flash_db_adapter.init_db().remove_duplicate_rows() @cli.group() @@ -387,7 +388,7 @@ def truncate_cbs(path): @click.argument("identifiers", nargs=-1) def load_discussions(identifiers): from anyway.models import DiscussionMarker - from anyway.app_and_db import db + from anyway.app_and_db import db, app identifiers = identifiers or sys.stdin @@ -407,12 +408,14 @@ def load_discussions(identifiers): } ) try: - db.session.add(marker) - db.session.commit() - logging.info(f"Added: {identifier}") + with app.app_context(): + db.session.add(marker) + db.session.commit() + logging.info(f"Added: {identifier}") except Exception as e: - db.session.rollback() - logging.warning(f"Failed: {identifier} {e}") + with app.app_context(): + db.session.rollback() + logging.warning(f"Failed: {identifier} {e}") @cli.group() diff --git a/tests/factories.py b/tests/factories.py index b9e2b7be0..fc9836aff 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,13 +1,14 @@ import factory from anyway import models -from anyway.app_and_db import db +from anyway.app_and_db import db, app class DefaultFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: abstract = True - sqlalchemy_session = db.session - sqlalchemy_session_persistence = 'flush' + with app.app_context(): + sqlalchemy_session = db.session + sqlalchemy_session_persistence = 'flush' class AccidentMarkerFactory(DefaultFactory): diff --git a/tests/test_infographic_api.py b/tests/test_infographic_api.py index c6463e368..8b8f848cf 100644 --- a/tests/test_infographic_api.py +++ b/tests/test_infographic_api.py @@ -7,14 +7,14 @@ from six.moves import http_client from anyway import app as flask_app from jsonschema import validate -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.vehicle_type import VehicleCategory from anyway.widgets.road_segment_widgets.accident_count_by_car_type_widget import AccidentCountByCarTypeWidget from anyway.backend_constants import NewsflashLocationQualification - +import sqlalchemy as sa def insert_infographic_mock_data(app): - sql_insert = f""" + sql_insert = sa.text(f""" insert into news_flash (accident, author, date, description, lat, link, lon, title, source, location, road1, road2, resolution, tweet_id, district_hebrew, non_urban_intersection_hebrew, region_hebrew, road_segment_name, street1_hebrew, street2_hebrew, yishuv_name, newsflash_location_qualification, location_qualifying_user) @@ -42,9 +42,10 @@ def insert_infographic_mock_data(app): null, {NewsflashLocationQualification.NOT_VERIFIED.value}, null) returning id; - """ - insert_id = db.session.execute(sql_insert).fetchone()[0] - db.session.commit() + """) + with app.app_context(): + insert_id = db.session.execute(sql_insert).fetchone()[0] + db.session.commit() return insert_id @@ -61,9 +62,10 @@ def get_infographic_data(): def delete_new_infographic_data(new_infographic_data_id): - sql_delete = f"DELETE FROM news_flash where id = {new_infographic_data_id}" - db.session.execute(sql_delete) - db.session.commit() + sql_delete = sa.text(f"DELETE FROM news_flash where id = {new_infographic_data_id}") + with app.app_context(): + db.session.execute(sql_delete) + db.session.commit() class TestInfographicApi: diff --git a/tests/test_news_flash.py b/tests/test_news_flash.py index 75f945df4..21983d6df 100755 --- a/tests/test_news_flash.py +++ b/tests/test_news_flash.py @@ -10,6 +10,7 @@ from anyway.parsers.news_flash_classifiers import classify_tweets, classify_rss from anyway import secrets from anyway.parsers.news_flash_db_adapter import init_db +from anyway.app_and_db import app from anyway.models import NewsFlash from anyway.parsers import timezones from anyway.infographics_utils import is_news_flash_resolution_supported @@ -98,9 +99,10 @@ def test_scrape_ynet(): def test_sanity_get_latest_date(): db = init_db() - db.get_latest_date_of_source("ynet") - db.get_latest_date_of_source("walla") - db.get_latest_date_of_source("twitter") + with app.app_context(): + db.get_latest_date_of_source("ynet") + db.get_latest_date_of_source("walla") + db.get_latest_date_of_source("twitter") @pytest.mark.slow @@ -249,6 +251,7 @@ def test_nan_becomes_none_before_insertion(monkeypatch): db_mock = MagicMock() monkeypatch.setattr('anyway.parsers.news_flash_db_adapter.infographics_data_cache_updater', MagicMock()) adapter = DBAdapter(db=db_mock) - adapter.insert_new_newsflash(newsflash) + with app.app_context(): + adapter.insert_new_newsflash(newsflash) assert newsflash.road1 is None \ No newline at end of file diff --git a/tests/test_news_flash_api.py b/tests/test_news_flash_api.py index 9e0d2de12..c5d0dd0ca 100644 --- a/tests/test_news_flash_api.py +++ b/tests/test_news_flash_api.py @@ -3,7 +3,7 @@ from unittest.mock import patch from http import HTTPStatus from sqlalchemy.orm import sessionmaker -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.views.news_flash.api import ( is_news_flash_resolution_supported, gen_news_flash_query, @@ -20,7 +20,8 @@ # pylint: disable=E1101 class NewsFlashApiTestCase(unittest.TestCase): def setUp(self) -> None: - self.connection = db.get_engine().connect() + with app.app_context(): + self.connection = db.get_engine().connect() # begin a non-ORM transaction self.trans = self.connection.begin() # bind an individual Session to the connection diff --git a/tests/test_queries.py b/tests/test_queries.py index c7e2ac4d2..62f038bf5 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -2,7 +2,7 @@ from factory import make_factory, Iterator from anyway import models -from anyway.app_and_db import db +from anyway.app_and_db import db, app from anyway.backend_constants import InjurySeverity from tests.factories import InvolvedFactory, UrbanAccidentMarkerFactory, \ SuburbanAccidentMarkerFactory, RoadSegmentFactory @@ -12,8 +12,9 @@ @pytest.fixture() def db_session(): - yield db.session - db.session.rollback() + with app.app_context(): + yield db.session + db.session.rollback() @pytest.mark.skip(reason="requires empty db")