diff --git a/nipype/info.py b/nipype/info.py index bce47c3e3a..47d765b34e 100644 --- a/nipype/info.py +++ b/nipype/info.py @@ -153,6 +153,7 @@ def get_nipype_gitversion(): TESTS_REQUIRES = [ "coverage >= 5.2.1", + "pandas >= 1.5.0", "pytest >= 6", "pytest-cov >=2.11", "pytest-env", diff --git a/nipype/pipeline/plugins/tests/test_callback.py b/nipype/pipeline/plugins/tests/test_callback.py index f7606708c7..b10238ec4a 100644 --- a/nipype/pipeline/plugins/tests/test_callback.py +++ b/nipype/pipeline/plugins/tests/test_callback.py @@ -1,8 +1,9 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -"""Tests for workflow callbacks -""" +"""Tests for workflow callbacks.""" +from pathlib import Path from time import sleep +import json import pytest import nipype.interfaces.utility as niu import nipype.pipeline.engine as pe @@ -60,3 +61,51 @@ def test_callback_exception(tmpdir, plugin, stop_on_first_crash): sleep(0.5) # Wait for callback to be called (python 2.7) assert so.statuses == [("f_node", "start"), ("f_node", "exception")] + + +@pytest.mark.parametrize("plugin", ["Linear", "MultiProc", "LegacyMultiProc"]) +def test_callback_gantt(tmp_path: Path, plugin: str) -> None: + import logging + + from os import path + + from nipype.utils.profiler import log_nodes_cb + from nipype.utils.draw_gantt_chart import generate_gantt_chart + + log_filename = tmp_path / "callback.log" + logger = logging.getLogger("callback") + logger.setLevel(logging.DEBUG) + handler = logging.FileHandler(log_filename) + logger.addHandler(handler) + + # create workflow + wf = pe.Workflow(name="test", base_dir=str(tmp_path)) + f_node = pe.Node( + niu.Function(function=func, input_names=[], output_names=[]), name="f_node" + ) + wf.add_nodes([f_node]) + wf.config["execution"] = {"crashdump_dir": wf.base_dir, "poll_sleep_duration": 2} + + plugin_args = {"status_callback": log_nodes_cb} + if plugin != "Linear": + plugin_args["n_procs"] = 8 + wf.run(plugin=plugin, plugin_args=plugin_args) + + with open(log_filename, "r") as _f: + loglines = _f.readlines() + + # test missing duration + first_line = json.loads(loglines[0]) + if "duration" in first_line: + del first_line["duration"] + loglines[0] = f"{json.dumps(first_line)}\n" + + # test duplicate timestamp warning + loglines.append(loglines[-1]) + + with open(log_filename, "w") as _f: + _f.write("".join(loglines)) + + with pytest.warns(Warning): + generate_gantt_chart(str(log_filename), 1 if plugin == "Linear" else 8) + assert (tmp_path / "callback.log.html").exists() diff --git a/nipype/utils/draw_gantt_chart.py b/nipype/utils/draw_gantt_chart.py index 3ae4b77246..64a0d793db 100644 --- a/nipype/utils/draw_gantt_chart.py +++ b/nipype/utils/draw_gantt_chart.py @@ -8,8 +8,10 @@ import random import datetime import simplejson as json +from typing import Union from collections import OrderedDict +from warnings import warn # Pandas try: @@ -66,9 +68,9 @@ def create_event_dict(start_time, nodes_list): finish_delta = (node["finish"] - start_time).total_seconds() # Populate dictionary - if events.get(start_delta) or events.get(finish_delta): + if events.get(start_delta): err_msg = "Event logged twice or events started at exact same time!" - raise KeyError(err_msg) + warn(err_msg, category=Warning) events[start_delta] = start_node events[finish_delta] = finish_node @@ -101,15 +103,25 @@ def log_to_dict(logfile): nodes_list = [json.loads(l) for l in lines] - def _convert_string_to_datetime(datestring): - try: + def _convert_string_to_datetime( + datestring: Union[str, datetime.datetime], + ) -> datetime.datetime: + """Convert a date string to a datetime object.""" + if isinstance(datestring, datetime.datetime): + datetime_object = datestring + elif isinstance(datestring, str): + date_format = ( + "%Y-%m-%dT%H:%M:%S.%f%z" + if "+" in datestring + else "%Y-%m-%dT%H:%M:%S.%f" + ) datetime_object: datetime.datetime = datetime.datetime.strptime( - datestring, "%Y-%m-%dT%H:%M:%S.%f" + datestring, date_format ) - return datetime_object - except Exception as _: - pass - return datestring + else: + msg = f"{datestring} is not a string or datetime object." + raise TypeError(msg) + return datetime_object date_object_node_list: list = list() for n in nodes_list: @@ -154,12 +166,18 @@ def calculate_resource_timeseries(events, resource): # Iterate through the events for _, event in sorted(events.items()): if event["event"] == "start": - if resource in event and event[resource] != "Unknown": - all_res += float(event[resource]) + if resource in event: + try: + all_res += float(event[resource]) + except ValueError: + continue current_time = event["start"] elif event["event"] == "finish": - if resource in event and event[resource] != "Unknown": - all_res -= float(event[resource]) + if resource in event: + try: + all_res -= float(event[resource]) + except ValueError: + continue current_time = event["finish"] res[current_time] = all_res @@ -284,7 +302,14 @@ def draw_nodes(start, nodes_list, cores, minute_scale, space_between_minutes, co # Left left = 60 for core in range(len(end_times)): - if end_times[core] < node_start: + try: + end_time_condition = end_times[core] < node_start + except TypeError: + # if one has a timezone and one does not + end_time_condition = end_times[core].replace( + tzinfo=None + ) < node_start.replace(tzinfo=None) + if end_time_condition: left += core * 30 end_times[core] = datetime.datetime( node_finish.year, @@ -307,7 +332,7 @@ def draw_nodes(start, nodes_list, cores, minute_scale, space_between_minutes, co "offset": offset, "scale_duration": scale_duration, "color": color, - "node_name": node["name"], + "node_name": node.get("name", node.get("id", "")), "node_dur": node["duration"] / 60.0, "node_start": node_start.strftime("%Y-%m-%d %H:%M:%S"), "node_finish": node_finish.strftime("%Y-%m-%d %H:%M:%S"), @@ -527,6 +552,25 @@ def generate_gantt_chart( # Read in json-log to get list of node dicts nodes_list = log_to_dict(logfile) + # Only include nodes with timing information, and convert timestamps + # from strings to datetimes + nodes_list = [ + { + k: ( + datetime.datetime.strptime(i[k], "%Y-%m-%dT%H:%M:%S.%f") + if k in {"start", "finish"} and isinstance(i[k], str) + else i[k] + ) + for k in i + } + for i in nodes_list + if "start" in i and "finish" in i + ] + + for node in nodes_list: + if "duration" not in node: + node["duration"] = (node["finish"] - node["start"]).total_seconds() + # Create the header of the report with useful information start_node = nodes_list[0] last_node = nodes_list[-1]