Skip to content

Commit

Permalink
fix: improve legend and fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Aug 31, 2024
1 parent 1590aa6 commit 93381f3
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions code/pages/4_RL model playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,20 @@ def select_task(task_family, reward_baiting, n_trials, seed):
],
index=0,
)
max_block_tally = st.slider("max block tally", 1, 10, 4)
persev_add = st.checkbox("anti-perseveration", value=True)
perseverative_limit = 3
if persev_add:
perseverative_limit = st.slider("perseverative limit", 1, 10, 3)
return UncoupledBlockTask(
rwd_prob_array=rwd_prob_array,
block_min=block_min,
block_max=block_max,
persev_add=True,
perseverative_limit=4,
max_block_tally=4,
persev_add=persev_add,
perseverative_limit=perseverative_limit,
max_block_tally=max_block_tally,
num_trials=n_trials,
reward_baiting=reward_baiting,
seed=seed,
)

Expand All @@ -191,6 +197,7 @@ def select_task(task_family, reward_baiting, n_trials, seed):
sigma=[sigma, sigma],
mean=[0, 0],
num_trials=n_trials,
reward_baiting=reward_baiting,
seed=seed,
)

Expand All @@ -204,12 +211,13 @@ def app():
col0 = st.columns([1, 1])

with col0[0]:
with st.expander("Agent", expanded=True):
st.markdown("#### Select agent ([🤖 aind-dynamic-foraging-models](https://github.com/AllenNeuralDynamics/aind-dynamic-foraging-models/blob/develop/src/aind_dynamic_foraging_models/generative_model/foragers.py))")
with st.expander("", expanded=True):
col1 = st.columns([1, 2])
with col1[0]:
# -- Select forager family --
agent_family = st.selectbox(
"Select agent family",
"Agent type",
list(model_families.keys())
)
# -- Select forager --
Expand All @@ -224,12 +232,13 @@ def app():
# all_presets = forager_collection.FORAGER_PRESETS.keys()

with col0[1]:
with st.expander("Task", expanded=True):
st.markdown("#### Select dynamic foraging task ([🏋️aind-behavior-gym](https://github.com/AllenNeuralDynamics/aind-behavior-gym/tree/develop/src/aind_behavior_gym/dynamic_foraging/task))")
with st.expander("", expanded=True):
col1 = st.columns([1, 2])
with col1[0]:
# -- Select task family --
task_family = st.selectbox(
"Select task family",
"Task type",
list(task_families.keys()),
index=0,
)
Expand All @@ -256,5 +265,14 @@ def app():
fig, axes = forager.plot_session(if_plot_latent=if_plot_latent)
with st.columns([1, 0.5])[0]:
st.pyplot(fig)

# Plot block logic
if task_family == "Uncoupled block task":
if_show_block_logic = st.checkbox("Show uncoupled block logic", value=False)
if if_show_block_logic:
fig, ax = task.plot_reward_schedule()
ax[0].legend()
fig.suptitle("Reward schedule")
st.pyplot(fig)

app()

0 comments on commit 93381f3

Please sign in to comment.