Skip to content

Commit

Permalink
add iteration and walker info when hovering over data in wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Oct 28, 2023
1 parent d1fce7c commit bd314ae
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 11 deletions.
75 changes: 68 additions & 7 deletions psiflow/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path
from typing import NamedTuple, Optional, Union

import numpy as np
import typeguard
import wandb
from parsl.app.app import python_app
Expand All @@ -18,6 +17,25 @@
logger = logging.getLogger(__name__) # logging per module


@typeguard.typechecked
def _trace_identifier(
identifier_traces: dict,
state: FlowAtoms,
iteration: int,
walker_index: int,
nsteps: int,
) -> dict:
if not state == NullState: # same checks as sampling.py:assign_identifier
if state.reference_status:
identifier = state.info["identifier"]
assert identifier not in identifier_traces
identifier_traces[identifier] = (iteration, walker_index, nsteps)
return identifier_traces


trace_identifier = python_app(_trace_identifier, executors=["Default"])


@typeguard.typechecked
def _save_walker_logs(data: dict[str, list], path: Path) -> str:
from prettytable import PrettyTable
Expand Down Expand Up @@ -247,10 +265,13 @@ def _to_wandb(
wandb_project: str,
walker_logs: Optional[dict],
dataset_log: Optional[dict],
identifier_traces: Optional[dict],
):
import os
import tempfile

import numpy as np
import pandas as pd
import plotly.express as px
import wandb

Expand All @@ -275,14 +296,41 @@ def _to_wandb(
fix_plotly_layout(figure)
figure.update_layout(yaxis_title="forces RMSE [meV/A]")
figures[title] = figure

# convert dataset_log to pandas dataframe and add identifier tracing
customdata = []
identifiers = dataset_log["identifier"]
for index in identifiers:
customdata.append(identifier_traces.get(index, (np.nan,) * 3))
customdata = np.array(customdata, dtype=np.float32)
dataset_log["iteration"] = customdata[:, 0]
dataset_log["walker_index"] = customdata[:, 1]
dataset_log["nsteps"] = customdata[:, 2]
df = pd.DataFrame.from_dict(dataset_log)

if dataset_log is not None:
for x_axis in dataset_log:
if x_axis.startswith("CV") or (x_axis == "identifier"):
for y_axis in dataset_log:
if (y_axis == "e_rmse") or y_axis.startswith("f_rmse"):
x = dataset_log[x_axis]
y = dataset_log[y_axis]
figure = px.scatter(x=x, y=y)
figure = px.scatter(
data_frame=df,
x=x_axis,
y=y_axis,
custom_data=["iteration", "walker_index", "nsteps"],
color="nsteps",
color_continuous_scale=px.colors.sequential.Turbo,
# color_continuous_scale=['darkcyan', 'darkgoldenrod', 'crimson'],
)
figure.update_traces(
marker={"size": 8},
hovertemplate=(
"<b>iteration</b>: %{customdata[0]}<br>"
+ "<b>walker index</b>: %{customdata[1]}<br>"
+ "<b>steps</b>: %{customdata[2]}<br>"
+ "<extra></extra>"
),
)
figure.update_xaxes(type="linear")
title = "dataset_" + y_axis + "_" + x_axis
fix_plotly_layout(figure)
Expand Down Expand Up @@ -331,6 +379,8 @@ def __init__(
resume=resume,
)
self.walker_logs = []
self.identifier_traces = {}
self.iteration = 0

def as_dict(self):
return {
Expand Down Expand Up @@ -368,6 +418,13 @@ def log_walker(
**metadata_dict,
)
self.walker_logs.append(log)
self.identifier_traces = trace_identifier(
self.identifier_traces,
state,
self.iteration,
i,
metadata.counter,
)

def save(
self,
Expand All @@ -390,7 +447,11 @@ def save(
dataset_log = log_dataset(inputs=inputs)
save_dataset_log(dataset_log, path / "dataset.log")
if self.wandb_group is not None:
# somehow, assignment is necessary to ensure execution of app
f = to_wandb( # noqa: F841
self.wandb_id, self.wandb_project, walker_logs, dataset_log
# typically needs a result() from caller
return to_wandb( # noqa: F841
self.wandb_id,
self.wandb_project,
walker_logs,
dataset_log,
self.identifier_traces,
)
7 changes: 3 additions & 4 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import parsl

import psiflow
from psiflow.committee import Committee
from psiflow.data import FlowAtoms
from psiflow.metrics import Metrics, log_dataset
Expand Down Expand Up @@ -65,8 +65,7 @@ def test_sample_metrics(mace_model, dataset, tmp_path):
assert sum([a is None for a in dataset_log["CV1"]]) == 6 - 1

assert len(metrics.walker_logs) == len(walkers)
metrics.save(tmp_path, model=mace_model, dataset=data)
parsl.wait_for_current_tasks()
metrics.save(tmp_path, model=mace_model, dataset=data).result()
assert (tmp_path / "walkers.log").exists()
assert (tmp_path / "dataset.log").exists()

Expand Down Expand Up @@ -111,7 +110,7 @@ def test_sample_committee(gpu, mace_config, dataset, tmp_path):
assert data[i].result().info["identifier"] <= 2
assert len(metrics.walker_logs) == len(walkers)
metrics.save(tmp_path)
parsl.wait_for_current_tasks()
psiflow.wait()
assert (tmp_path / "walkers.log").exists()
with open(tmp_path / "walkers.log", "r") as f:
print(f.read())
Expand Down

0 comments on commit bd314ae

Please sign in to comment.