Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
donlapark authored Jan 1, 2023
1 parent bd4602c commit 7288f8a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 124 deletions.
83 changes: 16 additions & 67 deletions XLabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,18 @@ def main():
key="threshold"
)

form_cols = st.columns((2.2, 1, 4))
form_cols[1].form_submit_button("Sample", on_click=sample_and_predict)
form_cols = st.columns((2, 2, 2))
form_cols[1].form_submit_button(
"Sample",
on_click=sample_and_predict
)

if "pages" in _state:
st.sidebar.radio('Labels',
_state.pages,
key="label_page",
index=_state.next_clicked,
on_change=update_counter
)

label = _state.label_page
display_main_screen(label)
page_list = list(_state.pages)
tabs = st.tabs(page_list)
for i, tab in enumerate(tabs):
with tab:
display_main_screen(page_list[i])


def update_file():
Expand Down Expand Up @@ -133,7 +132,6 @@ def init_state_params():
previous_: The index of the previous page.
next_: The index of the next page.
next_clicked: The index of the current page.
counter: A dummy used to go to the top of a new screen.
local_results: A dict of outputs of EBM
used to write predictions and plot heatmaps on screen.
models: A dict of EBM models to predict the labels.
Expand Down Expand Up @@ -178,14 +176,9 @@ def create_pages():
for label in _state.pages}

_state.update({
'counter': 1,
'local_results': {},
'next_clicked': 0,
'local_results': {}
})

_state["next_"] = False
_state["previous_"] = False

_state["predictions"] = pd.DataFrame(index=_state.database.index,
columns=_state.pages)

Expand Down Expand Up @@ -285,7 +278,7 @@ def display_main_screen(label):
if _state.unlabeled_index[label].empty:
main_cols[1].write("All "+label+" data are labeled.")
else:
with st.form("Label form"):
with st.form(label + " label form"):
if _state.local_results[label] == {}:
main_cols[1].write("""There are some unlabeled data left. \n \
This means that the confidences of the remaining data are \
Expand All @@ -300,17 +293,12 @@ def display_main_screen(label):
num_features = len(input_features[label])
num_heatmap_rows = math.ceil(num_features/_NUM_FEAT_PER_ROW)

for page in _state.local_results[label]:
for i, page in enumerate(_state.local_results[label]):
current_plot = plot_all_features(_state.local_results[label][page]['data'],
title=str(page),
height=50,
num_rows=num_heatmap_rows)
cols = st.columns((6, 1))
#with cols[0]:
# if _state.text1 is not None:
# st.write(_state.data[_state.text1][page])
# if _state.text2 is not None:
# st.write(_state.data[_state.text2][page])

cols[0].altair_chart(current_plot, use_container_width=True)

Expand All @@ -331,28 +319,14 @@ def display_main_screen(label):
"Automatically label the remaining data?",
("Yes", "No"),
index=1,
key="auto"
key=label+"_auto"
)

label_from_cols[1].form_submit_button("Submit Labels",
on_click=update_and_save,
args=(label,)
)

button_cols = st.columns((3, 1, 1, 4))
button_cols[1].button("Previous", on_click=update_previous_click)
button_cols[2].button("Next", on_click=update_next_click)

components.html(
f"""
<p>{_state.counter}</p>
<script>
window.parent.document.querySelector('section.main').scrollTo(0, 0);
</script>
""",
height=0
)


@st.experimental_memo
def plot_all_features(data, title, height, num_rows):
Expand Down Expand Up @@ -456,29 +430,6 @@ def report_results(idx, col_name):
return results


def update_previous_click():
"""Track the index of the previous page."""
_state.next_clicked -= 1
if _state.next_clicked == -1:
_state.next_clicked = len(_state.pages)-1
_state.counter += 1


def update_next_click():
"""Track the index of the next page."""
_state.next_clicked += 1
if _state.next_clicked == len(_state.pages):
_state.next_clicked = 0
_state.counter += 1


def update_counter():
"""Update the counter after changing to a new screen."""
if not (_state.next_ or _state.previous_):
_state.next_clicked = _state.pages.get_loc(_state.label_page)
_state.counter += 1


def sample_and_predict():
"""Sample data and make a dict of predictions and explanations.
Expand Down Expand Up @@ -535,7 +486,7 @@ def update_and_save(label):
for ix in new_labeled_index]
compute_unlabeled_index(new_labeled_index, label)

if _state.auto == "Yes":
if _state[label+"_auto"] == "Yes":
unlabeled_idx = _state.unlabeled_index[label]
class_pred = _state.predictions.loc[unlabeled_idx, label]
_state.database.loc[unlabeled_idx, label] = class_pred
Expand Down Expand Up @@ -563,12 +514,10 @@ def update_and_save(label):
pkle.dump(_state.models_params, _file, protocol=pkle.HIGHEST_PROTOCOL)

