From 4be27678fb076d0eb323e66f9d4cf5a570eae5c8 Mon Sep 17 00:00:00 2001 From: Reagan Lee Date: Tue, 15 Aug 2023 19:12:44 +0000 Subject: [PATCH 01/10] add metrics arg in command --- elk/plotting/command.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/elk/plotting/command.py b/elk/plotting/command.py index e79dc165..27126b6b 100644 --- a/elk/plotting/command.py +++ b/elk/plotting/command.py @@ -22,6 +22,9 @@ class Plot: overwrite: bool = False """Whether to overwrite existing plots.""" + metrics: list[str] = field(default_factory=list) + """Name of metric to plot""" + def execute(self): root_dir = sweeps_dir() @@ -35,6 +38,9 @@ def execute(self): else: sweep_paths = [root_dir / sweep for sweep in self.sweeps] + if not self.metrics: + self.metrics = ["auroc_estimate"] + for sweep_path in sweep_paths: if not sweep_path.exists(): pretty_error(f"No sweep with name {{{sweep_path}}} found in {root_dir}") @@ -47,4 +53,6 @@ def execute(self): if self.overwrite: shutil.rmtree(sweep_path / "viz") - visualize_sweep(sweep_path) + assert len(self.metrics) == 1, "Multiple metrics at a time aren't supported yet. Re-run plot for each metric separately." + + visualize_sweep(sweep_path, self.metrics) From f5864d8b8dacc614c023a8a36dd7f1d9156d7cef Mon Sep 17 00:00:00 2001 From: Reagan Lee Date: Thu, 17 Aug 2023 10:15:42 +0000 Subject: [PATCH 02/10] rename var --- elk/plotting/command.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/elk/plotting/command.py b/elk/plotting/command.py index 27126b6b..16661257 100644 --- a/elk/plotting/command.py +++ b/elk/plotting/command.py @@ -22,7 +22,7 @@ class Plot: overwrite: bool = False """Whether to overwrite existing plots.""" - metrics: list[str] = field(default_factory=list) + metric_types: list[str] = field(default_factory=list) """Name of metric to plot""" def execute(self): @@ -38,8 +38,8 @@ def execute(self): else: sweep_paths = [root_dir / sweep for sweep in self.sweeps] - if not self.metrics: - self.metrics = ["auroc_estimate"] + if not self.metric_types: + self.metric_types = ["auroc_estimate"] for sweep_path in sweep_paths: if not sweep_path.exists(): @@ -53,6 +53,6 @@ def execute(self): if self.overwrite: shutil.rmtree(sweep_path / "viz") - assert len(self.metrics) == 1, "Multiple metrics at a time aren't supported yet. Re-run plot for each metric separately." + assert len(self.metric_types) == 1, "Multiple metrics at a time aren't supported yet. Re-run plot for each metric separately." - visualize_sweep(sweep_path, self.metrics) + visualize_sweep(sweep_path, self.metric_types) From 2553c6e3755a3f3e81f1acacb26545c51daaac23 Mon Sep 17 00:00:00 2001 From: Reagan Lee Date: Thu, 17 Aug 2023 11:13:31 +0000 Subject: [PATCH 03/10] allow just one metric at a time --- elk/plotting/command.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/elk/plotting/command.py b/elk/plotting/command.py index 16661257..3adba5f9 100644 --- a/elk/plotting/command.py +++ b/elk/plotting/command.py @@ -22,7 +22,7 @@ class Plot: overwrite: bool = False """Whether to overwrite existing plots.""" - metric_types: list[str] = field(default_factory=list) + metric_type: str = None """Name of metric to plot""" def execute(self): @@ -38,8 +38,9 @@ def execute(self): else: sweep_paths = [root_dir / sweep for sweep in self.sweeps] - if not self.metric_types: - self.metric_types = ["auroc_estimate"] + if not self.metric_type: + # ArgumentParser maps cli input --metric to metric_type + self.metric_type = "auroc_estimate" for sweep_path in sweep_paths: if not sweep_path.exists(): @@ -53,6 +54,4 @@ def execute(self): if self.overwrite: shutil.rmtree(sweep_path / "viz") - assert len(self.metric_types) == 1, "Multiple metrics at a time aren't supported yet. Re-run plot for each metric separately." - - visualize_sweep(sweep_path, self.metric_types) + visualize_sweep(sweep_path, self.metric_type) From 8ef905a2b6d6895628094d6f9066ef550058430d Mon Sep 17 00:00:00 2001 From: Reagan Lee Date: Thu, 17 Aug 2023 11:22:51 +0000 Subject: [PATCH 04/10] rename --- elk/plotting/visualize.py | 40 ++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/elk/plotting/visualize.py b/elk/plotting/visualize.py index fa183e5a..52e13cd1 100644 --- a/elk/plotting/visualize.py +++ b/elk/plotting/visualize.py @@ -114,7 +114,7 @@ class TransferEvalHeatmap: """Class for generating heatmaps for transfer evaluation results.""" layer: int - score_type: str = "auroc_estimate" + metric_type: str = "auroc_estimate" ensembling: str = "full" def render(self, df: pd.DataFrame) -> go.Figure: @@ -129,7 +129,7 @@ def render(self, df: pd.DataFrame) -> go.Figure: model_name = df["eval_dataset"].iloc[0] # infer model name # TODO: validate pivot = pd.pivot_table( - df, values=self.score_type, index="eval_dataset", columns="train_dataset" + df, values=self.metric_type, index="eval_dataset", columns="train_dataset" ) fig = px.imshow(pivot, color_continuous_scale="Viridis", text_auto=True) @@ -137,7 +137,7 @@ def render(self, df: pd.DataFrame) -> go.Figure: fig.update_layout( xaxis_title="Train Dataset", yaxis_title="Transfer Dataset", - title=f"AUROC Score Heatmap: {model_name} | Layer {self.layer}", + title=f"{self.metric_type} Score Heatmap: {model_name} | Layer {self.layer}", ) return fig @@ -149,7 +149,7 @@ class TransferEvalTrend: evaluation.""" dataset_names: list[str] | None - score_type: str = "auroc_estimate" + metric_type: str = "auroc_estimate" def render(self, df: pd.DataFrame) -> go.Figure: """Render the trend plot visualization. @@ -164,14 +164,14 @@ def render(self, df: pd.DataFrame) -> go.Figure: if self.dataset_names is not None: df = self._filter_transfer_datasets(df, self.dataset_names) pivot = pd.pivot_table( - df, values=self.score_type, index="layer", columns="eval_dataset" + df, values=self.metric_type, index="layer", columns="eval_dataset" ) fig = px.line(pivot, color_discrete_sequence=px.colors.qualitative.Plotly) fig.update_layout( xaxis_title="Layer", - yaxis_title="AUROC Score", - title=f"AUROC Score Trend: {model_name}", + yaxis_title=f"{self.metric_type} Score", + title=f"{self.metric_type} Score Trend: {model_name}", ) avg = pivot.mean(axis=1) @@ -244,7 +244,7 @@ def render_and_save( self, sweep: "SweepVisualization", dataset_names: list[str] | None = None, - score_type="auroc_estimate", + metric_type="auroc_estimate", ensembling="full", ) -> None: """Render and save the visualization for the model. @@ -252,9 +252,10 @@ def render_and_save( Args: sweep: The SweepVisualization instance. dataset_names: List of dataset names to include in the visualization. - score_type: The type of score to display. + metric_type: The type of score to display. ensembling: The ensembling option to consider. """ + metric_type = sweep.metric_type df = self.df model_name = self.model_name layer_min, layer_max = df["layer"].min(), df["layer"].max() @@ -264,10 +265,10 @@ def render_and_save( for layer in range(layer_min, layer_max + 1): filtered = df[(df["layer"] == layer) & (df["ensembling"] == ensembling)] fig = TransferEvalHeatmap( - layer, score_type=score_type, ensembling=ensembling + layer, metric_type=metric_type, ensembling=ensembling ).render(filtered) fig.write_image(file=model_path / f"{layer}.png") - fig = TransferEvalTrend(dataset_names).render(df) + fig = TransferEvalTrend(dataset_names, metric_type=metric_type).render(df) fig.write_image(file=model_path / "transfer_eval_trend.png") @staticmethod @@ -288,6 +289,7 @@ class SweepVisualization: path: Path datasets: list[str] models: dict[str, ModelVisualization] + metric_type: str def model_names(self) -> list[str]: """Get the names of all models in the sweep. @@ -323,7 +325,7 @@ def _get_model_paths(sweep_path: Path) -> list[Path]: return folders @classmethod - def collect(cls, sweep_path: Path) -> "SweepVisualization": + def collect(cls, sweep_path: Path, metric_type: str) -> "SweepVisualization": """Collect the evaluation data for a sweep. Args: @@ -348,7 +350,7 @@ def collect(cls, sweep_path: Path) -> "SweepVisualization": } df = pd.concat([model.df for model in models.values()], ignore_index=True) datasets = list(df["eval_dataset"].unique()) - return cls(sweep_name, df, sweep_viz_path, datasets, models) + return cls(sweep_name, df, sweep_viz_path, datasets, models, metric_type=metric_type) def render_and_save(self): """Render and save all visualizations for the sweep.""" @@ -369,13 +371,13 @@ def render_multiplots(self, write=False): ] def render_table( - self, score_type="auroc_estimate", display=True, write=False + self, metric_type="auroc_estimate", display=True, write=False ) -> pd.DataFrame: """Render and optionally write the score table. Args: layer: The layer number (from last layer) to include in the score table. - score_type: The type of score to include in the table. + metric_type: The type of score to include in the table. display: Flag indicating whether to display the table to stdout. write: Flag indicating whether to write the table to a file. @@ -395,7 +397,7 @@ def render_table( pivot_table = pd.concat(model_dfs).pivot_table( index="eval_dataset", columns="model_name", - values=score_type, + values=metric_type, margins=True, margins_name="Mean", ) @@ -416,14 +418,14 @@ def render_table( console.print(table) if write: - pivot_table.to_csv(f"score_table_{score_type}.csv") + pivot_table.to_csv(f"score_table_{metric_type}.csv") return pivot_table -def visualize_sweep(sweep_path: Path): +def visualize_sweep(sweep_path: Path, metric_type: str): """Visualize a sweep by generating and saving the visualizations. Args: sweep_path: The path to the sweep data directory. """ - SweepVisualization.collect(sweep_path).render_and_save() + SweepVisualization.collect(sweep_path, metric_type).render_and_save() From 54a9310827afb38480349f1d26167574ff4acd00 Mon Sep 17 00:00:00 2001 From: Reagan Lee Date: Thu, 17 Aug 2023 11:32:50 +0000 Subject: [PATCH 05/10] selectable metric_type support for layer viz --- elk/plotting/visualize.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/elk/plotting/visualize.py b/elk/plotting/visualize.py index 52e13cd1..3a5b4af1 100644 --- a/elk/plotting/visualize.py +++ b/elk/plotting/visualize.py @@ -47,7 +47,7 @@ def render( shared_yaxes=True, vertical_spacing=0.1, x_title="Layer", - y_title="AUROC", + y_title=f"{sweep.metric_type}", ) color_map = dict(zip(ensembles, qualitative.Plotly)) @@ -371,13 +371,12 @@ def render_multiplots(self, write=False): ] def render_table( - self, metric_type="auroc_estimate", display=True, write=False + self, display=True, write=False ) -> pd.DataFrame: """Render and optionally write the score table. Args: layer: The layer number (from last layer) to include in the score table. - metric_type: The type of score to include in the table. display: Flag indicating whether to display the table to stdout. write: Flag indicating whether to write the table to a file. @@ -389,7 +388,7 @@ def render_table( # For each model, we use the layer whose mean AUROC is the highest best_layers, model_dfs = [], [] for _, model_df in df.groupby("model_name"): - best_layer = model_df.groupby("layer").auroc_estimate.mean().argmax() + best_layer = model_df.groupby("layer")[self.metric_type].mean().argmax() best_layers.append(best_layer) model_dfs.append(model_df[model_df["layer"] == best_layer]) @@ -397,7 +396,7 @@ def render_table( pivot_table = pd.concat(model_dfs).pivot_table( index="eval_dataset", columns="model_name", - values=metric_type, + values=self.metric_type, margins=True, margins_name="Mean", ) @@ -418,7 +417,7 @@ def render_table( console.print(table) if write: - pivot_table.to_csv(f"score_table_{metric_type}.csv") + pivot_table.to_csv(f"score_table_{self.metric_type}.csv") return pivot_table From 4deb51bfcad2822512d24839d33049e60f6283c3 Mon Sep 17 00:00:00 2001 From: Reagan Lee Date: Thu, 17 Aug 2023 12:01:12 +0000 Subject: [PATCH 06/10] switch out rest of auroc with metric_type --- elk/plotting/visualize.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/elk/plotting/visualize.py b/elk/plotting/visualize.py index 3a5b4af1..1a5c3d5e 100644 --- a/elk/plotting/visualize.py +++ b/elk/plotting/visualize.py @@ -56,7 +56,7 @@ def render( if with_transfer: # TODO write tests ensemble_data = ensemble_data.groupby( ["eval_dataset", "layer", "ensembling"], as_index=False - ).agg({"auroc_estimate": "mean"}) + ).agg({f"{sweep.metric_type}": "mean"}) else: ensemble_data = ensemble_data[ ensemble_data["eval_dataset"] == ensemble_data["train_dataset"] @@ -75,7 +75,7 @@ def render( fig.add_trace( go.Scatter( x=dataset_data["layer"], - y=dataset_data["auroc_estimate"], + y=dataset_data[f"{sweep.metric_type}"], mode="lines", name=ensemble, showlegend=False @@ -95,7 +95,7 @@ def render( legend=dict( title="Ensembling", ), - title=f"AUROC Trend: {self.model_name}", + title=f"{sweep.metric_type} Trend: {self.model_name}", ) if write: fig.write_image( @@ -114,7 +114,7 @@ class TransferEvalHeatmap: """Class for generating heatmaps for transfer evaluation results.""" layer: int - metric_type: str = "auroc_estimate" + metric_type: str = None ensembling: str = "full" def render(self, df: pd.DataFrame) -> go.Figure: @@ -145,11 +145,11 @@ def render(self, df: pd.DataFrame) -> go.Figure: @dataclass class TransferEvalTrend: - """Class for generating line plots for the trend of AUROC scores in transfer + """Class for generating line plots for the trend of metric scores in transfer evaluation.""" dataset_names: list[str] | None - metric_type: str = "auroc_estimate" + metric_type: str = None def render(self, df: pd.DataFrame) -> go.Figure: """Render the trend plot visualization. @@ -244,7 +244,6 @@ def render_and_save( self, sweep: "SweepVisualization", dataset_names: list[str] | None = None, - metric_type="auroc_estimate", ensembling="full", ) -> None: """Render and save the visualization for the model. @@ -252,7 +251,6 @@ def render_and_save( Args: sweep: The SweepVisualization instance. dataset_names: List of dataset names to include in the visualization. - metric_type: The type of score to display. ensembling: The ensembling option to consider. """ metric_type = sweep.metric_type From 0b6abf49574443999ef98c684d514af72a5bc299 Mon Sep 17 00:00:00 2001 From: Reagan Lee Date: Thu, 17 Aug 2023 12:01:28 +0000 Subject: [PATCH 07/10] add functionality to readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index fa74d900..f5f3bf4c 100644 --- a/README.md +++ b/README.md @@ -57,11 +57,11 @@ together. You can also add a `--visualize` flag to visualize the results of the elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled ``` -If you just do `elk plot`, it will plot the results from the most recent sweep. -If you want to plot a specific sweep, you can do so with: +If you just do `elk plot`, it will plot the results of AUROC from the most recent sweep. +If you want to plot a specific sweep, with a specific metric type, you can do so with: ```bash -elk plot {sweep_name} +elk plot {sweep_name} --metric acc_estimate ``` ## Caching From 179066cac16daa1d2a8fc17e81c79ea9e49d53b8 Mon Sep 17 00:00:00 2001 From: Reagan Lee Date: Fri, 18 Aug 2023 07:53:43 +0000 Subject: [PATCH 08/10] add functionality to sweep / cleanup --- elk/plotting/command.py | 6 +----- elk/plotting/visualize.py | 4 ++-- elk/training/sweep.py | 5 ++++- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/elk/plotting/command.py b/elk/plotting/command.py index 3adba5f9..db80319d 100644 --- a/elk/plotting/command.py +++ b/elk/plotting/command.py @@ -22,7 +22,7 @@ class Plot: overwrite: bool = False """Whether to overwrite existing plots.""" - metric_type: str = None + metric_type: str = "auroc_estimate" """Name of metric to plot""" def execute(self): @@ -38,10 +38,6 @@ def execute(self): else: sweep_paths = [root_dir / sweep for sweep in self.sweeps] - if not self.metric_type: - # ArgumentParser maps cli input --metric to metric_type - self.metric_type = "auroc_estimate" - for sweep_path in sweep_paths: if not sweep_path.exists(): pretty_error(f"No sweep with name {{{sweep_path}}} found in {root_dir}") diff --git a/elk/plotting/visualize.py b/elk/plotting/visualize.py index 1a5c3d5e..ce8e864d 100644 --- a/elk/plotting/visualize.py +++ b/elk/plotting/visualize.py @@ -114,7 +114,7 @@ class TransferEvalHeatmap: """Class for generating heatmaps for transfer evaluation results.""" layer: int - metric_type: str = None + metric_type: str = "" ensembling: str = "full" def render(self, df: pd.DataFrame) -> go.Figure: @@ -149,7 +149,7 @@ class TransferEvalTrend: evaluation.""" dataset_names: list[str] | None - metric_type: str = None + metric_type: str = "" def render(self, df: pd.DataFrame) -> go.Figure: """Render the trend plot visualization. diff --git a/elk/training/sweep.py b/elk/training/sweep.py index a4e5c97a..6a58f643 100755 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -50,6 +50,9 @@ class Sweep: visualize: bool = False """Whether to generate visualizations of the results of the sweep.""" + metric_type: str = "auroc_estimate" + """Name of metric to plot""" + name: str | None = None # A bit of a hack to add all the command line arguments from Elicit @@ -176,4 +179,4 @@ def execute(self): eval.execute(highlight_color="green") if self.visualize: - visualize_sweep(sweep_dir) + visualize_sweep(sweep_dir, self.metric_type) From b0cf33d0678ae644d4a5ca01c288c00613e22485 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Aug 2023 08:30:36 +0000 Subject: [PATCH 09/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/plotting/visualize.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/elk/plotting/visualize.py b/elk/plotting/visualize.py index ce8e864d..e3e44f38 100644 --- a/elk/plotting/visualize.py +++ b/elk/plotting/visualize.py @@ -253,7 +253,7 @@ def render_and_save( dataset_names: List of dataset names to include in the visualization. ensembling: The ensembling option to consider. """ - metric_type = sweep.metric_type + metric_type = sweep.metric_type df = self.df model_name = self.model_name layer_min, layer_max = df["layer"].min(), df["layer"].max() @@ -348,7 +348,9 @@ def collect(cls, sweep_path: Path, metric_type: str) -> "SweepVisualization": } df = pd.concat([model.df for model in models.values()], ignore_index=True) datasets = list(df["eval_dataset"].unique()) - return cls(sweep_name, df, sweep_viz_path, datasets, models, metric_type=metric_type) + return cls( + sweep_name, df, sweep_viz_path, datasets, models, metric_type=metric_type + ) def render_and_save(self): """Render and save all visualizations for the sweep.""" @@ -368,9 +370,7 @@ def render_multiplots(self, write=False): for model in self.models ] - def render_table( - self, display=True, write=False - ) -> pd.DataFrame: + def render_table(self, display=True, write=False) -> pd.DataFrame: """Render and optionally write the score table. Args: From 989f00f7d57a4a597a9abb8bb0b4f53c86673e24 Mon Sep 17 00:00:00 2001 From: Reagan Lee Date: Fri, 18 Aug 2023 08:33:18 +0000 Subject: [PATCH 10/10] line length --- elk/plotting/visualize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/elk/plotting/visualize.py b/elk/plotting/visualize.py index ce8e864d..e36985ad 100644 --- a/elk/plotting/visualize.py +++ b/elk/plotting/visualize.py @@ -137,7 +137,8 @@ def render(self, df: pd.DataFrame) -> go.Figure: fig.update_layout( xaxis_title="Train Dataset", yaxis_title="Transfer Dataset", - title=f"{self.metric_type} Score Heatmap: {model_name} | Layer {self.layer}", + title=f"{self.metric_type} Score Heatmap: {model_name} \ + | Layer {self.layer}", ) return fig