diff --git a/anyway/llm.py b/anyway/llm.py new file mode 100644 index 00000000..f0209ec4 --- /dev/null +++ b/anyway/llm.py @@ -0,0 +1,83 @@ +from openai import OpenAI +import json +import tiktoken + +from langchain.output_parsers import EnumOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_openai import ChatOpenAI +import langchain +from enum import Enum +from anyway import secrets + +api_key = secrets.get("OPENAI_API_KEY") +client = OpenAI(api_key=api_key) + +langchain.debug = True +model = ChatOpenAI(api_key=api_key, temperature=0) + + +def match_streets_with_langchain(street_names, location): + street_names.append("-") + Streets = Enum('Streets', {name: name for name in street_names}) + + parser = EnumOutputParser(enum=Streets) + print(parser.get_format_instructions()) + prompt = PromptTemplate( + template="Return the street that is mentioned in the location string. if non matches return '-'.\nstreets: {streets}\n" + + "location_string:{location}\n{format_instructions}\n", + input_variables=["streets", "location"], + partial_variables={"format_instructions": parser.get_format_instructions()}, + ) + + chain = prompt | model | parser + + res = chain.invoke({"streets": street_names, "location": location}) + return res + + +def count_tokens_for_prompt(messages, model): + tokenizer = tiktoken.encoding_for_model(model) + total_tokens = 0 + for message in messages: + # Each message has a role and content + message_tokens = tokenizer.encode(f"{message['role']}: {message['content']}") + total_tokens += len(message_tokens) + # Additional tokens for formatting + total_tokens += 4 # approx overhead for each message (role + delimiters) + + return total_tokens + + +def count_tokens(text, model): + tokenizer = tiktoken.encoding_for_model(model) + tokens = tokenizer.encode(text) + return len(tokens) + + +def ask_gpt(system_message, user_message, model="gpt-4o"): + messages = [ + {"role": "system", "content": system_message}, + {"role": "user", "content": user_message} + ] + completion = client.chat.completions.create( + response_format={"type": "json_object"}, + model=model, + messages=messages + ) + print(f"tokens for prompt: {count_tokens_for_prompt(messages, model)}") + return completion.choices[0].message + + +def ask_ai_about_street_matching(streets, location_string, model="gpt-4-turbo"): + system_message = """ + Given a list of streets, return the name of the street that is mentioned in the provided location string. + Return the name exactly as appears in list. + If no match is found, return "-". + Return json with field "street" and your answer. + Select one of the following options: + """ + json.dumps(streets + ["-"]) + input = json.dumps({"streets": streets, "location": location_string}) + reply = ask_gpt(system_message, input, model) + # print(f"tokens: {count_tokens(reply.content, model)}") + result = json.loads(reply.content)["street"] + return result, result in streets diff --git a/anyway/parsers/location_extraction.py b/anyway/parsers/location_extraction.py index a43bdf62..bb1d7cd3 100644 --- a/anyway/parsers/location_extraction.py +++ b/anyway/parsers/location_extraction.py @@ -10,6 +10,7 @@ from anyway.parsers.resolution_fields import ResolutionFields as RF from anyway import secrets from anyway.models import AccidentMarkerView, RoadSegments +from anyway.llm import ask_ai_about_street_matching from sqlalchemy import not_ import pandas as pd from sqlalchemy.orm import load_only @@ -176,19 +177,7 @@ def get_bounding_box(latitude, longitude, distance_in_km): return final_loc -def get_db_matching_location(db, latitude, longitude, resolution, road_no=None): - """ - extracts location from db by closest geo point to location found, using road number if provided and limits to - requested resolution - :param db: the DB - :param latitude: location latitude - :param longitude: location longitude - :param resolution: wanted resolution - :param road_no: road number if there is - :return: a dict containing all the geo fields stated in - resolution dict, with values filled according to resolution - """ - # READ MARKERS FROM DB +def read_markers_and_distance_from_location(db, latitude, longitude, resolution, road_no=None): geod = Geodesic.WGS84 relevant_fields = RF.get_possible_fields(resolution) markers = db.get_markers_for_location_extraction() @@ -222,6 +211,24 @@ def get_db_matching_location(db, latitude, longitude, resolution, road_no=None): markers["dist_point"] = markers.apply( lambda x: geod.Inverse(latitude, longitude, x["latitude"], x["longitude"])["s12"], axis=1 ).replace({np.nan: None}) + return markers + + +def get_db_matching_location(db, latitude, longitude, resolution, road_no=None): + """ + extracts location from db by closest geo point to location found, using road number if provided and limits to + requested resolution + :param db: the DB + :param latitude: location latitude + :param longitude: location longitude + :param resolution: wanted resolution + :param road_no: road number if there is + :return: a dict containing all the geo fields stated in + resolution dict, with values filled according to resolution + """ + # READ MARKERS FROM DB + relevant_fields = RF.get_possible_fields(resolution) + markers = read_markers_and_distance_from_location(db, latitude, longitude, resolution, road_no) most_fit_loc = ( markers.loc[markers["dist_point"] == markers["dist_point"].min()].iloc[0].to_dict() @@ -240,6 +247,24 @@ def get_db_matching_location(db, latitude, longitude, resolution, road_no=None): return final_loc +def read_n_closest_streets(db, n, latitude, longitude, road_no=None): + markers = read_markers_and_distance_from_location( + db, latitude, longitude, BE_CONST.ResolutionCategories.STREET, road_no + ) + # Sort by distance + sorted_markers = markers.sort_values(by="dist_point") + + # Drop duplicates to ensure unique street1_hebrew values + unique_street_markers = sorted_markers.drop_duplicates(subset="street1_hebrew") + + # Select the top n entries + top_n_unique_streets = unique_street_markers.head(n) + + # Convert to dictionary if needed + result_dicts = top_n_unique_streets.to_dict(orient="records") + return [result["street1_hebrew"] for result in result_dicts] + + def set_accident_resolution(accident_row): """ set the resolution of the accident @@ -282,11 +307,12 @@ def reverse_geocode_extract(latitude, longitude): try: gmaps = googlemaps.Client(key=secrets.get("GOOGLE_MAPS_KEY")) geocode_result = gmaps.reverse_geocode((latitude, longitude)) - + print(geocode_result) # if we got no results, move to next iteration of location string if not geocode_result: return None except Exception as _: + logging.info(_) logging.info("exception in gmaps") return None # logging.info(geocode_result) @@ -539,6 +565,42 @@ def extract_geo_features(db, newsflash: NewsFlash, use_existing_coordinates_only if location_from_db is not None: update_location_fields(newsflash, location_from_db) try_find_segment_id(newsflash) + logging.debug(newsflash.resolution) + if newsflash.resolution == BE_CONST.ResolutionCategories.STREET: + try_improve_street_identification(newsflash) + + +def try_improve_street_identification(newsflash): + from anyway.parsers import news_flash_db_adapter + + db = news_flash_db_adapter.init_db() + all_closest_streets = read_n_closest_streets(db, 20, newsflash.lat, newsflash.lon) + + num_of_streets_for_first_try = 5 + streets_for_first_try = all_closest_streets[:num_of_streets_for_first_try] + streets_for_second_try = all_closest_streets[num_of_streets_for_first_try:] + + result, result_in_input = ask_ai_about_street_matching( + streets_for_first_try, newsflash.location + ) + logging.debug(f"result of 1st try {result}") + if not result_in_input: + logging.debug(f"street matching failed first try for newsflash {newsflash.id}") + result, result_in_input = ask_ai_about_street_matching( + streets_for_second_try, newsflash.location + ) + logging.debug(f"result of 2nd try {result}") + if result_in_input: + if result == newsflash.street1_hebrew: + logging.debug("street matching succeeded, street not changed") + else: + logging.debug( + f"street matching succeeded, street updated for {newsflash.id} " + f"from {newsflash.street1_hebrew} to {result}" + ) + newsflash.street1_hebrew = result + else: + logging.debug(f"street matching failed second try for newsflash {newsflash.id}") def update_location_fields(newsflash, location_from_db): diff --git a/main.py b/main.py index 1db1d813..668cf637 100755 --- a/main.py +++ b/main.py @@ -333,6 +333,17 @@ def infographics_pictures(id): raise Exception("generation failed") +@process.command() +@click.option("--id", type=int) +def street_name(id): + from anyway.parsers import news_flash_db_adapter + from anyway.parsers.location_extraction import try_improve_street_identification + + db = news_flash_db_adapter.init_db() + newsflash = db.get_newsflash_by_id(id).first() + try_improve_street_identification(newsflash) + + @process.group() def cache(): pass diff --git a/requirements.txt b/requirements.txt index 2b63ecdc..a27fb42e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ Flask-Login==0.5.0 Flask-SQLAlchemy==2.4.1 flask-restx==0.5.1 Jinja2==3.1.4 -SQLAlchemy==1.3.17 +SQLAlchemy==1.4 Werkzeug==2.0.3 alembic==1.4.2 attrs==23.1.0 @@ -53,3 +53,7 @@ swifter==1.3.4 telebot==0.0.5 selenium==4.11.2 apache-airflow-client==2.6.2 +openai==1.45.0 +langchain==0.2.16 +langchain_openai==0.1.25 +python-dotenv \ No newline at end of file