Skip to content

Commit

Permalink
Merge pull request #106 from AllenNeuralDynamics/han_#105_optional_lo…
Browse files Browse the repository at this point in the history
…ad_docDB

Han #105 performance improvements
  • Loading branch information
hanhou authored Dec 24, 2024
2 parents 9ea3e81 + c8ae95b commit 42026c8
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 49 deletions.
49 changes: 32 additions & 17 deletions code/Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def show_curriculums():
pass

# ------- Layout starts here -------- #
def init(if_load_docDB=True):
def init(if_load_docDB_override=None):

# Clear specific session state and all filters
for key in st.session_state:
Expand Down Expand Up @@ -454,7 +454,14 @@ def _get_data_source(rig):


# --- Load data from docDB ---
if if_load_docDB:
if_load_docDb = if_load_docDB_override if if_load_docDB_override is not None else (
st.query_params['if_load_docDB'].lower() == 'true'
if 'if_load_docDB' in st.query_params
else st.session_state.if_load_docDB
if 'if_load_docDB' in st.session_state
else False)

if if_load_docDb:
_df = merge_in_df_docDB(_df)

# add docDB_status column
Expand Down Expand Up @@ -525,25 +532,33 @@ def app():
# with col1:
# -- 1. unit dataframe --

cols = st.columns([2, 2, 4, 1])
cols = st.columns([2, 4, 1])
cols[0].markdown(f'### Filter the sessions on the sidebar\n'
f'##### {len(st.session_state.df_session_filtered)} sessions, '
f'{len(st.session_state.df_session_filtered.h2o.unique())} mice filtered')

if_load_bpod_sessions = checkbox_wrapper_for_url_query(
st_prefix=cols[1],
label='Include old Bpod sessions (reload after change)',
key='if_load_bpod_sessions',
default=False,
)

with cols[1]:
if st.button(' Reload data ', type='primary'):
st.cache_data.clear()
init()
st.rerun()
with cols[1]:
with st.form(key='load_settings', clear_on_submit=False):
if_load_bpod_sessions = checkbox_wrapper_for_url_query(
st_prefix=st,
label='Include old Bpod sessions (reload after change)',
key='if_load_bpod_sessions',
default=False,
)
if_load_docDB = checkbox_wrapper_for_url_query(
st_prefix=st,
label='Load metadata from docDB (reload after change)',
key='if_load_docDB',
default=False,
)

submitted = st.form_submit_button("Reload data! 🔄", type='primary')
if submitted:
st.cache_data.clear()
sync_session_state_to_URL()
init()
st.rerun() # Reload the page to apply the changes

