Skip to content

Commit

Permalink
add operating_directory
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Mar 20, 2023
1 parent 16f895e commit 2dcf7ba
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 1 deletion.
122 changes: 121 additions & 1 deletion zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import dataclasses
import importlib
import logging
import os
import pathlib
import shutil
import typing

import dvc.api
Expand All @@ -15,11 +17,23 @@
import znjson

from zntrack.notebooks.jupyter import jupyter_class_to_file
from zntrack.utils import NodeStatusResults, deprecated, module_handler, run_dvc_cmd
from zntrack.utils import (
NodeStatusResults,
convert_to_list,
deprecated,
module_handler,
run_dvc_cmd,
update_gitignore,
)
from zntrack.utils.config import config
from zntrack.utils.node_wd import move_nwd

log = logging.getLogger(__name__)

EXCEPTION_OR_LST_EXCEPTIONS = typing.Union[
typing.Type[Exception], typing.Collection[typing.Type[Exception]], None
]


@dataclasses.dataclass
class NodeStatus:
Expand Down Expand Up @@ -210,6 +224,112 @@ def write_graph(self, run: bool = False, **kwargs):
if run:
run_dvc_cmd(["repro", self.name])

@contextlib.contextmanager
def operating_directory(
self,
prefix="ckpt",
remove_on: EXCEPTION_OR_LST_EXCEPTIONS = None,
move_on: EXCEPTION_OR_LST_EXCEPTIONS = Exception,
disable: bool = None,
) -> bool:
"""Work in a temporary operating directory until successfully finished.
This context manager will replace the path of the node working
directory $nwd$ with a temporary operating directory 'prefix_$nwd$'
and move the files to $nwd$, when successfully finished.
This can be useful, when you are running, e.g., on hardware
with limited execution time and can't use 'dvc checkpoints'.
When successfully finished, all files will be moved from 'prefix_$nwd$' to $nwd$.
You can call 'dvc repro' multiple times to continue from 'prefix_$nwd$'.
If used properly this will result in reproducible data, but:
- checkpoints will not be removed if parameters change. Always remove a
checkpoint, when running with new parameters!
- checkpoints are not versioned. If you want to checkpoint, e.g., model training,
use 'dvc checkpoints'.
Parameters
----------
prefix: str, default = 'ckpt'
Prefix for the temporary directory
remove_on: Exception or list of Exceptions, default = None
If one of the exceptions in 'remove_on' is raised, the temporary
operating directory will be removed. Otherwise, it will remain
and reused upon the next run.
move_on: Exception, default = Exception
If one of the exceptions in 'move_on' is raised, the temporary
operating directories content is moved to $nwd$ and the temporary
directory will be removed. This helps, in the case of an error,
to not restart from an already failed data point.
The default is set to move the files if any Exception occurs.
disable: bool, default = False
Disable the temporary operating directory. Yields True.
Yields
------
new_ckpt: bool
True if creating a new checkpoint. False if the checkpoint already existed.
"""
if disable is None:
disable = config.disable_operating_directory
if disable:
yield True
return

nwd = self.nwd
nwd_new = self.nwd.with_name(f"{prefix}_{self.nwd.name}")
nwd_is_new = not nwd_new.exists()

remove = False
move = False
remove_on = convert_to_list(remove_on)
move_on = convert_to_list(move_on)

if self._run_and_save:
update_gitignore(prefix=prefix)

if nwd_is_new:
log.info(f"Creating new operating directory: {nwd_new}")
log.warning(
"Experimental Feature: operating directory is currently not"
" compatible with 'dvc exp --temp' or 'dvc exp --queue'"
)
# TODO add a unique path per node.
# TODO check on windows!
shutil.copytree(nwd, nwd_new, copy_function=os.link)
else:
log.info(f"Continuing inside operating directory: {nwd_new}.")

self.nwd = nwd_new
try:
yield nwd_is_new
except Exception as err:
log.warning("Node execution was interrupted.")
remove = any(isinstance(err, e) for e in remove_on)
move = any(isinstance(err, e) for e in move_on)
# finally -> ...
raise err
finally:
# Save e.g. `zn.outs` before stopping.
self.save(results=True)
self.nwd = nwd
if remove:
log.info(f"Removing operating directory: {nwd_new}")
shutil.rmtree(nwd_new)
elif move:
log.info(f"Moving files from '{nwd_new}' to {nwd}")
move_nwd(nwd_new, nwd)

log.info(f"Finished successfully. Moving files from {nwd_new} to {nwd}")
move_nwd(nwd_new, nwd)
else:
# if not inside 'run_and_save' no directory should be created. ?!?!?!
self.nwd = nwd_new
try:
yield nwd_is_new
finally:
self.nwd = nwd


def get_dvc_cmd(
node: Node,
Expand Down
25 changes: 25 additions & 0 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import unittest.mock

import dvc.cli
import git

from zntrack.utils import cli
from zntrack.utils.config import config
Expand Down Expand Up @@ -204,3 +205,27 @@ def cwd_temp_dir(required_files=None) -> tempfile.TemporaryDirectory:
os.chdir(temp_dir.name)

return temp_dir


def convert_to_list(value) -> list:
"""Convert value to a list if it is not already one.
If 'value is None', return an empty list.
"""
if not isinstance(value, (list, tuple)):
return [] if value is None else [value]
return value


def update_gitignore(prefix) -> None:
"""Add 'nodes/<prefix>_*' to the gitignore file, if not already there."""
ignore = f"nodes/{prefix}_*"

repo = git.Repo(".")
if repo.ignored(ignore):
return

gitignore = pathlib.Path(".gitignore")
with gitignore.open("a", encoding="utf-8") as file:
file.write("\n# ZnTrack operating directory \n")
file.write(f"{ignore}\n")

0 comments on commit 2dcf7ba

Please sign in to comment.