Skip to content

Commit

Permalink
added tests for explict and auto-superclass casting
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed Sep 8, 2023
1 parent 6bca1db commit d6c3e2f
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pydra/engine/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def make_klass(spec):
)
checker_label = f"'{name}' field of {spec.name}"
type_checker = TypeParser[newfield.type](
newfield.type, label=checker_label, allow_lazy_super=True
newfield.type, label=checker_label, superclass_auto_cast=True
)
if newfield.type in (MultiInputObj, MultiInputFile):
converter = attr.converters.pipe(ensure_list, type_checker)
Expand Down
94 changes: 80 additions & 14 deletions pydra/utils/tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
GenericShellTask,
specific_func_task,
SpecificShellTask,
other_specific_func_task,
OtherSpecificShellTask,
MyFormatX,
MyOtherFormatX,
MyHeader,
)

Expand Down Expand Up @@ -168,12 +171,12 @@ def test_type_check_permit_superclass():
# Typical case as Json is subclass of File
TypeParser(ty.List[File])(lz(ty.List[Json]))
# Permissive super class, as File is superclass of Json
TypeParser(ty.List[Json], allow_lazy_super=True)(lz(ty.List[File]))
TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[File]))
with pytest.raises(TypeError, match="Cannot coerce"):
TypeParser(ty.List[Json], allow_lazy_super=False)(lz(ty.List[File]))
TypeParser(ty.List[Json], superclass_auto_cast=False)(lz(ty.List[File]))
# Fails because Yaml is neither sub or super class of Json
with pytest.raises(TypeError, match="Cannot coerce"):
TypeParser(ty.List[Json], allow_lazy_super=True)(lz(ty.List[Yaml]))
TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[Yaml]))


def test_type_check_fail1():
Expand Down Expand Up @@ -550,7 +553,17 @@ def specific_task(request):
assert False


def test_typing_cast(tmp_path, generic_task, specific_task):
@pytest.fixture(params=["func", "shell"])
def other_specific_task(request):
if request.param == "func":
return other_specific_func_task
elif request.param == "shell":
return OtherSpecificShellTask
else:
assert False


def test_typing_implicit_cast_from_super(tmp_path, generic_task, specific_task):
"""Check the casting of lazy fields and whether specific file-sets can be recovered
from generic `File` classes"""

Expand All @@ -574,33 +587,86 @@ def test_typing_cast(tmp_path, generic_task, specific_task):
)
)

wf.add(
specific_task(
in_file=wf.generic.lzout.out,
name="specific2",
)
)

wf.set_output(
[
("out_file", wf.specific2.lzout.out),
]
)

in_file = MyFormatX.sample()

result = wf(in_file=in_file, plugin="serial")

out_file: MyFormatX = result.output.out_file
assert type(out_file) is MyFormatX
assert out_file.parent != in_file.parent
assert type(out_file.header) is MyHeader
assert out_file.header.parent != in_file.header.parent


def test_typing_cast(tmp_path, specific_task, other_specific_task):
"""Check the casting of lazy fields and whether specific file-sets can be recovered
from generic `File` classes"""

wf = Workflow(
name="test",
input_spec={"in_file": MyFormatX},
output_spec={"out_file": MyFormatX},
)

wf.add(
specific_task(
in_file=wf.lzin.in_file,
name="entry",
)
)

with pytest.raises(TypeError, match="Cannot coerce"):
# No cast of generic task output to MyFormatX
wf.add( # Generic task
other_specific_task(
in_file=wf.entry.lzout.out,
name="inner",
)
)

wf.add( # Generic task
other_specific_task(
in_file=wf.entry.lzout.out.cast(MyOtherFormatX),
name="inner",
)
)

with pytest.raises(TypeError, match="Cannot coerce"):
# No cast of generic task output to MyFormatX
wf.add(
specific_task(
in_file=wf.generic.lzout.out,
name="specific2",
in_file=wf.inner.lzout.out,
name="exit",
)
)

wf.add(
specific_task(
in_file=wf.generic.lzout.out.cast(MyFormatX),
name="specific2",
in_file=wf.inner.lzout.out.cast(MyFormatX),
name="exit",
)
)

wf.set_output(
[
("out_file", wf.specific2.lzout.out),
("out_file", wf.exit.lzout.out),
]
)

my_fspath = tmp_path / "in_file.my"
hdr_fspath = tmp_path / "in_file.hdr"
my_fspath.write_text("my-format")
hdr_fspath.write_text("my-header")
in_file = MyFormatX([my_fspath, hdr_fspath])
in_file = MyFormatX.sample()

result = wf(in_file=in_file, plugin="serial")

Expand Down
89 changes: 87 additions & 2 deletions pydra/utils/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from pathlib import Path
import typing as ty
from fileformats.generic import File
from fileformats.core.mixin import WithSeparateHeader
from fileformats.core.mixin import WithSeparateHeader, WithMagicNumber
from pydra import mark
from pydra.engine.task import ShellCommandTask
from pydra.engine import specs


