Skip to content

Commit

Permalink
Added support for semantic searches based on strings
Browse files Browse the repository at this point in the history
- Fixed an issue when running the MoreLikeThisQuery
- Added a method to query the semantic_embedding API with a string.
- run_knn_topic_inferance -> run_knn_similarity_search, to better reflect the purpose of the script.
  • Loading branch information
tfnribeiro committed Oct 31, 2024
1 parent 6a9f414 commit 8a1c7fc
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
articles_like_this_semantic,
add_topics_based_on_semantic_hood_search,
articles_like_this_tfidf,
find_articles_based_on_text,
)

from zeeguu.core.model.article import Article
from zeeguu.core.model.language import Language

from zeeguu.core.elastic.settings import ES_CONN_STRING, ES_ZINDEX
from zeeguu.core.elastic.settings import ES_CONN_STRING
from elasticsearch import Elasticsearch
from collections import Counter
from zeeguu.core.elastic.elastic_query_builder import build_elastic_recommender_query

from zeeguu.api.app import create_app
import argparse
Expand All @@ -24,7 +22,8 @@
parser = argparse.ArgumentParser(
description="Utilizes the various similar document queries in ES, to analyze the results."
)
parser.add_argument("article_id", type=int, help="article id to search with")
parser.add_argument("-a", "--article_id", type=int, help="article id to search with")
parser.add_argument("-k", "--keyword", type=str, help="keyword to search with")


def search_similar_to_article(article_id):
Expand Down Expand Up @@ -82,10 +81,10 @@ def search_similar_to_article(article_id):
neighbouring_topics = [t.new_topic for a in a_found_t for t in a.new_topics]
TOPICS_TO_NOT_COUNT = set(["news", "aktuell", "nyheder", "nieuws", "article"])
neighbouring_keywords = [
t.url_keywords
t.url_keyword
for a in a_found_t
for t in a.url_keywords
if t.url_keywords.keyword not in TOPICS_TO_NOT_COUNT
if t.url_keyword.keyword not in TOPICS_TO_NOT_COUNT
]

print()
Expand All @@ -105,7 +104,37 @@ def search_similar_to_article(article_id):
print(a_found[0].content[:100])


def search_similar_to_keyword(keyword):
app = create_app()
app.app_context().push()

es = Elasticsearch(ES_CONN_STRING)

a_found, hits = find_articles_based_on_text(keyword)
print("------------------------------------------------")

print("Keyword Searched: ", keyword)
print()
print("Similar articles:")
for hit in hits:
print(
hit["_id"],
hit["_source"]["old_topics"],
hit["_source"]["language"],
f"New Topics: {hit['_source']['topics']}",
hit["_source"].get("url_keywords", []),
hit["_source"].get("url", ""),
hit["_score"],
)
print("Article list: ")
print(a_found)


if __name__ == "__main__":
args = parser.parse_args()
article_id = args.article_id
search_similar_to_article(article_id)
keyword = args.keyword
if article_id:
search_similar_to_article(article_id)
if keyword:
search_similar_to_keyword(keyword)
33 changes: 32 additions & 1 deletion zeeguu/core/elastic/elastic_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def more_like_this_query(count, article_text, language, page=0):
.filter("term", language=language.name.lower())
)

return {"from": page * count, "size": count, "query": s.to_dict()}
return {"from": page * count, "size": count, "query": s.to_dict()["query"]}


def build_elastic_recommender_query(
Expand Down Expand Up @@ -326,6 +326,37 @@ def build_elastic_semantic_sim_query(
return query


def build_elastic_semantic_sim_query_for_text(
count,
text_embedding,
n_candidates=100,
language=None,
):
"""
Similar to build_elastic_semantic_sim_query, but taking a text embedding
"""
s = Search()
# s = s.exclude("match", id=article.id)
if language:
s = s.knn(
field="sem_vec",
k=count,
num_candidates=n_candidates,
query_vector=text_embedding,
filter=(Q("match", language__keyword=language.name)),
)
else:
s = s.knn(
field="sem_vec",
k=count,
num_candidates=n_candidates,
query_vector=text_embedding,
)

query = s.to_dict()
return query


def build_elastic_semantic_sim_query_for_topic_cls(
k_count,
article,
Expand Down
1 change: 1 addition & 0 deletions zeeguu/core/semantic_search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
articles_like_this_semantic,
add_topics_based_on_semantic_hood_search,
articles_like_this_tfidf,
find_articles_based_on_text,
)
30 changes: 29 additions & 1 deletion zeeguu/core/semantic_search/elastic_semantic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
from zeeguu.core.elastic.elastic_query_builder import (
build_elastic_semantic_sim_query,
build_elastic_semantic_sim_query_for_topic_cls,
build_elastic_semantic_sim_query_for_text,
more_like_this_query,
)
from zeeguu.core.util.timer_logging_decorator import time_this
from zeeguu.core.elastic.settings import ES_CONN_STRING, ES_ZINDEX
from zeeguu.core.semantic_vector_api import get_embedding_from_article
from zeeguu.core.semantic_vector_api import (
get_embedding_from_article,
get_embedding_from_text,
)


@time_this
Expand Down Expand Up @@ -86,6 +90,30 @@ def add_topics_based_on_semantic_hood_search(
return [], []


@time_this
def find_articles_based_on_text(text, k: int = 9): # hood = (slang) neighborhood
query_body = build_elastic_semantic_sim_query_for_text(
k, get_embedding_from_text(text)
)
final_article_mix = []

try:
es = Elasticsearch(ES_CONN_STRING)
res = es.search(index=ES_ZINDEX, body=query_body)

hit_list = res["hits"].get("hits")
final_article_mix.extend(_to_articles_from_ES_hits(hit_list))

return [
a for a in final_article_mix if a is not None and not a.broken
], hit_list
except ConnectionError:
print("Could not connect to ES server.")
except Exception as e:
print(f"Error encountered: {e}")
return [], []


def _to_articles_from_ES_hits(hits):
articles = []
for hit in hits:
Expand Down
6 changes: 5 additions & 1 deletion zeeguu/core/semantic_vector_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .retrieve_embeddings import get_embedding_from_article, EMB_API_CONN_STRING
from .retrieve_embeddings import (
get_embedding_from_article,
get_embedding_from_text,
EMB_API_CONN_STRING,
)
11 changes: 11 additions & 0 deletions zeeguu/core/semantic_vector_api/retrieve_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"ZEEGUU_EMB_API_CONN_STRING", "http://127.0.0.1:8000"
)


def get_embedding_from_article(a: Article):
r = requests.post(
url=f"{EMB_API_CONN_STRING}/get_article_embedding",
Expand All @@ -15,3 +16,13 @@ def get_embedding_from_article(a: Article):
},
)
return r.json()


def get_embedding_from_text(text: str, language: str = None):
data = {
"article_content": text,
}
if language:
data["article_language"] = language
r = requests.post(url=f"{EMB_API_CONN_STRING}/get_article_embedding", json=data)
return r.json()

0 comments on commit 8a1c7fc

Please sign in to comment.