diff --git a/src/jobflow/core/store.py b/src/jobflow/core/store.py index fc68e223..3b3bb770 100644 --- a/src/jobflow/core/store.py +++ b/src/jobflow/core/store.py @@ -444,6 +444,26 @@ def remove_docs(self, criteria: dict): store.remove_docs({"job_uuid": doc["uuid"], "job_index": doc["index"]}) self.docs_store.remove_docs(criteria) + + def __eq__(self, other: object) -> bool: + """ + Check equality for JobStore. + + Args: + other: other JobStore to compare with. + """ + if not isinstance(other, JobStore): + return False + + fields = ["docs_store", "save", "load"] + + # Check equality of all additional_stores + if self.additional_stores == other.additional_stores: + return all(getattr(self, f) == getattr(other, f) for f in fields) + + return False + + def get_output( self, uuid: str, diff --git a/src/jobflow/managers/fireworks.py b/src/jobflow/managers/fireworks.py index 5cc32129..9407c02e 100644 --- a/src/jobflow/managers/fireworks.py +++ b/src/jobflow/managers/fireworks.py @@ -2,9 +2,13 @@ from __future__ import annotations +import os import typing from fireworks import FiretaskBase, Firework, FWAction, Workflow, explicit_serialize +from fireworks.fw_config import DS_PASSWORD +from fireworks.utilities.fw_utilities import DataServer +from maggma.stores.shared_stores import StoreFacade if typing.TYPE_CHECKING: from typing import Sequence @@ -151,7 +155,16 @@ def run_task(self, fw_spec): if store is None: store = SETTINGS.JOB_STORE - store.connect() + if "FW_DATASERVER_PORT" in os.environ: + ds = DataServer( + address=("127.0.0.1", int(os.environ["FW_DATASERVER_PORT"])), + authkey=DS_PASSWORD, + ) + ds.connect() + multistore = ds.MultiStore() + store = StoreFacade(store, multistore) + else: + store.connect() if hasattr(self, "fw_id"): job.metadata.update({"fw_id": self.fw_id})