_state.local_results[label] = {}
if _state.auto == "No":
if _state[label+"_auto"] == "No":
X = X.loc[_state.unlabeled_index[label], :]
generate_explanation(X, label, ebm)

_state.counter += 1


def generate_explanation(X, label, model):
"""Create a dict of predictions and explanations of a sample.
Expand Down
68 changes: 11 additions & 57 deletions XLabelDL.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,15 @@ def main():
format="%.2f",
key="threshold")

form_cols = st.columns((2.2, 1, 4))
form_cols = st.columns((2, 2, 2))
form_cols[1].form_submit_button("Sample", on_click=sample_and_predict)

if "pages" in _state:
st.sidebar.radio('Labels',
_state.pages,
key="label_page",
index=_state.next_clicked,
on_change=update_counter)

label = _state.label_page
display_main_screen(label)
page_list = list(_state.pages)
tabs = st.tabs(page_list)
for i, tab in enumerate(tabs):
with tab:
display_main_screen(page_list[i])

filename = _state.configs["db_filename"]
file_pre, file_ext = os.path.splitext(filename)
Expand Down Expand Up @@ -131,7 +128,6 @@ def init_state_params():
previous_: The index of the previous page.
next_: The index of the next page.
next_clicked: The index of the current page.
counter: A dummy used to go to the top of a new screen.
local_results: A dict of outputs of EBM
used to write predictions and plot heatmaps on screen.
models: A dict of EBM models to predict the labels.
Expand Down Expand Up @@ -184,14 +180,9 @@ def create_pages():
}

_state.update({
'counter': 1,
'local_results': {},
'next_clicked': 0,
'local_results': {}
})

_state["next_"] = False
_state["previous_"] = False

_state["predictions"] = pd.DataFrame(index=_state.database.index,
columns=_state.pages)

Expand Down Expand Up @@ -318,7 +309,7 @@ def display_main_screen(label):
if _state.unlabeled_index[label].empty:
main_cols[1].write("All " + label + " data are labeled.")
else:
with st.form("Label form"):
with st.form(label + " label form"):
if _state.local_results[label] == {}:
main_cols[1].write("""There are some unlabeled data left. \n \
This means that the confidences of the remaining data are \
Expand Down Expand Up @@ -365,24 +356,12 @@ def display_main_screen(label):
label_from_cols[1].radio("Automatically label the remaining data?",
("Yes", "No"),
index=1,
key="auto")
key=label+"_auto")

label_from_cols[1].form_submit_button("Submit Labels",
on_click=update_and_save,
args=(label, ))

button_cols = st.columns((3, 1, 1, 4))
button_cols[1].button("Previous", on_click=update_previous_click)
button_cols[2].button("Next", on_click=update_next_click)

components.html(f"""
<p>{_state.counter}</p>
<script>
window.parent.document.querySelector('section.main').scrollTo(0, 0);
</script>
""",
height=0)


@st.experimental_memo
def plot_all_features(data, title, height, num_rows):
Expand Down Expand Up @@ -468,29 +447,6 @@ def report_results(idx, col_name):
return results


def update_previous_click():
"""Track the index of the previous page."""
_state.next_clicked -= 1
if _state.next_clicked == -1:
_state.next_clicked = len(_state.pages) - 1
_state.counter += 1


def update_next_click():
"""Track the index of the next page."""
_state.next_clicked += 1
if _state.next_clicked == len(_state.pages):
_state.next_clicked = 0
_state.counter += 1


def update_counter():
"""Update the counter after changing to a new screen."""
if not (_state.next_ or _state.previous_):
_state.next_clicked = _state.pages.get_loc(_state.label_page)
_state.counter += 1


def sample_and_predict():
"""Sample data and make a dict of predictions and explanations.
Expand Down Expand Up @@ -548,7 +504,7 @@ def update_and_save(label):
]
compute_unlabeled_index(new_labeled_index, label)

if _state.auto == "Yes":
if _state[label + "_auto"] == "Yes":
unlabeled_idx = _state.unlabeled_index[label]
class_pred = _state.predictions.loc[unlabeled_idx, label]
_state.database.loc[unlabeled_idx, label] = class_pred
Expand All @@ -572,12 +528,10 @@ def update_and_save(label):
pkle.dump(_state.models_params, _file, protocol=pkle.HIGHEST_PROTOCOL)

_state.local_results[label] = {}
if _state.auto == "No":
if _state[label + "_auto"] == "No":
X = X.loc[_state.unlabeled_index[label], :]
generate_explanation(X, label, ebm)

_state.counter += 1


def generate_explanation(X, label, model):
"""Create a dict of predictions and explanations of a sample.
Expand Down

0 comments on commit 7288f8a

Please sign in to comment.