diff --git a/code/pages/4_RL model playground.py b/code/pages/4_RL model playground.py index 6adf39c..b5bb8a2 100644 --- a/code/pages/4_RL model playground.py +++ b/code/pages/4_RL model playground.py @@ -12,6 +12,7 @@ from aind_dynamic_foraging_models import generative_model from aind_dynamic_foraging_models.generative_model import ForagerCollection from aind_dynamic_foraging_models.generative_model.params import ParamsSymbols +from aind_dynamic_foraging_basic_analysis import compute_foraging_efficiency try: st.set_page_config(layout="wide", @@ -251,7 +252,14 @@ def app(): # -- Run the model -- forager.perform(task) - if_plot_latent = st.checkbox("Plot latent variables", value=False) + # Evaluate the foraging efficiency + foraging_eff, foraging_eff_random_seed = compute_foraging_efficiency( + baited=task.reward_baiting, + choice_history=forager.get_choice_history(), + reward_history=forager.get_reward_history(), + p_reward=forager.get_p_reward(), + random_number=task.random_numbers.T, + ) # Capture the results # ground_truth_params = forager.params.model_dump() @@ -262,8 +270,11 @@ def app(): # reward_history = forager.get_reward_history() # Plot the session results + if_plot_latent = st.checkbox("Plot latent variables", value=False) fig, axes = forager.plot_session(if_plot_latent=if_plot_latent) - with st.columns([1, 0.5])[0]: + + col0 = st.columns([1, 0.5]) + with col0[0]: st.pyplot(fig) # Plot block logic @@ -274,5 +285,9 @@ def app(): ax[0].legend() fig.suptitle("Reward schedule") st.pyplot(fig) + + with col0[1]: + st.write(f"#### **Foraging efficiency**:") + st.write(f"# {foraging_eff_random_seed:.3f}") app() \ No newline at end of file