diff --git a/code/Home.py b/code/Home.py index 14f2433..31b8ecb 100644 --- a/code/Home.py +++ b/code/Home.py @@ -280,7 +280,7 @@ def show_curriculums(): pass # ------- Layout starts here -------- # -def init(if_load_docDB=True): +def init(if_load_docDB_override=None): # Clear specific session state and all filters for key in st.session_state: @@ -454,7 +454,14 @@ def _get_data_source(rig): # --- Load data from docDB --- - if if_load_docDB: + if_load_docDb = if_load_docDB_override if if_load_docDB_override is not None else ( + st.query_params['if_load_docDB'].lower() == 'true' + if 'if_load_docDB' in st.query_params + else st.session_state.if_load_docDB + if 'if_load_docDB' in st.session_state + else False) + + if if_load_docDb: _df = merge_in_df_docDB(_df) # add docDB_status column @@ -525,25 +532,33 @@ def app(): # with col1: # -- 1. unit dataframe -- - cols = st.columns([2, 2, 4, 1]) + cols = st.columns([2, 4, 1]) cols[0].markdown(f'### Filter the sessions on the sidebar\n' f'##### {len(st.session_state.df_session_filtered)} sessions, ' f'{len(st.session_state.df_session_filtered.h2o.unique())} mice filtered') - - if_load_bpod_sessions = checkbox_wrapper_for_url_query( - st_prefix=cols[1], - label='Include old Bpod sessions (reload after change)', - key='if_load_bpod_sessions', - default=False, - ) - - with cols[1]: - if st.button(' Reload data ', type='primary'): - st.cache_data.clear() - init() - st.rerun() + with cols[1]: + with st.form(key='load_settings', clear_on_submit=False): + if_load_bpod_sessions = checkbox_wrapper_for_url_query( + st_prefix=st, + label='Include old Bpod sessions (reload after change)', + key='if_load_bpod_sessions', + default=False, + ) + if_load_docDB = checkbox_wrapper_for_url_query( + st_prefix=st, + label='Load metadata from docDB (reload after change)', + key='if_load_docDB', + default=False, + ) + + submitted = st.form_submit_button("Reload data! 🔄", type='primary') + if submitted: + st.cache_data.clear() + sync_session_state_to_URL() + init() + st.rerun() # Reload the page to apply the changes - table_height = slider_wrapper_for_url_query(st_prefix=cols[3], + table_height = slider_wrapper_for_url_query(st_prefix=cols[2], label='Table height', min_value=0, max_value=2000, diff --git a/code/pages/1_Basic behavior analysis.py b/code/pages/1_Basic behavior analysis.py index 9c3b752..1ecd2c8 100644 --- a/code/pages/1_Basic behavior analysis.py +++ b/code/pages/1_Basic behavior analysis.py @@ -377,6 +377,6 @@ def _plot_histograms(df, column, bins, use_kernel_smooth, use_density): return fig if "df" not in st.session_state or "sessions_bonsai" not in st.session_state.df.keys(): - init(if_load_docDB=False) + init(if_load_docDB_override=False) app() diff --git a/code/util/plot_autotrain_manager.py b/code/util/plot_autotrain_manager.py index addf0b9..f5e0209 100644 --- a/code/util/plot_autotrain_manager.py +++ b/code/util/plot_autotrain_manager.py @@ -8,29 +8,33 @@ from aind_auto_train.schema.curriculum import TrainingStage -def plot_manager_all_progress(manager: 'AutoTrainManager', - x_axis: ['session', 'date', - 'relative_date'] = 'session', # type: ignore - sort_by: ['subject_id', 'first_date', - 'last_date', 'progress_to_graduated'] = 'subject_id', - sort_order: ['ascending', - 'descending'] = 'descending', - recent_days: int=None, - marker_size=10, - marker_edge_width=2, - highlight_subjects=[], - if_show_fig=True - ): - - +@st.cache_data(ttl=3600 * 24) +def plot_manager_all_progress( + x_axis: ["session", "date", "relative_date"] = "session", # type: ignore + sort_by: [ + "subject_id", + "first_date", + "last_date", + "progress_to_graduated", + ] = "subject_id", + sort_order: ["ascending", "descending"] = "descending", + recent_days: int = None, + marker_size=10, + marker_edge_width=2, + highlight_subjects=[], + if_show_fig=True, +): + + manager = st.session_state.auto_train_manager + # %% # Set default order df_manager = manager.df_manager.sort_values(by=['subject_id', 'session'], ascending=[sort_order == 'ascending', False]) - + if not len(df_manager): return None - + # Get some additional metadata from the master table df_tmp_rig_user_name = st.session_state.df['sessions_bonsai'].loc[:, ['subject_id', 'session_date', 'rig', 'user_name']] df_tmp_rig_user_name.session_date = df_tmp_rig_user_name.session_date.astype(str) @@ -51,7 +55,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', elif sort_by == 'progress_to_graduated': manager.compute_stats() df_stats = manager.df_manager_stats - + # Sort by 'first_entry' of GRADUATED subject_ids = df_stats.reset_index().set_index( 'subject_id' @@ -59,10 +63,10 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', f'current_stage_actual == "GRADUATED"' )['first_entry'].sort_values( ascending=sort_order != 'ascending').index.to_list() - + # Append subjects that have not graduated subject_ids = subject_ids + [s for s in df_manager.subject_id.unique() if s not in subject_ids] - + else: raise ValueError( f'sort_by must be in {["subject_id", "first_date", "last_date", "progress"]}') @@ -71,17 +75,17 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', traces = [] for n, subject_id in enumerate(subject_ids): df_subject = df_manager[df_manager['subject_id'] == subject_id] - + # Get stage_color_mapper stage_color_mapper = get_stage_color_mapper(stage_list=list(TrainingStage.__members__)) - + # Get h2o if available if 'h2o' in manager.df_behavior: h2o = manager.df_behavior[ manager.df_behavior['subject_id'] == subject_id]['h2o'].iloc[0] else: h2o = None - + df_subject = df_subject.merge( df_tmp_rig_user_name, on=['subject_id', 'session_date'], how='left') @@ -105,11 +109,11 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', else: raise ValueError( f"x_axis can only be in ['session', 'date', 'relative_date']") - + # Cache x range xrange_min = x.min() if n == 0 else min(x.min(), xrange_min) xrange_max = x.max() if n == 0 else max(x.max(), xrange_max) - + y = len(subject_ids) - n # Y axis traces.append(go.Scattergl( @@ -159,7 +163,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', showlegend=False ) ) - + # Add "x" for open loop sessions traces.append(go.Scattergl( x=x[open_loop_ids], @@ -197,14 +201,14 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', ), yaxis_range=[-0.5, len(subject_ids) + 1], ) - + # Limit x range to recent days if x is "date" if x_axis == 'date' and recent_days is not None: # xrange_max = pd.Timestamp.today() # For unknown reasons, using this line will break both plotly_events and new st.plotly_chart callback... xrange_max = pd.to_datetime(df_manager.session_date).max() + pd.Timedelta(days=1) xrange_min = xrange_max - pd.Timedelta(days=recent_days) fig.update_layout(xaxis_range=[xrange_min, xrange_max]) - + # Highight the selected subject for n, subject_id in enumerate(subject_ids): y = len(subject_ids) - n # Y axis @@ -222,7 +226,6 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', opacity=0.3, layer="below" ) - # Show the plot if if_show_fig: diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 5ef2b06..d1a3bec 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -791,7 +791,7 @@ def add_auto_train_manager(): recent_weeks = slider_wrapper_for_url_query(cols[5], label="only recent weeks", min_value=1, - max_value=26, + max_value=52, step=1, key='auto_training_history_recent_weeks', default=8, @@ -808,7 +808,6 @@ def add_auto_train_manager(): highlight_subjects = [] fig_auto_train = plot_manager_all_progress( - st.session_state.auto_train_manager, x_axis=x_axis, recent_days=recent_weeks*7, sort_by=sort_by, diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index f9cc99f..e8fd43f 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -10,6 +10,7 @@ # Note: When creating the widget, add argument "value"/"index" as well as "key" for all widgets you want to sync with URL to_sync_with_url_query_default = { "if_load_bpod_sessions": False, + "if_load_docDB": False, "to_filter_columns": [ "subject_id", "task",