diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..10c72da --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,12 @@ +{ + "[python]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit" + }, + "editor.defaultFormatter": "ms-python.black-formatter" + }, + "python.testing.pytestArgs": ["tests"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/poetry.lock b/poetry.lock index e88bbc7..282d66e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1703,8 +1703,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2185,8 +2185,8 @@ astroid = ">=3.2.4,<=3.3.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, - {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, ] isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" mccabe = ">=0.6,<0.8" @@ -2808,6 +2808,20 @@ watchdog = {version = ">=2.1.5,<5", markers = "platform_system != \"Darwin\""} [package.extras] snowflake = ["snowflake-connector-python (>=2.8.0)", "snowflake-snowpark-python (>=0.9.0)"] +[[package]] +name = "streamlit-pills" +version = "0.3.0" +description = "💊 A Streamlit component to show clickable pills/badges" +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "streamlit-pills-0.3.0.tar.gz", hash = "sha256:47668ad4fd8c137b203ee1aec9d9d44ed8b2ff7ded9f586984f204be2eac772f"}, + {file = "streamlit_pills-0.3.0-py3-none-any.whl", hash = "sha256:b66fdf7b1820c09b751a76ef1ae01ab93221d0c2c2d1cd489b711b9afaae0513"}, +] + +[package.dependencies] +streamlit = ">=1.12.0,<2.0.0" + [[package]] name = "tenacity" version = "8.5.0" @@ -3331,4 +3345,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "6d8b3ebf94a8d14be3a6f59799596207e85a5c100ee2a01672a00cdd6ef144e3" +content-hash = "e0c7497eb8245d019a6d71b936614c96e0b458e9b52054d1b06c54affab2bc0c" diff --git a/pyproject.toml b/pyproject.toml index 67edeab..2214cfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ llama-index-llms-openai = "^0.1.27" llama-index-agent-openai = "^0.2.9" streamlit = "^1.36.0" llama-agents = {version = "^0.0.12", extras = ["kafka"]} +streamlit-pills = "^0.3.0" [tool.poetry.group.dev.dependencies] mypy = "^1.10.1" diff --git a/snowflake_cybersyn_demo/apps/streamlit.py b/snowflake_cybersyn_demo/apps/streamlit.py index 61be4ee..3e5bc56 100644 --- a/snowflake_cybersyn_demo/apps/streamlit.py +++ b/snowflake_cybersyn_demo/apps/streamlit.py @@ -3,6 +3,7 @@ import streamlit as st from llama_index.core.llms import ChatMessage, ChatResponseGen from llama_index.llms.openai import OpenAI +from streamlit_pills import pills def _llama_index_stream_wrapper( @@ -12,14 +13,30 @@ def _llama_index_stream_wrapper( yield chunk.delta +def _handle_task_submission() -> None: + st.session_state.submitted_pills.append(st.session_state.task_input) + + llm = OpenAI(model="gpt-4o-mini") st.set_page_config(layout="wide") st.title("Human In The Loop W/ LlamaAgents") +# state management +if "submitted_pills" not in st.session_state: + st.session_state["submitted_pills"] = [] +st.session_state["human_required_pills"] = [] +st.session_state["completed_pills"] = [] + + left, middle, right = st.columns([1, 2, 1], vertical_alignment="bottom") with left: - task_input = st.text_input("Task input", placeholder="Enter a task input.") + task_input = st.text_input( + "Task input", + placeholder="Enter a task input.", + key="task_input", + on_change=_handle_task_submission, + ) with middle: if "messages" not in st.session_state: @@ -45,3 +62,26 @@ def _llama_index_stream_wrapper( st.session_state.messages.append( {"role": "assistant", "content": response} ) + +with right: + if st.session_state.submitted_pills: + submitted_pills = st.session_state.submitted_pills + submitted = pills( + "Submitted", + options=submitted_pills, + key="selected_submitted", + ) + + if st.session_state.human_required_pills: + human_required = pills( + "Human Required", + st.session_state.human_required_pills, + key="selected_human_required", + ) + + if st.session_state.completed_pills: + completed = pills( + "Completed", + st.session_state.completed_pills, + key="selected_completed", + )