class MyFormat(File):
class MyFormat(WithMagicNumber, File):
ext = ".my"
magic_number = b"MYFORMAT"


class MyHeader(File):
Expand All @@ -17,6 +20,34 @@ class MyFormatX(WithSeparateHeader, MyFormat):
header_type = MyHeader


class MyOtherFormatX(WithMagicNumber, WithSeparateHeader, File):
magic_number = b"MYFORMAT"
ext = ".my"
header_type = MyHeader


@File.generate_sample_data.register
def my_format_x_generate_sample_data(
my_format_x: MyFormatX, dest_dir: Path
) -> ty.List[Path]:
fspath = dest_dir / "file.my"
with open(fspath, "wb") as f:
f.write(b"MYFORMAT\nsome data goes here")
header_fspath = dest_dir / "file.hdr"
header_fspath.write_text("a: 1\nb: 2\nc: 3\n")
return [fspath, header_fspath]


@File.generate_sample_data.register
def my_other_format_generate_sample_data(
my_other_format: MyOtherFormatX, dest_dir: Path
) -> ty.List[Path]:
fspath = dest_dir / "file.my"
with open(fspath, "wb") as f:
f.write(b"MYFORMAT\nsome data goes here")
return [fspath]


@mark.task
def generic_func_task(in_file: File) -> File:
return in_file
Expand Down Expand Up @@ -118,3 +149,57 @@ class SpecificShellTask(ShellCommandTask):
input_spec = specific_shell_input_spec
output_spec = specific_shelloutput_spec
executable = "echo"


@mark.task
def other_specific_func_task(in_file: MyOtherFormatX) -> MyOtherFormatX:
return in_file


other_specific_shell_input_fields = [
(
"in_file",
MyOtherFormatX,
{
"help_string": "the input file",
"argstr": "",
"copyfile": "copy",
"sep": " ",
},
),
(
"out",
str,
{
"help_string": "output file name",
"argstr": "",
"position": -1,
"output_file_template": "{in_file}", # Pass through un-altered
},
),
]

other_specific_shell_input_spec = specs.SpecInfo(
name="Input", fields=other_specific_shell_input_fields, bases=(specs.ShellSpec,)
)

other_specific_shell_output_fields = [
(
"out",
MyOtherFormatX,
{
"help_string": "output file",
},
),
]
other_specific_shelloutput_spec = specs.SpecInfo(
name="Output",
fields=other_specific_shell_output_fields,
bases=(specs.ShellOutSpec,),
)


class OtherSpecificShellTask(ShellCommandTask):
input_spec = other_specific_shell_input_spec
output_spec = other_specific_shelloutput_spec
executable = "echo"
19 changes: 14 additions & 5 deletions pydra/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class TypeParser(ty.Generic[T]):
the tree of more complex nested container types. Overrides 'coercible' to enable
you to carve out exceptions, such as TypeParser(list, coercible=[(ty.Iterable, list)],
not_coercible=[(str, list)])
allow_lazy_super : bool
superclass_auto_cast : bool
Allow lazy fields to pass the type check if their types are superclasses of the
specified pattern (instead of matching or being subclasses of the pattern)
label : str
Expand All @@ -69,7 +69,7 @@ class TypeParser(ty.Generic[T]):
tp: ty.Type[T]
coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]]
not_coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]]
allow_lazy_super: bool
superclass_auto_cast: bool
label: str

COERCIBLE_DEFAULT: ty.Tuple[ty.Tuple[type, type], ...] = (
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
not_coercible: ty.Optional[
ty.Iterable[ty.Tuple[TypeOrAny, TypeOrAny]]
] = NOT_COERCIBLE_DEFAULT,
allow_lazy_super: bool = False,
superclass_auto_cast: bool = False,
label: str = "",
):
def expand_pattern(t):
Expand Down Expand Up @@ -142,7 +142,7 @@ def expand_pattern(t):
)
self.not_coercible = list(not_coercible) if not_coercible is not None else []
self.pattern = expand_pattern(tp)
self.allow_lazy_super = allow_lazy_super
self.superclass_auto_cast = superclass_auto_cast

def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]:
"""Attempts to coerce the object to the specified type, unless the value is
Expand Down Expand Up @@ -172,7 +172,7 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]:
try:
self.check_type(obj.type)
except TypeError as e:
if self.allow_lazy_super:
if self.superclass_auto_cast:
try:
# Check whether the type of the lazy field isn't a superclass of
# the type to check against, and if so, allow it due to permissive
Expand Down Expand Up @@ -492,8 +492,17 @@ def check_coercible(
explicit inclusions and exclusions set in the `coercible` and `not_coercible`
member attrs
"""
# Short-circuit the basic cases where the source and target are the same
if source is target:
return
if self.superclass_auto_cast and self.is_subclass(target, type(source)):
logger.info(
"Attempting to coerce %s into %s due to super-to-sub class coercion "
"being permitted",
source,
target,
)
return
source_origin = get_origin(source)
if source_origin is not None:
source = source_origin
Expand Down

0 comments on commit d6c3e2f

Please sign in to comment.