table_height = slider_wrapper_for_url_query(st_prefix=cols[3],
table_height = slider_wrapper_for_url_query(st_prefix=cols[2],
label='Table height',
min_value=0,
max_value=2000,
Expand Down
2 changes: 1 addition & 1 deletion code/pages/1_Basic behavior analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,6 @@ def _plot_histograms(df, column, bins, use_kernel_smooth, use_density):
return fig

if "df" not in st.session_state or "sessions_bonsai" not in st.session_state.df.keys():
init(if_load_docDB=False)
init(if_load_docDB_override=False)

app()
61 changes: 32 additions & 29 deletions code/util/plot_autotrain_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,33 @@
from aind_auto_train.schema.curriculum import TrainingStage


def plot_manager_all_progress(manager: 'AutoTrainManager',
x_axis: ['session', 'date',
'relative_date'] = 'session', # type: ignore
sort_by: ['subject_id', 'first_date',
'last_date', 'progress_to_graduated'] = 'subject_id',
sort_order: ['ascending',
'descending'] = 'descending',
recent_days: int=None,
marker_size=10,
marker_edge_width=2,
highlight_subjects=[],
if_show_fig=True
):


@st.cache_data(ttl=3600 * 24)
def plot_manager_all_progress(
x_axis: ["session", "date", "relative_date"] = "session", # type: ignore
sort_by: [
"subject_id",
"first_date",
"last_date",
"progress_to_graduated",
] = "subject_id",
sort_order: ["ascending", "descending"] = "descending",
recent_days: int = None,
marker_size=10,
marker_edge_width=2,
highlight_subjects=[],
if_show_fig=True,
):

manager = st.session_state.auto_train_manager

# %%
# Set default order
df_manager = manager.df_manager.sort_values(by=['subject_id', 'session'],
ascending=[sort_order == 'ascending', False])

if not len(df_manager):
return None

# Get some additional metadata from the master table
df_tmp_rig_user_name = st.session_state.df['sessions_bonsai'].loc[:, ['subject_id', 'session_date', 'rig', 'user_name']]
df_tmp_rig_user_name.session_date = df_tmp_rig_user_name.session_date.astype(str)
Expand All @@ -51,18 +55,18 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
elif sort_by == 'progress_to_graduated':
manager.compute_stats()
df_stats = manager.df_manager_stats

# Sort by 'first_entry' of GRADUATED
subject_ids = df_stats.reset_index().set_index(
'subject_id'
).query(
f'current_stage_actual == "GRADUATED"'
)['first_entry'].sort_values(
ascending=sort_order != 'ascending').index.to_list()

# Append subjects that have not graduated
subject_ids = subject_ids + [s for s in df_manager.subject_id.unique() if s not in subject_ids]

else:
raise ValueError(
f'sort_by must be in {["subject_id", "first_date", "last_date", "progress"]}')
Expand All @@ -71,17 +75,17 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
traces = []
for n, subject_id in enumerate(subject_ids):
df_subject = df_manager[df_manager['subject_id'] == subject_id]

# Get stage_color_mapper
stage_color_mapper = get_stage_color_mapper(stage_list=list(TrainingStage.__members__))

# Get h2o if available
if 'h2o' in manager.df_behavior:
h2o = manager.df_behavior[
manager.df_behavior['subject_id'] == subject_id]['h2o'].iloc[0]
else:
h2o = None

df_subject = df_subject.merge(
df_tmp_rig_user_name,
on=['subject_id', 'session_date'], how='left')
Expand All @@ -105,11 +109,11 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
else:
raise ValueError(
f"x_axis can only be in ['session', 'date', 'relative_date']")

# Cache x range
xrange_min = x.min() if n == 0 else min(x.min(), xrange_min)
xrange_max = x.max() if n == 0 else max(x.max(), xrange_max)

y = len(subject_ids) - n # Y axis

traces.append(go.Scattergl(
Expand Down Expand Up @@ -159,7 +163,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
showlegend=False
)
)

# Add "x" for open loop sessions
traces.append(go.Scattergl(
x=x[open_loop_ids],
Expand Down Expand Up @@ -197,14 +201,14 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
),
yaxis_range=[-0.5, len(subject_ids) + 1],
)

# Limit x range to recent days if x is "date"
if x_axis == 'date' and recent_days is not None:
# xrange_max = pd.Timestamp.today() # For unknown reasons, using this line will break both plotly_events and new st.plotly_chart callback...
xrange_max = pd.to_datetime(df_manager.session_date).max() + pd.Timedelta(days=1)
xrange_min = xrange_max - pd.Timedelta(days=recent_days)
fig.update_layout(xaxis_range=[xrange_min, xrange_max])

# Highight the selected subject
for n, subject_id in enumerate(subject_ids):
y = len(subject_ids) - n # Y axis
Expand All @@ -222,7 +226,6 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
opacity=0.3,
layer="below"
)


# Show the plot
if if_show_fig:
Expand Down
3 changes: 1 addition & 2 deletions code/util/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def add_auto_train_manager():
recent_weeks = slider_wrapper_for_url_query(cols[5],
label="only recent weeks",
min_value=1,
max_value=26,
max_value=52,
step=1,
key='auto_training_history_recent_weeks',
default=8,
Expand All @@ -808,7 +808,6 @@ def add_auto_train_manager():
highlight_subjects = []

fig_auto_train = plot_manager_all_progress(
st.session_state.auto_train_manager,
x_axis=x_axis,
recent_days=recent_weeks*7,
sort_by=sort_by,
Expand Down
1 change: 1 addition & 0 deletions code/util/url_query_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# Note: When creating the widget, add argument "value"/"index" as well as "key" for all widgets you want to sync with URL
to_sync_with_url_query_default = {
"if_load_bpod_sessions": False,
"if_load_docDB": False,
"to_filter_columns": [
"subject_id",
"task",
Expand Down

0 comments on commit 42026c8

Please sign in to comment.