diff --git a/psiflow/metrics.py b/psiflow/metrics.py
index cc3cae0..83eeb2d 100644
--- a/psiflow/metrics.py
+++ b/psiflow/metrics.py
@@ -5,7 +5,6 @@
from pathlib import Path
from typing import NamedTuple, Optional, Union
-import numpy as np
import typeguard
import wandb
from parsl.app.app import python_app
@@ -18,6 +17,25 @@
logger = logging.getLogger(__name__) # logging per module
+@typeguard.typechecked
+def _trace_identifier(
+ identifier_traces: dict,
+ state: FlowAtoms,
+ iteration: int,
+ walker_index: int,
+ nsteps: int,
+) -> dict:
+ if not state == NullState: # same checks as sampling.py:assign_identifier
+ if state.reference_status:
+ identifier = state.info["identifier"]
+ assert identifier not in identifier_traces
+ identifier_traces[identifier] = (iteration, walker_index, nsteps)
+ return identifier_traces
+
+
+trace_identifier = python_app(_trace_identifier, executors=["Default"])
+
+
@typeguard.typechecked
def _save_walker_logs(data: dict[str, list], path: Path) -> str:
from prettytable import PrettyTable
@@ -247,10 +265,13 @@ def _to_wandb(
wandb_project: str,
walker_logs: Optional[dict],
dataset_log: Optional[dict],
+ identifier_traces: Optional[dict],
):
import os
import tempfile
+ import numpy as np
+ import pandas as pd
import plotly.express as px
import wandb
@@ -275,14 +296,41 @@ def _to_wandb(
fix_plotly_layout(figure)
figure.update_layout(yaxis_title="forces RMSE [meV/A]")
figures[title] = figure
+
+ # convert dataset_log to pandas dataframe and add identifier tracing
+ customdata = []
+ identifiers = dataset_log["identifier"]
+ for index in identifiers:
+ customdata.append(identifier_traces.get(index, (np.nan,) * 3))
+ customdata = np.array(customdata, dtype=np.float32)
+ dataset_log["iteration"] = customdata[:, 0]
+ dataset_log["walker_index"] = customdata[:, 1]
+ dataset_log["nsteps"] = customdata[:, 2]
+ df = pd.DataFrame.from_dict(dataset_log)
+
if dataset_log is not None:
for x_axis in dataset_log:
if x_axis.startswith("CV") or (x_axis == "identifier"):
for y_axis in dataset_log:
if (y_axis == "e_rmse") or y_axis.startswith("f_rmse"):
- x = dataset_log[x_axis]
- y = dataset_log[y_axis]
- figure = px.scatter(x=x, y=y)
+ figure = px.scatter(
+ data_frame=df,
+ x=x_axis,
+ y=y_axis,
+ custom_data=["iteration", "walker_index", "nsteps"],
+ color="nsteps",
+ color_continuous_scale=px.colors.sequential.Turbo,
+ # color_continuous_scale=['darkcyan', 'darkgoldenrod', 'crimson'],
+ )
+ figure.update_traces(
+ marker={"size": 8},
+ hovertemplate=(
+ "iteration: %{customdata[0]}
"
+ + "walker index: %{customdata[1]}
"
+ + "steps: %{customdata[2]}
"
+ + ""
+ ),
+ )
figure.update_xaxes(type="linear")
title = "dataset_" + y_axis + "_" + x_axis
fix_plotly_layout(figure)
@@ -331,6 +379,8 @@ def __init__(
resume=resume,
)
self.walker_logs = []
+ self.identifier_traces = {}
+ self.iteration = 0
def as_dict(self):
return {
@@ -368,6 +418,13 @@ def log_walker(
**metadata_dict,
)
self.walker_logs.append(log)
+ self.identifier_traces = trace_identifier(
+ self.identifier_traces,
+ state,
+ self.iteration,
+ i,
+ metadata.counter,
+ )
def save(
self,
@@ -390,7 +447,11 @@ def save(
dataset_log = log_dataset(inputs=inputs)
save_dataset_log(dataset_log, path / "dataset.log")
if self.wandb_group is not None:
- # somehow, assignment is necessary to ensure execution of app
- f = to_wandb( # noqa: F841
- self.wandb_id, self.wandb_project, walker_logs, dataset_log
+ # typically needs a result() from caller
+ return to_wandb( # noqa: F841
+ self.wandb_id,
+ self.wandb_project,
+ walker_logs,
+ dataset_log,
+ self.identifier_traces,
)
diff --git a/tests/test_sampling.py b/tests/test_sampling.py
index 1cf881d..93c8db0 100644
--- a/tests/test_sampling.py
+++ b/tests/test_sampling.py
@@ -1,6 +1,6 @@
import numpy as np
-import parsl
+import psiflow
from psiflow.committee import Committee
from psiflow.data import FlowAtoms
from psiflow.metrics import Metrics, log_dataset
@@ -65,8 +65,7 @@ def test_sample_metrics(mace_model, dataset, tmp_path):
assert sum([a is None for a in dataset_log["CV1"]]) == 6 - 1
assert len(metrics.walker_logs) == len(walkers)
- metrics.save(tmp_path, model=mace_model, dataset=data)
- parsl.wait_for_current_tasks()
+ metrics.save(tmp_path, model=mace_model, dataset=data).result()
assert (tmp_path / "walkers.log").exists()
assert (tmp_path / "dataset.log").exists()
@@ -111,7 +110,7 @@ def test_sample_committee(gpu, mace_config, dataset, tmp_path):
assert data[i].result().info["identifier"] <= 2
assert len(metrics.walker_logs) == len(walkers)
metrics.save(tmp_path)
- parsl.wait_for_current_tasks()
+ psiflow.wait()
assert (tmp_path / "walkers.log").exists()
with open(tmp_path / "walkers.log", "r") as f:
print(f.read())