-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
95 lines (77 loc) · 3.07 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Simple streamlit app to explore talk python to me podcasts. """
import streamlit as st
import components.sidebar as sb
import utils as ut
from embeddings import (
get_chroma,
get_embedding_function,
query_db,
_match_aggregating_function,
)
TALK_PYTHON_LOGO = r"https://cdn-podcast.talkpython.fm/static/img/talk_python_logo_mic.png?cache_id=dd08157a0f56a88381ec34afe167db21"
# The following check is needed for streamlit cloud to download the data,
# if starting from docker this is done prior to deploying the app
import talking_python.release as rel
if not rel.get_chroma_dir().exists():
print("chroma database wasn't found, downloading!")
from talking_python.utilities import prepare_chroma
prepare_chroma()
def clear_submit():
st.session_state["submit"] = False
st.set_page_config(
page_title="Explore Talk Python To Me", page_icon=":snake:", layout="wide"
)
st.header(":snake: Explore Talk Python To Me")
st.markdown(
"##### Use [semantic search](https://www.sbert.net/examples/applications/semantic-search/README.html#background)"
" to find episodes related to your queries."
)
sb.sidebar()
query = st.text_area("Ask a question", on_change=clear_submit)
with st.expander("Advanced Options"):
# TODO: Explain this field is not the same as the number of episodes?
max_episodes = int(
st.number_input(
"Set the maximum numbers of episodes to suggest. Defaults to 20",
value=20,
min_value=1,
max_value=60,
)
)
aggregating_function_name = st.selectbox(
"Aggregating_function", ("sum_weighted", "minimum", "raw")
)
button = st.button("Submit")
if button or st.session_state.get("submit"):
if not st.session_state.get("api_key_configured"):
st.error("Please configure your HuggingFace API Token!")
elif not query:
st.error("Please enter a question!")
else:
try:
emb_fn = get_embedding_function()
chroma = get_chroma(emb_fn)
with st.spinner("Querying the database... This may take a while⏳"):
# TODO: The results should be aggregated outside to avoid
# hitting the model again to embed the query and finding in
# crhom.
episode_titles = query_db(
chroma,
query=query,
n_results=max_episodes,
aggregating_function=_match_aggregating_function(
aggregating_function_name
),
)
st.session_state["submit"] = True
# Output column to show the results
podcasts_col = st.columns(1)[0]
with podcasts_col:
st.markdown("#### Suggested podcasts")
with st.spinner(
"Extracting episodes table... The first time may take a while⏳"
):
ut.show_episodes_table(episode_titles=episode_titles)
except Exception as e:
# Unexpected error, just print it
st.error(e)