Skip to content

Commit

Permalink
Merge branch 'main' into add_operating_directory
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Mar 28, 2023
2 parents 2dcf7ba + 44deaa9 commit eba4873
Show file tree
Hide file tree
Showing 12 changed files with 324 additions and 103 deletions.
139 changes: 73 additions & 66 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ typer = "^0.7.0"

dot4dict = "^0.1.1"
zninit = "^0.1.9"
znflow = "^0.1.4"
znjson = "^0.2.2"
znflow = "^0.1.6"


[tool.poetry.urls]
Expand Down
59 changes: 59 additions & 0 deletions tests/integration/test_combine_lists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest

import zntrack


class GenerateList(zntrack.Node):
size = zntrack.zn.params(10)
outs = zntrack.zn.outs()

def run(self):
self.outs = list(range(self.size))


class AddOneToList(zntrack.Node):
data = zntrack.zn.deps()
outs = zntrack.zn.outs()

def run(self) -> None:
self.outs = [x + 1 for x in self.data]


class AddOneToDict(zntrack.Node):
data = zntrack.zn.deps()
outs = zntrack.zn.outs()

def run(self) -> None:
self.outs = {k: [x + 1 for x in v] for k, v in self.data.items()}


@pytest.mark.parametrize("eager", [True, False])
def test_combine(proj_path, eager):
with zntrack.Project() as proj:
a = GenerateList(size=1, name="a")
b = GenerateList(size=2, name="b")
c = GenerateList(size=3, name="c")

added = AddOneToList(data=a.outs + b.outs + c.outs)

proj.run(eager=eager)
if not eager:
added.load()

assert added.outs == [1] + [1, 2] + [1, 2, 3]


@pytest.mark.parametrize("eager", [True, False])
def test_combine_dict(proj_path, eager):
with zntrack.Project() as proj:
a = GenerateList(size=1, name="a")
b = GenerateList(size=2, name="b")
c = GenerateList(size=3, name="c")

added = AddOneToDict(data={x.name: x.outs for x in [a, b, c]})

proj.run(eager=eager)
if not eager:
added.load()

assert added.outs == {"a": [1], "b": [1, 2], "c": [1, 2, 3]}
26 changes: 26 additions & 0 deletions tests/integration/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import znflow

import zntrack


class NodeWithProperty(zntrack.Node):
params = zntrack.zn.params(None)

@property
def calc(self):
"""This should not change the params if not called."""
self.params = 42
return "calc"

def run(self):
pass


def test_NodeWithProperty(proj_path):
with zntrack.Project() as proj:
node = NodeWithProperty()

proj.run()

node.load()
assert node.params is None
16 changes: 14 additions & 2 deletions tests/integration/test_none_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,24 @@ def test_from_dvc_deps(proj_path, eager):


class EmptyNodesNode(zntrack.Node):
# we use dvc.outs to generate zntrack.json
file = zntrack.dvc.outs(zntrack.nwd / "file.txt")
nodes = zntrack.zn.nodes(None)
outs = zntrack.zn.outs()

def run(self):
pass
if self.nodes is None:
self.outs = 42
else:
self.outs = self.nodes.value
self.file.write_text("Hello World")


def test_EmptyNode(proj_path):
@pytest.mark.parametrize("eager", [True, False])
def test_EmptyNode(proj_path, eager):
with zntrack.Project() as project:
node = EmptyNodesNode()
project.run(eager=eager)
if not eager:
node.load()
assert node.outs == 42
16 changes: 16 additions & 0 deletions tests/integration/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,19 @@ def test_WriteIO_no_name(tmp_path_2, assert_before_exp):

assert exp2["WriteIO"].inputs == "Lorem Ipsum"
assert exp2["WriteIO"].outputs == "Lorem Ipsum"


def test_project_remove_graph(proj_path):
with zntrack.Project() as project:
node = WriteIO(inputs="Hello World")
project.run()
node.load()
assert node.outputs == "Hello World"

with zntrack.Project(remove_existing_graph=True) as project:
node2 = WriteIO(inputs="Lorem Ipsum", name="node2")
project.run()
node2.load()
assert node2.outputs == "Lorem Ipsum"
with pytest.raises(zntrack.exceptions.NodeNotAvailableError):
node.load()
3 changes: 2 additions & 1 deletion zntrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
import importlib.metadata

from zntrack import tools
from zntrack import exceptions, tools
from zntrack.core.node import Node
from zntrack.core.nodify import NodeConfig, nodify
from zntrack.fields import Field, FieldGroup, LazyField, dvc, meta, zn
Expand All @@ -28,4 +28,5 @@
"nodify",
"NodeConfig",
"tools",
"exceptions",
]
14 changes: 10 additions & 4 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import zninit
import znjson

from zntrack import exceptions
from zntrack.notebooks.jupyter import jupyter_class_to_file
from zntrack.utils import (
NodeStatusResults,
Expand Down Expand Up @@ -101,6 +102,8 @@ class Node(zninit.ZnInit, znflow.Node):
name: str = _NameDescriptor(None)
_name_ = None

_protected_ = znflow.Node._protected_ + ["name"]

def _post_load_(self) -> None:
"""Post load hook.
Expand Down Expand Up @@ -176,10 +179,13 @@ def load(self, lazy: bool = None) -> None:

kwargs = {} if lazy is None else {"lazy": lazy}
self.state.loaded = True # we assume loading will be successful.
with config.updated_config(**kwargs):
# TODO: it would be much nicer not to use a global config object here.
for attr in zninit.get_descriptors(Field, self=self):
attr.load(self)
try:
with config.updated_config(**kwargs):
# TODO: it would be much nicer not to use a global config object here.
for attr in zninit.get_descriptors(Field, self=self):
attr.load(self)
except KeyError as err:
raise exceptions.NodeNotAvailableError(self) from err

# TODO: documentation about _post_init and _post_load_ and when they are called
self._post_load_()
Expand Down
19 changes: 19 additions & 0 deletions zntrack/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""All ZnTrack exceptions."""


class NodeNotAvailableError(Exception):
"""Raised when a node is not available."""

def __init__(self, arg):
"""Initialize the exception.
Parameters
----------
arg : str|Node
Custom Error message or Node that is not available.
"""
if isinstance(arg, str):
super().__init__(arg)
else:
# assume arg is a Node
super().__init__(f"Node {arg.name} is not available.")
21 changes: 21 additions & 0 deletions zntrack/fields/field.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The base class for all fields."""
import abc
import contextlib
import enum
import json
import logging
Expand Down Expand Up @@ -149,9 +150,29 @@ def _write_value_to_config(self, value, instance: "Node", encoder=None):
json.dump(zntrack_dict, f, indent=4, cls=encoder)


class DataIsLazyError(Exception):
"""Exception to raise when a field is accessed that contains lazy data."""


class LazyField(Field):
"""Base class for fields that are loaded lazily."""

def get_value_except_lazy(self, instance):
"""Get the value of the field.
If the value is lazy, raise an Error.
Raises
------
DataIsLazyError
If the value is lazy.
"""
with contextlib.suppress(KeyError):
if instance.__dict__[self.name] is LazyOption:
raise DataIsLazyError()

return getattr(instance, self.name, None)

def __get__(self, instance, owner=None):
"""Load the field from disk if it is not already loaded."""
if instance is None:
Expand Down
Loading

0 comments on commit eba4873

Please sign in to comment.