From 93381f399e0a453bc3e18ce69c98dfbf76c2e605 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 31 Aug 2024 07:00:39 +0000 Subject: [PATCH] fix: improve legend and fix bug --- code/pages/4_RL model playground.py | 32 ++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/code/pages/4_RL model playground.py b/code/pages/4_RL model playground.py index d62dd0f..6adf39c 100644 --- a/code/pages/4_RL model playground.py +++ b/code/pages/4_RL model playground.py @@ -171,14 +171,20 @@ def select_task(task_family, reward_baiting, n_trials, seed): ], index=0, ) + max_block_tally = st.slider("max block tally", 1, 10, 4) + persev_add = st.checkbox("anti-perseveration", value=True) + perseverative_limit = 3 + if persev_add: + perseverative_limit = st.slider("perseverative limit", 1, 10, 3) return UncoupledBlockTask( rwd_prob_array=rwd_prob_array, block_min=block_min, block_max=block_max, - persev_add=True, - perseverative_limit=4, - max_block_tally=4, + persev_add=persev_add, + perseverative_limit=perseverative_limit, + max_block_tally=max_block_tally, num_trials=n_trials, + reward_baiting=reward_baiting, seed=seed, ) @@ -191,6 +197,7 @@ def select_task(task_family, reward_baiting, n_trials, seed): sigma=[sigma, sigma], mean=[0, 0], num_trials=n_trials, + reward_baiting=reward_baiting, seed=seed, ) @@ -204,12 +211,13 @@ def app(): col0 = st.columns([1, 1]) with col0[0]: - with st.expander("Agent", expanded=True): + st.markdown("#### Select agent ([🤖 aind-dynamic-foraging-models](https://github.com/AllenNeuralDynamics/aind-dynamic-foraging-models/blob/develop/src/aind_dynamic_foraging_models/generative_model/foragers.py))") + with st.expander("", expanded=True): col1 = st.columns([1, 2]) with col1[0]: # -- Select forager family -- agent_family = st.selectbox( - "Select agent family", + "Agent type", list(model_families.keys()) ) # -- Select forager -- @@ -224,12 +232,13 @@ def app(): # all_presets = forager_collection.FORAGER_PRESETS.keys() with col0[1]: - with st.expander("Task", expanded=True): + st.markdown("#### Select dynamic foraging task ([🏋️aind-behavior-gym](https://github.com/AllenNeuralDynamics/aind-behavior-gym/tree/develop/src/aind_behavior_gym/dynamic_foraging/task))") + with st.expander("", expanded=True): col1 = st.columns([1, 2]) with col1[0]: # -- Select task family -- task_family = st.selectbox( - "Select task family", + "Task type", list(task_families.keys()), index=0, ) @@ -256,5 +265,14 @@ def app(): fig, axes = forager.plot_session(if_plot_latent=if_plot_latent) with st.columns([1, 0.5])[0]: st.pyplot(fig) + + # Plot block logic + if task_family == "Uncoupled block task": + if_show_block_logic = st.checkbox("Show uncoupled block logic", value=False) + if if_show_block_logic: + fig, ax = task.plot_reward_schedule() + ax[0].legend() + fig.suptitle("Reward schedule") + st.pyplot(fig) app() \ No newline at end of file