From dd3b814fb09c6db73e4d2a94869c36265850d145 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Wed, 22 May 2024 07:44:26 +1000 Subject: [PATCH 1/9] streamlines typing error messages and fixes bug in super->subclass "auto casting" when the sub-classes are part of a union. If the super class is a super class of any of the union args then it is ok. --- pydra/utils/tests/test_typing.py | 146 ++++++++++++++++++++++--------- pydra/utils/typing.py | 60 +++++++++++-- 2 files changed, 156 insertions(+), 50 deletions(-) diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index b41aefd2a8..5e78fde9b0 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -1,6 +1,7 @@ import os import itertools import sys +import re import typing as ty from pathlib import Path import tempfile @@ -28,6 +29,17 @@ def lz(tp: ty.Type): return LazyOutField(name="foo", field="boo", type=tp) +def exc_info_matches(exc_info, match, regex=False): + if exc_info.value.__cause__ is not None: + msg = str(exc_info.value.__cause__) + else: + msg = str(exc_info.value) + if regex: + return re.match(".*" + match, msg) + else: + return match in msg + + PathTypes = ty.Union[str, os.PathLike] @@ -36,8 +48,9 @@ def test_type_check_basic1(): def test_type_check_basic2(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=[(int, float)])(lz(float)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_basic3(): @@ -45,8 +58,9 @@ def test_type_check_basic3(): def test_type_check_basic4(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=[(ty.Any, float)])(lz(float)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_basic5(): @@ -54,8 +68,9 @@ def test_type_check_basic5(): def test_type_check_basic6(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=None, not_coercible=[(float, int)])(lz(float)) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_check_basic7(): @@ -63,9 +78,11 @@ def test_type_check_basic7(): path_coercer(lz(Path)) - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: path_coercer(lz(str)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") + def test_type_check_basic8(): TypeParser(Path, coercible=[(PathTypes, PathTypes)])(lz(str)) @@ -74,7 +91,6 @@ def test_type_check_basic8(): def test_type_check_basic9(): file_coercer = TypeParser(File, coercible=[(PathTypes, File)]) - file_coercer(lz(Path)) file_coercer(lz(str)) @@ -82,8 +98,9 @@ def test_type_check_basic9(): def test_type_check_basic10(): impotent_str_coercer = TypeParser(str, coercible=[(PathTypes, File)]) - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: impotent_str_coercer(lz(File)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_basic11(): @@ -108,12 +125,13 @@ def test_type_check_basic13(): def test_type_check_basic14(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( list, coercible=[(ty.Sequence, ty.Sequence)], not_coercible=[(str, ty.Sequence)], )(lz(str)) + assert exc_info_matches(exc_info, match="explicitly excluded") def test_type_check_basic15(): @@ -126,10 +144,11 @@ def test_type_check_basic15a(): def test_type_check_basic16(): - with pytest.raises( - TypeError, match="Cannot coerce to any of the union types" - ): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Union[Path, File, bool, int])(lz(float)) + assert exc_info_matches( + exc_info, match="Cannot coerce to any of the union types" + ) @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") @@ -173,16 +192,18 @@ def test_type_check_nested7(): def test_type_check_nested7a(): - with pytest.raises(TypeError, match="Wrong number of type arguments"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Tuple[float, float, float])(lz(ty.Tuple[int])) + assert exc_info_matches(exc_info, "Wrong number of type arguments") def test_type_check_nested8(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( ty.Tuple[int, ...], not_coercible=[(ty.Sequence, ty.Tuple)], )(lz(ty.List[float])) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_check_permit_superclass(): @@ -190,21 +211,25 @@ def test_type_check_permit_superclass(): TypeParser(ty.List[File])(lz(ty.List[Json])) # Permissive super class, as File is superclass of Json TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[File])) - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[Json], superclass_auto_cast=False)(lz(ty.List[File])) + assert exc_info_matches(exc_info, "Cannot coerce") # Fails because Yaml is neither sub or super class of Json - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[Yaml])) + assert exc_info_matches(exc_info, "Cannot coerce") def test_type_check_fail1(): - with pytest.raises(TypeError, match="Wrong number of type arguments in tuple"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Tuple[int, int, int])(lz(ty.Tuple[float, float, float, float])) + assert exc_info_matches(exc_info, "Wrong number of type arguments in tuple") def test_type_check_fail2(): - with pytest.raises(TypeError, match="to any of the union types"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Union[Path, File])(lz(int)) + assert exc_info_matches(exc_info, "to any of the union types") @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") @@ -214,25 +239,29 @@ def test_type_check_fail2a(): def test_type_check_fail3(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence, coercible=[(ty.Sequence, ty.Sequence)])( lz(ty.Dict[str, int]) ) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_fail4(): - with pytest.raises(TypeError, match="Cannot coerce into"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence)(lz(ty.Dict[str, int])) + assert exc_info_matches(exc_info, "Cannot coerce into") def test_type_check_fail5(): - with pytest.raises(TypeError, match=" doesn't match pattern"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int])(lz(int)) + assert exc_info_matches(exc_info, " doesn't match pattern") def test_type_check_fail6(): - with pytest.raises(TypeError, match=" doesn't match pattern"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[ty.Dict[str, str]])(lz(ty.Tuple[int, int, int])) + assert exc_info_matches(exc_info, " doesn't match pattern") def test_type_coercion_basic(): @@ -240,8 +269,9 @@ def test_type_coercion_basic(): def test_type_coercion_basic1(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(float, coercible=[(ty.Any, int)])(1) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_basic2(): @@ -254,8 +284,9 @@ def test_type_coercion_basic2(): def test_type_coercion_basic3(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=[(ty.Any, ty.Any)], not_coercible=[(float, int)])(1.0) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_coercion_basic4(): @@ -263,8 +294,9 @@ def test_type_coercion_basic4(): assert path_coercer(Path("/a/path")) == Path("/a/path") - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: path_coercer("/a/path") + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_basic5(): @@ -296,8 +328,9 @@ def test_type_coercion_basic7(a_file): def test_type_coercion_basic8(a_file): impotent_str_coercer = TypeParser(str, coercible=[(PathTypes, File)]) - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: impotent_str_coercer(File(a_file)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_basic9(a_file): @@ -321,13 +354,13 @@ def test_type_coercion_basic11(): def test_type_coercion_basic12(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( list, coercible=[(ty.Sequence, ty.Sequence)], not_coercible=[(str, ty.Sequence)], )("a-string") - + assert exc_info_matches(exc_info, "explicitly excluded") assert TypeParser(ty.Union[Path, File, int], coercible=[(ty.Any, ty.Any)])(1.0) == 1 @@ -422,24 +455,27 @@ def test_type_coercion_nested7(): def test_type_coercion_nested8(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( ty.Tuple[int, ...], coercible=[(ty.Any, ty.Any)], not_coercible=[(ty.Sequence, ty.Tuple)], )([1.0, 2.0, 3.0]) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_coercion_fail1(): - with pytest.raises(TypeError, match="Incorrect number of items"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Tuple[int, int, int], coercible=[(ty.Any, ty.Any)])( [1.0, 2.0, 3.0, 4.0] ) + assert exc_info_matches(exc_info, "Incorrect number of items") def test_type_coercion_fail2(): - with pytest.raises(TypeError, match="to any of the union types"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Union[Path, File], coercible=[(ty.Any, ty.Any)])(1) + assert exc_info_matches(exc_info, "to any of the union types") @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") @@ -449,25 +485,29 @@ def test_type_coercion_fail2a(): def test_type_coercion_fail3(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence, coercible=[(ty.Sequence, ty.Sequence)])( {"a": 1, "b": 2} ) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_fail4(): - with pytest.raises(TypeError, match="Cannot coerce {'a': 1} into"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence, coercible=[(ty.Any, ty.Any)])({"a": 1}) + assert exc_info_matches(exc_info, "Cannot coerce {'a': 1} into") def test_type_coercion_fail5(): - with pytest.raises(TypeError, match="as 1 is not iterable"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int], coercible=[(ty.Any, ty.Any)])(1) + assert exc_info_matches(exc_info, "as 1 is not iterable") def test_type_coercion_fail6(): - with pytest.raises(TypeError, match="is not a mapping type"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[ty.Dict[str, str]], coercible=[(ty.Any, ty.Any)])((1, 2, 3)) + assert exc_info_matches(exc_info, "is not a mapping type") def test_type_coercion_realistic(): @@ -490,21 +530,29 @@ def f(x: ty.List[File], y: ty.Dict[str, ty.List[File]]): TypeParser(ty.List[str])(task.lzout.a) # pylint: disable=no-member with pytest.raises( TypeError, - match="Cannot coerce into ", - ): + ) as exc_info: TypeParser(ty.List[int])(task.lzout.a) # pylint: disable=no-member + assert exc_info_matches( + exc_info, + match=r"Cannot coerce into ", + regex=True, + ) - with pytest.raises( - TypeError, match="Cannot coerce 'bad-value' into " - ): + with pytest.raises(TypeError) as exc_info: task.inputs.x = "bad-value" + assert exc_info_matches( + exc_info, match="Cannot coerce 'bad-value' into " + ) def test_check_missing_type_args(): - with pytest.raises(TypeError, match="wasn't declared with type args required"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int]).check_type(list) - with pytest.raises(TypeError, match="doesn't match pattern"): + assert exc_info_matches(exc_info, "wasn't declared with type args required") + + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int]).check_type(dict) + assert exc_info_matches(exc_info, "doesn't match pattern") def test_matches_type_union(): @@ -610,6 +658,18 @@ def test_contains_type_in_dict(): ) +def test_any_union(): + """Check that the superclass auto-cast matches if any of the union args match instead + of all""" + TypeParser(File, match_any_of_union=True).check_type(ty.Union[ty.List[File], Json]) + + +def test_union_superclass_check_type(): + """Check that the superclass auto-cast matches if any of the union args match instead + of all""" + TypeParser(ty.Union[ty.List[File], Json], superclass_auto_cast=True)(lz(File)) + + def test_type_matches(): assert TypeParser.matches([1, 2, 3], ty.List[int]) assert TypeParser.matches((1, 2, 3), ty.Tuple[int, ...]) @@ -713,7 +773,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): ) ) - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: # No cast of generic task output to MyFormatX wf.add( # Generic task other_specific_task( @@ -721,6 +781,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): name="inner", ) ) + assert exc_info_matches(exc_info, "Cannot coerce") wf.add( # Generic task other_specific_task( @@ -729,7 +790,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): ) ) - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: # No cast of generic task output to MyFormatX wf.add( specific_task( @@ -737,6 +798,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): name="exit", ) ) + assert exc_info_matches(exc_info, "Cannot coerce") wf.add( specific_task( diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index c765b1339c..531c7ecf9d 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -70,6 +70,8 @@ class TypeParser(ty.Generic[T]): label : str the label to be used to identify the type parser in error messages. Especially useful when TypeParser is used as a converter in attrs.fields + match_any_of_union : bool + match if any of the options in the union are a subclass (but not necessarily all) """ tp: ty.Type[T] @@ -77,6 +79,7 @@ class TypeParser(ty.Generic[T]): not_coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]] superclass_auto_cast: bool label: str + match_any_of_union: bool COERCIBLE_DEFAULT: ty.Tuple[ty.Tuple[type, type], ...] = ( ( @@ -121,6 +124,7 @@ def __init__( ] = NOT_COERCIBLE_DEFAULT, superclass_auto_cast: bool = False, label: str = "", + match_any_of_union: bool = False, ): def expand_pattern(t): """Recursively expand the type arguments of the target type in nested tuples""" @@ -151,6 +155,7 @@ def expand_pattern(t): self.not_coercible = list(not_coercible) if not_coercible is not None else [] self.pattern = expand_pattern(tp) self.superclass_auto_cast = superclass_auto_cast + self.match_any_of_union = match_any_of_union def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: """Attempts to coerce the object to the specified type, unless the value is @@ -185,9 +190,15 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: # Check whether the type of the lazy field isn't a superclass of # the type to check against, and if so, allow it due to permissive # typing rules. - TypeParser(obj.type).check_type(self.tp) + TypeParser(obj.type, match_any_of_union=True).check_type( + self.tp + ) except TypeError: - raise e + raise TypeError( + f"Incorrect type for lazy field{self.label_str}: {obj.type!r} " + f"is not a subclass or superclass of {self.tp} (and will not " + "be able to be coerced to one that is)" + ) from e else: logger.info( "Connecting lazy field %s to %s%s via permissive typing that " @@ -197,12 +208,22 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: self.label_str, ) else: - raise e + raise TypeError( + f"Incorrect type for lazy field{self.label_str}: {obj.type!r} " + f"is not a subclass of {self.tp} (and will not be able to be " + "coerced to one that is)" + ) from e coerced = obj # type: ignore elif isinstance(obj, StateArray): coerced = StateArray(self(o) for o in obj) # type: ignore[assignment] else: - coerced = self.coerce(obj) + try: + coerced = self.coerce(obj) + except TypeError as e: + raise TypeError( + f"Incorrect type for field{self.label_str}: {obj!r} is not of type " + f"{self.tp} (and cannot be coerced to it)" + ) from e return coerced def coerce(self, object_: ty.Any) -> T: @@ -406,12 +427,31 @@ def check_basic(tp, target): # Note that we are deliberately more permissive than typical type-checking # here, allowing parents of the target type as well as children, # to avoid users having to cast from loosely typed tasks to strict ones + if self.match_any_of_union and get_origin(tp) is ty.Union: + reasons = [] + tp_args = get_args(tp) + for tp_arg in tp_args: + if self.is_subclass(tp_arg, target): + return + try: + self.check_coercible(tp_arg, target) + except TypeError as e: + reasons.append(e) + else: + return + if reasons: + raise TypeError( + f"Cannot coerce any union args {tp_arg} to {target}" + f"{self.label_str}:\n\n" + + "\n\n".join(f"{a} -> {e}" for a, e in zip(tp_args, reasons)) + ) if not self.is_subclass(tp, target): self.check_coercible(tp, target) def check_union(tp, pattern_args): if get_origin(tp) in UNION_TYPES: - for tp_arg in get_args(tp): + tp_args = get_args(tp) + for tp_arg in tp_args: reasons = [] for pattern_arg in pattern_args: try: @@ -421,11 +461,15 @@ def check_union(tp, pattern_args): else: reasons = None break + if self.match_any_of_union and len(reasons) < len(tp_args): + # Just need one of the union args to match + return if reasons: + determiner = "any" if self.match_any_of_union else "all" raise TypeError( - f"Cannot coerce {tp} to " - f"ty.Union[{', '.join(str(a) for a in pattern_args)}]{self.label_str}, " - f"because {tp_arg} cannot be coerced to any of its args:\n\n" + f"Cannot coerce {tp} to ty.Union[" + f"{', '.join(str(a) for a in pattern_args)}]{self.label_str}, " + f"because {tp_arg} cannot be coerced to {determiner} of its args:\n\n" + "\n\n".join( f"{a} -> {e}" for a, e in zip(pattern_args, reasons) ) From 2d840be81e90f8abcb37cf374449f012e4af5692 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Wed, 22 May 2024 07:44:26 +1000 Subject: [PATCH 2/9] touched up error message --- pydra/engine/specs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index bbfbd57941..a2e3651779 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -694,7 +694,8 @@ def __getattr__(self, name): raise AttributeError(f"{name} hasn't been set yet") if name not in self._field_names: raise AttributeError( - f"Task {self._task.name} has no {self._attr_type} attribute {name}" + f"Task '{self._task.name}' has no {self._attr_type} attribute '{name}', " + "available: '" + "', '".join(self._field_names) + "'" ) type_ = self._get_type(name) splits = self._get_task_splits() From 5a7955e6306d6f15c82ee6642cf3f15af43cac15 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Wed, 22 May 2024 07:44:26 +1000 Subject: [PATCH 3/9] fixes issues with super->sub-class auto-cast and handles MultiInputObj coercion --- pydra/utils/typing.py | 72 +++++++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 17 deletions(-) diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index 531c7ecf9d..32968001c9 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -366,7 +366,18 @@ def coerce_obj(obj, type_): f"Cannot coerce {obj!r} into {type_}{msg}{self.label_str}" ) from e - return expand_and_coerce(object_, self.pattern) + # Special handling for MultiInputObjects (which are annoying) + if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: + try: + self.check_coercible(object_, self.pattern[1][0]) + except TypeError: + pass + else: + obj = [object_] + else: + obj = object_ + + return expand_and_coerce(obj, self.pattern) def check_type(self, type_: ty.Type[ty.Any]): """Checks the given type to see whether it matches or is a subtype of the @@ -413,7 +424,7 @@ def expand_and_check(tp, pattern: ty.Union[type, tuple]): f"{self.pattern}{self.label_str}" ) tp_args = get_args(tp) - self.check_coercible(tp_origin, pattern_origin) + self.check_type_coercible(tp_origin, pattern_origin) if issubclass(pattern_origin, ty.Mapping): return check_mapping(tp_args, pattern_args) if issubclass(pattern_origin, tuple): @@ -446,7 +457,7 @@ def check_basic(tp, target): + "\n\n".join(f"{a} -> {e}" for a, e in zip(tp_args, reasons)) ) if not self.is_subclass(tp, target): - self.check_coercible(tp, target) + self.check_type_coercible(tp, target) def check_union(tp, pattern_args): if get_origin(tp) in UNION_TYPES: @@ -526,19 +537,46 @@ def check_sequence(tp_args, pattern_args): for arg in tp_args: expand_and_check(arg, pattern_args[0]) - return expand_and_check(type_, self.pattern) + # Special handling for MultiInputObjects (which are annoying) + if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: + pattern = (ty.Union, [self.pattern[1][0], (ty.List, self.pattern[1])]) + else: + pattern = self.pattern + return expand_and_check(type_, pattern) + + def check_coercible(self, source: ty.Any, target: ty.Union[type, ty.Any]): + """Checks whether the source object is coercible to the target type given the coercion + rules defined in the `coercible` and `not_coercible` attrs + + Parameters + ---------- + source : object + the object to be coerced + target : type or typing.Any + the target type for the object to be coerced to + + Raises + ------ + TypeError + If the object cannot be coerced into the target type depending on the explicit + inclusions and exclusions set in the `coercible` and `not_coercible` member attrs + """ + self.check_type_coercible(type(source), target, source_repr=repr(source)) - def check_coercible( - self, source: ty.Union[object, type], target: ty.Union[type, ty.Any] + def check_type_coercible( + self, + source: ty.Union[type, ty.Any], + target: ty.Union[type, ty.Any], + source_repr: ty.Optional[str] = None, ): - """Checks whether the source object or type is coercible to the target type + """Checks whether the source type is coercible to the target type given the coercion rules defined in the `coercible` and `not_coercible` attrs Parameters ---------- - source : object or type - source object or type to be coerced - target : type or ty.Any + source : type or typing.Any + source type to be coerced + target : type or typing.Any target type for the source to be coerced to Raises @@ -548,10 +586,12 @@ def check_coercible( explicit inclusions and exclusions set in the `coercible` and `not_coercible` member attrs """ + if source_repr is None: + source_repr = repr(source) # Short-circuit the basic cases where the source and target are the same if source is target: return - if self.superclass_auto_cast and self.is_subclass(target, type(source)): + if self.superclass_auto_cast and self.is_subclass(target, source): logger.info( "Attempting to coerce %s into %s due to super-to-sub class coercion " "being permitted", @@ -563,13 +603,11 @@ def check_coercible( if source_origin is not None: source = source_origin - source_check = self.is_subclass if inspect.isclass(source) else self.is_instance - def matches_criteria(criteria): return [ (src, tgt) for src, tgt in criteria - if source_check(source, src) and self.is_subclass(target, tgt) + if self.is_subclass(source, src) and self.is_subclass(target, tgt) ] def type_name(t): @@ -580,7 +618,7 @@ def type_name(t): if not matches_criteria(self.coercible): raise TypeError( - f"Cannot coerce {repr(source)} into {target}{self.label_str} as the " + f"Cannot coerce {source_repr} into {target}{self.label_str} as the " "coercion doesn't match any of the explicit inclusion criteria: " + ", ".join( f"{type_name(s)} -> {type_name(t)}" for s, t in self.coercible @@ -589,7 +627,7 @@ def type_name(t): matches_not_coercible = matches_criteria(self.not_coercible) if matches_not_coercible: raise TypeError( - f"Cannot coerce {repr(source)} into {target}{self.label_str} as it is explicitly " + f"Cannot coerce {source_repr} into {target}{self.label_str} as it is explicitly " "excluded by the following coercion criteria: " + ", ".join( f"{type_name(s)} -> {type_name(t)}" @@ -683,7 +721,7 @@ def is_instance( if inspect.isclass(obj): return candidate is type if issubtype(type(obj), candidate) or ( - type(obj) is dict and candidate is ty.Mapping + type(obj) is dict and candidate is ty.Mapping # noqa: E721 ): return True else: From 0eed44a95c5f8e3c1e152aaee394005e2190a662 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Wed, 22 May 2024 07:45:02 +1000 Subject: [PATCH 4/9] added unittests for multi_input_obj coercion --- pydra/utils/tests/test_typing.py | 64 +++++++++++++++++++++++++++++++- pydra/utils/typing.py | 58 ++++++++++++++++++++--------- 2 files changed, 103 insertions(+), 19 deletions(-) diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index 5e78fde9b0..df87a87f2c 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -7,7 +7,7 @@ import tempfile import pytest from pydra import mark -from ...engine.specs import File, LazyOutField +from ...engine.specs import File, LazyOutField, MultiInputObj from ..typing import TypeParser from pydra import Workflow from fileformats.application import Json, Yaml, Xml @@ -249,7 +249,7 @@ def test_type_check_fail3(): def test_type_check_fail4(): with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence)(lz(ty.Dict[str, int])) - assert exc_info_matches(exc_info, "Cannot coerce into") + assert exc_info_matches(exc_info, "Cannot coerce .*(d|D)ict.* into") def test_type_check_fail5(): @@ -1043,3 +1043,63 @@ def test_type_is_instance11(): @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_is_instance11a(): assert not TypeParser.is_instance(None, int | str) + + +def test_multi_input_obj_coerce1(): + assert TypeParser(MultiInputObj[str])("a") == ["a"] + + +def test_multi_input_obj_coerce2(): + assert TypeParser(MultiInputObj[str])(["a"]) == ["a"] + + +def test_multi_input_obj_coerce3(): + assert TypeParser(MultiInputObj[ty.List[str]])(["a"]) == [["a"]] + + +def test_multi_input_obj_coerce3a(): + assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(["a"]) == [["a"]] + + +def test_multi_input_obj_coerce3b(): + assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([["a"]]) == [["a"]] + + +def test_multi_input_obj_coerce4(): + assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([1]) == [1] + + +def test_multi_input_obj_coerce4a(): + with pytest.raises(TypeError): + TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([[1]]) + + +def test_multi_input_obj_check_type1(): + TypeParser(MultiInputObj[str])(lz(str)) + + +def test_multi_input_obj_check_type2(): + TypeParser(MultiInputObj[str])(lz(ty.List[str])) + + +def test_multi_input_obj_check_type3(): + TypeParser(MultiInputObj[ty.List[str]])(lz(ty.List[str])) + + +def test_multi_input_obj_check_type3a(): + TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[str])) + + +def test_multi_input_obj_check_type3b(): + TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[ty.List[str]])) + + +def test_multi_input_obj_check_type4(): + TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[int])) + + +def test_multi_input_obj_check_type4a(): + with pytest.raises(TypeError): + TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])( + lz(ty.List[ty.List[int]]) + ) diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index 32968001c9..136bfd443e 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -2,6 +2,7 @@ import inspect from pathlib import Path import os +from copy import copy import sys import types import typing as ty @@ -13,6 +14,7 @@ MultiInputObj, MultiOutputObj, ) +from ..utils import add_exc_note from fileformats import field try: @@ -366,18 +368,26 @@ def coerce_obj(obj, type_): f"Cannot coerce {obj!r} into {type_}{msg}{self.label_str}" ) from e - # Special handling for MultiInputObjects (which are annoying) - if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: - try: - self.check_coercible(object_, self.pattern[1][0]) - except TypeError: - pass + try: + return expand_and_coerce(object_, self.pattern) + except TypeError as e: + # Special handling for MultiInputObjects (which are annoying) + if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: + # Attempt to coerce the object into arg type of the MultiInputObj first, + # and if that fails, try to coerce it into a list of the arg type + inner_type_parser = copy(self) + inner_type_parser.pattern = self.pattern[1][0] + try: + return [inner_type_parser.coerce(object_)] + except TypeError: + add_exc_note( + e, + "Also failed to coerce to the arg-type of the MultiInputObj " + f"({self.pattern[1][0]})", + ) + raise e else: - obj = [object_] - else: - obj = object_ - - return expand_and_coerce(obj, self.pattern) + raise e def check_type(self, type_: ty.Type[ty.Any]): """Checks the given type to see whether it matches or is a subtype of the @@ -537,12 +547,26 @@ def check_sequence(tp_args, pattern_args): for arg in tp_args: expand_and_check(arg, pattern_args[0]) - # Special handling for MultiInputObjects (which are annoying) - if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: - pattern = (ty.Union, [self.pattern[1][0], (ty.List, self.pattern[1])]) - else: - pattern = self.pattern - return expand_and_check(type_, pattern) + try: + return expand_and_check(type_, self.pattern) + except TypeError as e: + # Special handling for MultiInputObjects (which are annoying) + if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: + # Attempt to coerce the object into arg type of the MultiInputObj first, + # and if that fails, try to coerce it into a list of the arg type + inner_type_parser = copy(self) + inner_type_parser.pattern = self.pattern[1][0] + try: + inner_type_parser.check_type(type_) + except TypeError: + add_exc_note( + e, + "Also failed to coerce to the arg-type of the MultiInputObj " + f"({self.pattern[1][0]})", + ) + raise e + else: + raise e def check_coercible(self, source: ty.Any, target: ty.Union[type, ty.Any]): """Checks whether the source object is coercible to the target type given the coercion From 1bec68ccc0d311b182e49465e4f8f38bd85ab18f Mon Sep 17 00:00:00 2001 From: Tom Close Date: Wed, 22 May 2024 09:06:33 +1000 Subject: [PATCH 5/9] fixed up unittests --- pydra/engine/tests/test_specs.py | 5 ++++- pydra/engine/tests/test_workflow.py | 18 +++++++++-------- pydra/utils/__init__.py | 2 +- pydra/utils/misc.py | 12 ++++++++++++ pydra/utils/tests/test_typing.py | 30 ++++++++++++----------------- 5 files changed, 39 insertions(+), 28 deletions(-) diff --git a/pydra/engine/tests/test_specs.py b/pydra/engine/tests/test_specs.py index 77a0f690b7..8221751d01 100644 --- a/pydra/engine/tests/test_specs.py +++ b/pydra/engine/tests/test_specs.py @@ -124,7 +124,10 @@ def test_lazy_getvale(): lf = LazyIn(task=tn) with pytest.raises(Exception) as excinfo: lf.inp_c - assert str(excinfo.value) == "Task tn has no input attribute inp_c" + assert ( + str(excinfo.value) + == "Task 'tn' has no input attribute 'inp_c', available: 'inp_a', 'inp_b'" + ) def test_input_file_hash_1(tmp_path): diff --git a/pydra/engine/tests/test_workflow.py b/pydra/engine/tests/test_workflow.py index 598021c832..c6aab6544f 100644 --- a/pydra/engine/tests/test_workflow.py +++ b/pydra/engine/tests/test_workflow.py @@ -37,6 +37,7 @@ from ..core import Workflow from ... import mark from ..specs import SpecInfo, BaseSpec, ShellSpec +from pydra.utils import exc_info_matches def test_wf_no_input_spec(): @@ -102,13 +103,15 @@ def test_wf_dict_input_and_output_spec(): wf.inputs.a = "any-string" wf.inputs.b = {"foo": 1, "bar": False} - with pytest.raises(TypeError, match="Cannot coerce 1.0 into "): + with pytest.raises(TypeError) as exc_info: wf.inputs.a = 1.0 - with pytest.raises( - TypeError, - match=("Could not coerce object, 'bad-value', to any of the union types "), - ): + assert exc_info_matches(exc_info, "Cannot coerce 1.0 into ") + + with pytest.raises(TypeError) as exc_info: wf.inputs.b = {"foo": 1, "bar": "bad-value"} + assert exc_info_matches( + exc_info, "Could not coerce object, 'bad-value', to any of the union types" + ) result = wf() assert result.output.a == "any-string" @@ -5002,14 +5005,13 @@ def test_wf_input_output_typing(): output_spec={"alpha": int, "beta": ty.List[int]}, ) - with pytest.raises( - TypeError, match="Cannot coerce into " - ): + with pytest.raises(TypeError) as exc_info: list_mult_sum( scalar=wf.lzin.y, in_list=wf.lzin.y, name="A", ) + exc_info_matches(exc_info, "Cannot coerce into ") wf.add( # Split over workflow input "x" on "scalar" input list_mult_sum( diff --git a/pydra/utils/__init__.py b/pydra/utils/__init__.py index 9008779e27..cfde94dbf8 100644 --- a/pydra/utils/__init__.py +++ b/pydra/utils/__init__.py @@ -1 +1 @@ -from .misc import user_cache_dir, add_exc_note # noqa: F401 +from .misc import user_cache_dir, add_exc_note, exc_info_matches # noqa: F401 diff --git a/pydra/utils/misc.py b/pydra/utils/misc.py index 9a40769c9d..45b6a5c3ba 100644 --- a/pydra/utils/misc.py +++ b/pydra/utils/misc.py @@ -1,4 +1,5 @@ from pathlib import Path +import re import platformdirs from pydra._version import __version__ @@ -31,3 +32,14 @@ def add_exc_note(e: Exception, note: str) -> Exception: else: e.args = (e.args[0] + "\n" + note,) return e + + +def exc_info_matches(exc_info, match, regex=False): + if exc_info.value.__cause__ is not None: + msg = str(exc_info.value.__cause__) + else: + msg = str(exc_info.value) + if regex: + return re.match(".*" + match, msg) + else: + return match in msg diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index df87a87f2c..45d1ef46fe 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -1,7 +1,6 @@ import os import itertools import sys -import re import typing as ty from pathlib import Path import tempfile @@ -22,6 +21,7 @@ MyOtherFormatX, MyHeader, ) +from pydra.utils import exc_info_matches def lz(tp: ty.Type): @@ -29,17 +29,6 @@ def lz(tp: ty.Type): return LazyOutField(name="foo", field="boo", type=tp) -def exc_info_matches(exc_info, match, regex=False): - if exc_info.value.__cause__ is not None: - msg = str(exc_info.value.__cause__) - else: - msg = str(exc_info.value) - if regex: - return re.match(".*" + match, msg) - else: - return match in msg - - PathTypes = ty.Union[str, os.PathLike] @@ -154,7 +143,8 @@ def test_type_check_basic16(): @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_check_basic16a(): with pytest.raises( - TypeError, match="Cannot coerce to any of the union types" + TypeError, + match="Incorrect type for lazy field: is not a subclass of", ): TypeParser(Path | File | bool | int)(lz(float)) @@ -234,7 +224,7 @@ def test_type_check_fail2(): @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_check_fail2a(): - with pytest.raises(TypeError, match="to any of the union types"): + with pytest.raises(TypeError, match="Incorrect type for lazy field: "): TypeParser(Path | File)(lz(int)) @@ -249,7 +239,10 @@ def test_type_check_fail3(): def test_type_check_fail4(): with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence)(lz(ty.Dict[str, int])) - assert exc_info_matches(exc_info, "Cannot coerce .*(d|D)ict.* into") + assert exc_info_matches( + exc_info, + "Cannot coerce typing.Dict[str, int] into ", + ) def test_type_check_fail5(): @@ -366,13 +359,13 @@ def test_type_coercion_basic12(): @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_coercion_basic12a(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( list, coercible=[(ty.Sequence, ty.Sequence)], not_coercible=[(str, ty.Sequence)], )("a-string") - + assert exc_info_matches(exc_info, "explicitly excluded") assert TypeParser(Path | File | int, coercible=[(ty.Any, ty.Any)])(1.0) == 1 @@ -480,8 +473,9 @@ def test_type_coercion_fail2(): @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_coercion_fail2a(): - with pytest.raises(TypeError, match="to any of the union types"): + with pytest.raises(TypeError) as exc_info: TypeParser(Path | File, coercible=[(ty.Any, ty.Any)])(1) + assert exc_info_matches(exc_info, "to any of the union types") def test_type_coercion_fail3(): From ca0833cb6ab403c93dd2aaa0686840639f45821a Mon Sep 17 00:00:00 2001 From: Tom Close Date: Thu, 30 May 2024 10:13:32 +0930 Subject: [PATCH 6/9] Update pydra/utils/typing.py Co-authored-by: Chris Markiewicz --- pydra/utils/typing.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index 136bfd443e..e40f928047 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -551,21 +551,20 @@ def check_sequence(tp_args, pattern_args): return expand_and_check(type_, self.pattern) except TypeError as e: # Special handling for MultiInputObjects (which are annoying) - if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: - # Attempt to coerce the object into arg type of the MultiInputObj first, - # and if that fails, try to coerce it into a list of the arg type - inner_type_parser = copy(self) - inner_type_parser.pattern = self.pattern[1][0] - try: - inner_type_parser.check_type(type_) - except TypeError: - add_exc_note( - e, - "Also failed to coerce to the arg-type of the MultiInputObj " - f"({self.pattern[1][0]})", - ) - raise e - else: + if not isinstance(self.pattern, tuple) or self.pattern[0] != MultiInputObj: + raise e + # Attempt to coerce the object into arg type of the MultiInputObj first, + # and if that fails, try to coerce it into a list of the arg type + inner_type_parser = copy(self) + inner_type_parser.pattern = self.pattern[1][0] + try: + inner_type_parser.check_type(type_) + except TypeError: + add_exc_note( + e, + "Also failed to coerce to the arg-type of the MultiInputObj " + f"({self.pattern[1][0]})", + ) raise e def check_coercible(self, source: ty.Any, target: ty.Union[type, ty.Any]): From ae0422a2f7451bf7378634054821fed96f519046 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Thu, 30 May 2024 10:13:45 +0930 Subject: [PATCH 7/9] Update pydra/utils/tests/test_typing.py Co-authored-by: Chris Markiewicz --- pydra/utils/tests/test_typing.py | 35 ++++++++++++-------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index 45d1ef46fe..057c44091b 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -1039,28 +1039,19 @@ def test_type_is_instance11a(): assert not TypeParser.is_instance(None, int | str) -def test_multi_input_obj_coerce1(): - assert TypeParser(MultiInputObj[str])("a") == ["a"] - - -def test_multi_input_obj_coerce2(): - assert TypeParser(MultiInputObj[str])(["a"]) == ["a"] - - -def test_multi_input_obj_coerce3(): - assert TypeParser(MultiInputObj[ty.List[str]])(["a"]) == [["a"]] - - -def test_multi_input_obj_coerce3a(): - assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(["a"]) == [["a"]] - - -def test_multi_input_obj_coerce3b(): - assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([["a"]]) == [["a"]] - - -def test_multi_input_obj_coerce4(): - assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([1]) == [1] +@pytest.mark.parametrize( + ("typ", "obj", "result"), + [ + (MultiInputObj[str], "a", ["a"]), + (MultiInputObj[str], ["a"], ["a"]), + (MultiInputObj[ty.List[str]], ["a"], [["a"]]), + (MultiInputObj[ty.Union[int, ty.List[str]]], ["a"], [["a"]]), + (MultiInputObj[ty.Union[int, ty.List[str]]], [["a"]], [["a"]]), + (MultiInputObj[ty.Union[int, ty.List[str]]], [1], [1]), + ] +) +def test_multi_input_obj_coerce(typ, obj, result): + assert TypeParser(typ)(obj) == result def test_multi_input_obj_coerce4a(): From 4c00389cdc02aca3607804180fbcc208d733476b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 May 2024 00:43:58 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pydra/utils/tests/test_typing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index 057c44091b..ec6ed7ca66 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -1044,11 +1044,11 @@ def test_type_is_instance11a(): [ (MultiInputObj[str], "a", ["a"]), (MultiInputObj[str], ["a"], ["a"]), - (MultiInputObj[ty.List[str]], ["a"], [["a"]]), + (MultiInputObj[ty.List[str]], ["a"], [["a"]]), (MultiInputObj[ty.Union[int, ty.List[str]]], ["a"], [["a"]]), (MultiInputObj[ty.Union[int, ty.List[str]]], [["a"]], [["a"]]), (MultiInputObj[ty.Union[int, ty.List[str]]], [1], [1]), - ] + ], ) def test_multi_input_obj_coerce(typ, obj, result): assert TypeParser(typ)(obj) == result From ad28ae57f9423601991a7491e1ffe175476d2c54 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Thu, 30 May 2024 11:45:51 +0930 Subject: [PATCH 9/9] parametrized a few ranges of unittests --- pydra/utils/tests/test_typing.py | 234 ++++++++++++------------------- 1 file changed, 89 insertions(+), 145 deletions(-) diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index ec6ed7ca66..f83eedbd8c 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -75,6 +75,9 @@ def test_type_check_basic7(): def test_type_check_basic8(): TypeParser(Path, coercible=[(PathTypes, PathTypes)])(lz(str)) + + +def test_type_check_basic8a(): TypeParser(str, coercible=[(PathTypes, PathTypes)])(lz(Path)) @@ -94,6 +97,9 @@ def test_type_check_basic10(): def test_type_check_basic11(): TypeParser(str, coercible=[(PathTypes, PathTypes)])(lz(File)) + + +def test_type_check_basic11a(): TypeParser(File, coercible=[(PathTypes, PathTypes)])(lz(str)) @@ -655,12 +661,15 @@ def test_contains_type_in_dict(): def test_any_union(): """Check that the superclass auto-cast matches if any of the union args match instead of all""" + # The Json type within the Union matches File as it is a subclass as `match_any_of_union` + # is set to True. Otherwise, all types within the Union would have to match TypeParser(File, match_any_of_union=True).check_type(ty.Union[ty.List[File], Json]) def test_union_superclass_check_type(): """Check that the superclass auto-cast matches if any of the union args match instead of all""" + # In this case, File matches Json due to the `superclass_auto_cast=True` flag being set TypeParser(ty.Union[ty.List[File], Json], superclass_auto_cast=True)(lz(File)) @@ -818,20 +827,42 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): assert out_file.header.parent != in_file.header.parent -def test_type_is_subclass1(): - assert TypeParser.is_subclass(ty.Type[File], type) - - -def test_type_is_subclass2(): - assert not TypeParser.is_subclass(ty.Type[File], ty.Type[Json]) - - -def test_type_is_subclass3(): - assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File]) +@pytest.mark.parametrize( + ("sub", "super"), + [ + (ty.Type[File], type), + (ty.Type[Json], ty.Type[File]), + (ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml]), + (Json, ty.Union[Json, Yaml]), + (ty.List[int], list), + (None, ty.Union[int, None]), + (ty.Tuple[int, None], ty.Tuple[int, None]), + (None, None), + (None, type(None)), + (type(None), None), + (type(None), type(None)), + (type(None), type(None)), + ], +) +def test_subclass(sub, super): + assert TypeParser.is_subclass(sub, super) -def test_union_is_subclass1(): - assert TypeParser.is_subclass(ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml]) +@pytest.mark.parametrize( + ("sub", "super"), + [ + (ty.Type[File], ty.Type[Json]), + (ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml]), + (ty.Union[Json, Yaml], Json), + (list, ty.List[int]), + (ty.List[float], ty.List[int]), + (None, ty.Union[int, float]), + (None, int), + (int, None), + ], +) +def test_not_subclass(sub, super): + assert not TypeParser.is_subclass(sub, super) @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") @@ -844,18 +875,11 @@ def test_union_is_subclass1b(): assert TypeParser.is_subclass(Json | Yaml, ty.Union[Json, Yaml, Xml]) -## Up to here! - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_union_is_subclass1c(): assert TypeParser.is_subclass(ty.Union[Json, Yaml], Json | Yaml | Xml) -def test_union_is_subclass2(): - assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_union_is_subclass2a(): assert not TypeParser.is_subclass(Json | Yaml | Xml, Json | Yaml) @@ -871,86 +895,26 @@ def test_union_is_subclass2c(): assert not TypeParser.is_subclass(Json | Yaml | Xml, ty.Union[Json, Yaml]) -def test_union_is_subclass3(): - assert TypeParser.is_subclass(Json, ty.Union[Json, Yaml]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_union_is_subclass3a(): assert TypeParser.is_subclass(Json, Json | Yaml) -def test_union_is_subclass4(): - assert not TypeParser.is_subclass(ty.Union[Json, Yaml], Json) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_union_is_subclass4a(): assert not TypeParser.is_subclass(Json | Yaml, Json) -def test_generic_is_subclass1(): - assert TypeParser.is_subclass(ty.List[int], list) - - -def test_generic_is_subclass2(): - assert not TypeParser.is_subclass(list, ty.List[int]) - - -def test_generic_is_subclass3(): - assert not TypeParser.is_subclass(ty.List[float], ty.List[int]) - - -def test_none_is_subclass1(): - assert TypeParser.is_subclass(None, ty.Union[int, None]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_none_is_subclass1a(): assert TypeParser.is_subclass(None, int | None) -def test_none_is_subclass2(): - assert not TypeParser.is_subclass(None, ty.Union[int, float]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_none_is_subclass2a(): assert not TypeParser.is_subclass(None, int | float) -def test_none_is_subclass3(): - assert TypeParser.is_subclass(ty.Tuple[int, None], ty.Tuple[int, None]) - - -def test_none_is_subclass4(): - assert TypeParser.is_subclass(None, None) - - -def test_none_is_subclass5(): - assert not TypeParser.is_subclass(None, int) - - -def test_none_is_subclass6(): - assert not TypeParser.is_subclass(int, None) - - -def test_none_is_subclass7(): - assert TypeParser.is_subclass(None, type(None)) - - -def test_none_is_subclass8(): - assert TypeParser.is_subclass(type(None), None) - - -def test_none_is_subclass9(): - assert TypeParser.is_subclass(type(None), type(None)) - - -def test_none_is_subclass10(): - assert TypeParser.is_subclass(type(None), type(None)) - - @pytest.mark.skipif( sys.version_info < (3, 9), reason="Cannot subscript tuple in < Py3.9" ) @@ -980,40 +944,33 @@ class B(A): assert not TypeParser.is_subclass(MyTuple[B], ty.Tuple[A, int]) -def test_type_is_instance1(): - assert TypeParser.is_instance(File, ty.Type[File]) - - -def test_type_is_instance2(): - assert not TypeParser.is_instance(File, ty.Type[Json]) - - -def test_type_is_instance3(): - assert TypeParser.is_instance(Json, ty.Type[File]) - - -def test_type_is_instance4(): - assert TypeParser.is_instance(Json, type) - - -def test_type_is_instance5(): - assert TypeParser.is_instance(None, None) - - -def test_type_is_instance6(): - assert TypeParser.is_instance(None, type(None)) - - -def test_type_is_instance7(): - assert not TypeParser.is_instance(None, int) - - -def test_type_is_instance8(): - assert not TypeParser.is_instance(1, None) +@pytest.mark.parametrize( + ("tp", "obj"), + [ + (File, ty.Type[File]), + (Json, ty.Type[File]), + (Json, type), + (None, None), + (None, type(None)), + (None, ty.Union[int, None]), + (1, ty.Union[int, None]), + ], +) +def test_type_is_instance(tp, obj): + assert TypeParser.is_instance(tp, obj) -def test_type_is_instance9(): - assert TypeParser.is_instance(None, ty.Union[int, None]) +@pytest.mark.parametrize( + ("tp", "obj"), + [ + (File, ty.Type[Json]), + (None, int), + (1, None), + (None, ty.Union[int, str]), + ], +) +def test_type_is_not_instance(tp, obj): + assert not TypeParser.is_instance(tp, obj) @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") @@ -1021,19 +978,11 @@ def test_type_is_instance9a(): assert TypeParser.is_instance(None, int | None) -def test_type_is_instance10(): - assert TypeParser.is_instance(1, ty.Union[int, None]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_is_instance10a(): assert TypeParser.is_instance(1, int | None) -def test_type_is_instance11(): - assert not TypeParser.is_instance(None, ty.Union[int, str]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_is_instance11a(): assert not TypeParser.is_instance(None, int | str) @@ -1059,32 +1008,27 @@ def test_multi_input_obj_coerce4a(): TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([[1]]) -def test_multi_input_obj_check_type1(): - TypeParser(MultiInputObj[str])(lz(str)) - - -def test_multi_input_obj_check_type2(): - TypeParser(MultiInputObj[str])(lz(ty.List[str])) - - -def test_multi_input_obj_check_type3(): - TypeParser(MultiInputObj[ty.List[str]])(lz(ty.List[str])) - - -def test_multi_input_obj_check_type3a(): - TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[str])) - - -def test_multi_input_obj_check_type3b(): - TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[ty.List[str]])) - - -def test_multi_input_obj_check_type4(): - TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[int])) +@pytest.mark.parametrize( + ("reference", "to_be_checked"), + [ + (MultiInputObj[str], str), + (MultiInputObj[str], ty.List[str]), + (MultiInputObj[ty.List[str]], ty.List[str]), + (MultiInputObj[ty.Union[int, ty.List[str]]], ty.List[str]), + (MultiInputObj[ty.Union[int, ty.List[str]]], ty.List[ty.List[str]]), + (MultiInputObj[ty.Union[int, ty.List[str]]], ty.List[int]), + ], +) +def test_multi_input_obj_check_type(reference, to_be_checked): + TypeParser(reference)(lz(to_be_checked)) -def test_multi_input_obj_check_type4a(): +@pytest.mark.parametrize( + ("reference", "to_be_checked"), + [ + (MultiInputObj[ty.Union[int, ty.List[str]]], ty.List[ty.List[int]]), + ], +) +def test_multi_input_obj_check_type_fail(reference, to_be_checked): with pytest.raises(TypeError): - TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])( - lz(ty.List[ty.List[int]]) - ) + TypeParser(reference)(lz(to_be_checked))