From 0d1686835fe3fbb23009217a136fff6c42492b06 Mon Sep 17 00:00:00 2001 From: PythonFZ Date: Tue, 28 Mar 2023 18:01:07 +0200 Subject: [PATCH] bring operating directory back --- tests/integration/test_operating_directory.py | 196 ++++++++++++++++++ zntrack/core/node.py | 78 ++++--- 2 files changed, 231 insertions(+), 43 deletions(-) create mode 100644 tests/integration/test_operating_directory.py diff --git a/tests/integration/test_operating_directory.py b/tests/integration/test_operating_directory.py new file mode 100644 index 00000000..d82c2b96 --- /dev/null +++ b/tests/integration/test_operating_directory.py @@ -0,0 +1,196 @@ +import pathlib +import subprocess +import time + +import pytest + +from zntrack import Node, dvc, exceptions, utils, zn + + +class ListOfDataNode(Node): + data: list = zn.outs() + + def run(self): + with self.operating_directory(move_on=None) as new: + if new: + self.data = list(range(5)) + raise ValueError("Execution was interrupted") + else: + self.data += list(range(5, 10)) + + +class RestartFromCheckpoint(Node): + file: pathlib.Path = dvc.outs(utils.nwd / "out.txt") + + def run(self): + with self.operating_directory(move_on=None) as new: + if new: + self.file.write_text("Hello") + raise ValueError("Execution was interrupted") + else: + text = self.file.read_text() + self.file.write_text(f"{text} there") + + +class RemoveOnError(Node): + data: list = zn.outs() + + def run(self): + with self.operating_directory(remove_on=(TypeError, ValueError)): + raise ValueError("Execution was interrupted") + + +class MoveOnError(Node): + data: list = zn.outs() + + def run(self): + with self.operating_directory(move_on=(TypeError, ValueError)): + self.data = list(range(5)) + raise ValueError("Execution was interrupted") + + +def test_remove_on_error(proj_path): + node = RemoveOnError() + node.write_graph() + + node = RemoveOnError.load() + with pytest.raises(ValueError): + node.run() + node.save() + + nwd_new = node.nwd.with_name(f"ckpt_{node.nwd.name}") + assert not nwd_new.exists() + + node = node.load() + with pytest.raises(exceptions.DataNotAvailableError): + _ = node.data + + +def test_move_on_error(proj_path): + node = MoveOnError() + node.write_graph() + + node = MoveOnError.load() + with pytest.raises(ValueError): + node.run() + node.save() + + nwd_new = node.nwd.with_name(f"ckpt_{node.nwd.name}") + assert not nwd_new.exists() + # the file exists but even if 'run_and_save' fails so the output can be investigated. + node = node.load() + assert node.data == list(range(5)) + + +def test_ListOfDataNode(proj_path): + ListOfDataNode().write_graph() + + node = ListOfDataNode.load() + + with pytest.raises(exceptions.DVCProcessError): + utils.run_dvc_cmd(["repro"]) + nwd_new = node.nwd.with_name(f"ckpt_{node.nwd.name}") + assert nwd_new.exists() + utils.run_dvc_cmd(["repro"]) + assert not nwd_new.exists() + + assert ListOfDataNode.load().data == list(range(10)) + + +def test_ListOfDataNode2(proj_path): + node = ListOfDataNode() + node.write_graph() + + with pytest.raises(ValueError): + node.run() + node.save() + node.run() + node.save() + + assert node.load().data == list(range(10)) + + +def test_RestartFromCheckpoint(proj_path): + RestartFromCheckpoint().write_graph() + + with pytest.raises(exceptions.DVCProcessError): + utils.run_dvc_cmd(["repro"]) + + assert not RestartFromCheckpoint.load().file.exists() + utils.run_dvc_cmd(["repro"]) + + assert RestartFromCheckpoint.load().file.read_text() == "Hello there" + + +def test_RestartFromCheckpoint2(proj_path): + node = RestartFromCheckpoint() + node.write_graph() + + with pytest.raises(ValueError): + node.run() + node.save() + + node.run() + node.save() + + assert node.load().file.read_text() == "Hello there" + + +class WriteNumbersSlow(Node): + outs = zn.outs() + maximum = zn.params() + + def run(self): + with self.operating_directory() as new: + if new: + self.outs = [] + for x in range(len(self.outs), self.maximum): + print(x) + self.outs.append(x) + self.save(results=True) + time.sleep(0.1) + + +def test_kill_process(proj_path): + node = WriteNumbersSlow(maximum=15) + node.write_graph() + nwd_new = node.nwd.with_name(f"ckpt_{node.nwd.name}") + # killing the DVC process will not kill the python process as it would on + # a cluster + proc = subprocess.Popen( + ["zntrack", "run", "test_operating_directory.WriteNumbersSlow"] + ) + time.sleep(2.0) + proc.kill() + assert nwd_new.exists() + + # zntrack.utils.run_dvc_cmd(["repro"]) + proc = subprocess.Popen( + ["zntrack", "run", "test_operating_directory.WriteNumbersSlow"] + ) + proc.wait() + assert not nwd_new.exists() + assert node.load().outs == list(range(15)) + + # and now check again without killing + node = WriteNumbersSlow(maximum=10) + node.write_graph() + zntrack.utils.run_dvc_cmd(["repro"]) + assert not nwd_new.exists() + assert node.load().outs == list(range(10)) + + +def test_disable_operating_directory(proj_path): + ListOfDataNode().write_graph() + with utils.config.updated_config(disable_operating_directory=True): + node = ListOfDataNode.load() + + with pytest.raises(ValueError): + ListOfDataNode.load().run_and_save() + with pytest.raises(ValueError): # running it twice does not change the outcome + ListOfDataNode.load().run_and_save() + assert node.nwd.exists() + assert not node.nwd.with_name(f"ckpt_{node.nwd.name}").exists() + + with pytest.raises(exceptions.DataNotAvailableError): + _ = ListOfDataNode.load().data diff --git a/zntrack/core/node.py b/zntrack/core/node.py index 18687835..a8590408 100644 --- a/zntrack/core/node.py +++ b/zntrack/core/node.py @@ -291,50 +291,42 @@ def operating_directory( remove_on = convert_to_list(remove_on) move_on = convert_to_list(move_on) - if self._run_and_save: - update_gitignore(prefix=prefix) - - if nwd_is_new: - log.info(f"Creating new operating directory: {nwd_new}") - log.warning( - "Experimental Feature: operating directory is currently not" - " compatible with 'dvc exp --temp' or 'dvc exp --queue'" - ) - # TODO add a unique path per node. - # TODO check on windows! - shutil.copytree(nwd, nwd_new, copy_function=os.link) - else: - log.info(f"Continuing inside operating directory: {nwd_new}.") - - self.nwd = nwd_new - try: - yield nwd_is_new - except Exception as err: - log.warning("Node execution was interrupted.") - remove = any(isinstance(err, e) for e in remove_on) - move = any(isinstance(err, e) for e in move_on) - # finally -> ... - raise err - finally: - # Save e.g. `zn.outs` before stopping. - self.save(results=True) - self.nwd = nwd - if remove: - log.info(f"Removing operating directory: {nwd_new}") - shutil.rmtree(nwd_new) - elif move: - log.info(f"Moving files from '{nwd_new}' to {nwd}") - move_nwd(nwd_new, nwd) - - log.info(f"Finished successfully. Moving files from {nwd_new} to {nwd}") - move_nwd(nwd_new, nwd) + update_gitignore(prefix=prefix) + + if nwd_is_new: + log.info(f"Creating new operating directory: {nwd_new}") + log.warning( + "Experimental Feature: operating directory is currently not" + " compatible with 'dvc exp --temp' or 'dvc exp --queue'" + ) + # TODO add a unique path per node. + # TODO check on windows! + shutil.copytree(nwd, nwd_new, copy_function=os.link) else: - # if not inside 'run_and_save' no directory should be created. ?!?!?! - self.nwd = nwd_new - try: - yield nwd_is_new - finally: - self.nwd = nwd + log.info(f"Continuing inside operating directory: {nwd_new}.") + + self.nwd = nwd_new + try: + yield nwd_is_new + except Exception as err: + log.warning("Node execution was interrupted.") + remove = any(isinstance(err, e) for e in remove_on) + move = any(isinstance(err, e) for e in move_on) + # finally -> ... + raise err + finally: + # Save e.g. `zn.outs` before stopping. + self.save(results=True) + self.nwd = nwd + if remove: + log.info(f"Removing operating directory: {nwd_new}") + shutil.rmtree(nwd_new) + elif move: + log.info(f"Moving files from '{nwd_new}' to {nwd}") + move_nwd(nwd_new, nwd) + + log.info(f"Finished successfully. Moving files from {nwd_new} to {nwd}") + move_nwd(nwd_new, nwd) def get_dvc_cmd(