diff --git a/psiflow/metrics.py b/psiflow/metrics.py index cc3cae0..83eeb2d 100644 --- a/psiflow/metrics.py +++ b/psiflow/metrics.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import NamedTuple, Optional, Union -import numpy as np import typeguard import wandb from parsl.app.app import python_app @@ -18,6 +17,25 @@ logger = logging.getLogger(__name__) # logging per module +@typeguard.typechecked +def _trace_identifier( + identifier_traces: dict, + state: FlowAtoms, + iteration: int, + walker_index: int, + nsteps: int, +) -> dict: + if not state == NullState: # same checks as sampling.py:assign_identifier + if state.reference_status: + identifier = state.info["identifier"] + assert identifier not in identifier_traces + identifier_traces[identifier] = (iteration, walker_index, nsteps) + return identifier_traces + + +trace_identifier = python_app(_trace_identifier, executors=["Default"]) + + @typeguard.typechecked def _save_walker_logs(data: dict[str, list], path: Path) -> str: from prettytable import PrettyTable @@ -247,10 +265,13 @@ def _to_wandb( wandb_project: str, walker_logs: Optional[dict], dataset_log: Optional[dict], + identifier_traces: Optional[dict], ): import os import tempfile + import numpy as np + import pandas as pd import plotly.express as px import wandb @@ -275,14 +296,41 @@ def _to_wandb( fix_plotly_layout(figure) figure.update_layout(yaxis_title="forces RMSE [meV/A]") figures[title] = figure + + # convert dataset_log to pandas dataframe and add identifier tracing + customdata = [] + identifiers = dataset_log["identifier"] + for index in identifiers: + customdata.append(identifier_traces.get(index, (np.nan,) * 3)) + customdata = np.array(customdata, dtype=np.float32) + dataset_log["iteration"] = customdata[:, 0] + dataset_log["walker_index"] = customdata[:, 1] + dataset_log["nsteps"] = customdata[:, 2] + df = pd.DataFrame.from_dict(dataset_log) + if dataset_log is not None: for x_axis in dataset_log: if x_axis.startswith("CV") or (x_axis == "identifier"): for y_axis in dataset_log: if (y_axis == "e_rmse") or y_axis.startswith("f_rmse"): - x = dataset_log[x_axis] - y = dataset_log[y_axis] - figure = px.scatter(x=x, y=y) + figure = px.scatter( + data_frame=df, + x=x_axis, + y=y_axis, + custom_data=["iteration", "walker_index", "nsteps"], + color="nsteps", + color_continuous_scale=px.colors.sequential.Turbo, + # color_continuous_scale=['darkcyan', 'darkgoldenrod', 'crimson'], + ) + figure.update_traces( + marker={"size": 8}, + hovertemplate=( + "iteration: %{customdata[0]}
" + + "walker index: %{customdata[1]}
" + + "steps: %{customdata[2]}
" + + "" + ), + ) figure.update_xaxes(type="linear") title = "dataset_" + y_axis + "_" + x_axis fix_plotly_layout(figure) @@ -331,6 +379,8 @@ def __init__( resume=resume, ) self.walker_logs = [] + self.identifier_traces = {} + self.iteration = 0 def as_dict(self): return { @@ -368,6 +418,13 @@ def log_walker( **metadata_dict, ) self.walker_logs.append(log) + self.identifier_traces = trace_identifier( + self.identifier_traces, + state, + self.iteration, + i, + metadata.counter, + ) def save( self, @@ -390,7 +447,11 @@ def save( dataset_log = log_dataset(inputs=inputs) save_dataset_log(dataset_log, path / "dataset.log") if self.wandb_group is not None: - # somehow, assignment is necessary to ensure execution of app - f = to_wandb( # noqa: F841 - self.wandb_id, self.wandb_project, walker_logs, dataset_log + # typically needs a result() from caller + return to_wandb( # noqa: F841 + self.wandb_id, + self.wandb_project, + walker_logs, + dataset_log, + self.identifier_traces, ) diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 1cf881d..93c8db0 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -1,6 +1,6 @@ import numpy as np -import parsl +import psiflow from psiflow.committee import Committee from psiflow.data import FlowAtoms from psiflow.metrics import Metrics, log_dataset @@ -65,8 +65,7 @@ def test_sample_metrics(mace_model, dataset, tmp_path): assert sum([a is None for a in dataset_log["CV1"]]) == 6 - 1 assert len(metrics.walker_logs) == len(walkers) - metrics.save(tmp_path, model=mace_model, dataset=data) - parsl.wait_for_current_tasks() + metrics.save(tmp_path, model=mace_model, dataset=data).result() assert (tmp_path / "walkers.log").exists() assert (tmp_path / "dataset.log").exists() @@ -111,7 +110,7 @@ def test_sample_committee(gpu, mace_config, dataset, tmp_path): assert data[i].result().info["identifier"] <= 2 assert len(metrics.walker_logs) == len(walkers) metrics.save(tmp_path) - parsl.wait_for_current_tasks() + psiflow.wait() assert (tmp_path / "walkers.log").exists() with open(tmp_path / "walkers.log", "r") as f: print(f.read())