Skip to content

Commit

Permalink
Merge pull request #544 from gpetretto/devel
Browse files Browse the repository at this point in the history
jsanitize fireworks Task
  • Loading branch information
utf authored Feb 14, 2024
2 parents c211030 + 0debb38 commit eda2a65
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/jobflow/managers/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import typing

from fireworks import FiretaskBase, Firework, FWAction, Workflow, explicit_serialize
from fireworks.utilities.fw_serializers import recursive_serialize, serialize_fw
from monty.json import jsanitize

if typing.TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -197,3 +199,16 @@ def run_task(self, fw_spec):
defuse_workflow=response.stop_jobflow,
defuse_children=response.stop_children,
)

@serialize_fw
@recursive_serialize
def to_dict(self) -> dict:
"""
Serialize version of the FireTask.
Overrides the original method to explicitly jsanitize the Job
to handle cases not properly handled by fireworks, like a Callable.
"""
d = dict(self)
d["job"] = jsanitize(d["job"].as_dict())
return d
22 changes: 22 additions & 0 deletions tests/managers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,25 @@ def _gen():
return Flow([replace, simple], simple.output, order=JobOrder.LINEAR)

return _gen


@pytest.fixture(scope="session")
def maker_with_callable():
from dataclasses import dataclass
from typing import Callable

from jobflow.core.job import job
from jobflow.core.maker import Maker

global TestCallableMaker

@dataclass
class TestCallableMaker(Maker):
f: Callable
name: str = "TestCallableMaker"

@job
def make(self, a, b):
return self.f([a, b])

return TestCallableMaker
33 changes: 33 additions & 0 deletions tests/managers/test_fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,36 @@ def test_external_reference(lpad, mongo_jobstore, fw_dir, simple_job, capsys):
# check response
result2 = mongo_jobstore.query_one({"uuid": uuid2})
assert result2["output"] == "12345_end_end"


def test_maker_flow(lpad, mongo_jobstore, fw_dir, maker_with_callable, capsys):
from fireworks.core.rocket_launcher import rapidfire

from jobflow.core.flow import Flow
from jobflow.managers.fireworks import flow_to_workflow

j = maker_with_callable(f=sum).make(a=1, b=2)

flow = Flow([j])
uuid = flow[0].uuid

wf = flow_to_workflow(flow, mongo_jobstore)
fw_ids = lpad.add_wf(wf)

# run the workflow
rapidfire(lpad)

# check workflow completed
fw_id = next(iter(fw_ids.values()))
wf = lpad.get_wf_by_fw_id(fw_id)

assert all(s == "COMPLETED" for s in wf.fw_states.values())

# check store has the activity output
result = mongo_jobstore.query_one({"uuid": uuid})
assert result["output"] == 3

# check logs printed
captured = capsys.readouterr()
assert "INFO Starting job - TestCallableMaker" in captured.out
assert "INFO Finished job - TestCallableMaker" in captured.out

0 comments on commit eda2a65

Please sign in to comment.