diff --git a/src/spox/_attributes.py b/src/spox/_attributes.py index f1223a05..7049c09d 100644 --- a/src/spox/_attributes.py +++ b/src/spox/_attributes.py @@ -48,13 +48,12 @@ def value(self) -> T: def _validate(self): try: type_in_onnx = self._to_onnx_deref("dummy").type - except TypeError as e: - # TypeError: 'a' has type str, but expected one of: int -- this is raised from within protobuf + except Exception as e: + # Likely an error from within onnx/protobuf, such as: + # 1) AttributeError: 'int' object has no attribute 'encode' + # 2) TypeError: 'a' has type str, but expected one of: int -- this is raised from within protobuf # when .extend()-ing with the wrong type. - if "but expected one of" in str(e): - raise self._get_pretty_type_exception() from e - else: - raise + raise self._get_pretty_type_exception() from e if type_in_onnx != self._attribute_proto_type: raise self._get_pretty_type_exception() @@ -207,13 +206,6 @@ def _to_onnx_deref(self, key: str) -> AttributeProto: class AttrFloat32s(_AttrIterable[float]): _attribute_proto_type = AttributeProto.FLOATS - def _to_onnx_deref(self, key: str) -> AttributeProto: - try: - transformed = [float(v) for v in self.value] - except (ValueError, TypeError) as e: - raise AttributeTypeError("Attribute values don't seem to be floats.") from e - return make_attribute(key, transformed, attr_type=self._attribute_proto_type) - class AttrInt64s(_AttrIterable[int]): _attribute_proto_type = AttributeProto.INTS @@ -222,29 +214,10 @@ class AttrInt64s(_AttrIterable[int]): class AttrStrings(_AttrIterable[str]): _attribute_proto_type = AttributeProto.STRINGS - def _to_onnx_deref(self, key: str) -> AttributeProto: - try: - transformed = [v.encode() for v in self.value] - except AttributeError as e: - raise AttributeTypeError( - "Attribute values don't seem to be strings." - ) from e - return make_attribute(key, transformed, attr_type=self._attribute_proto_type) - class AttrTensors(_AttrIterable[np.ndarray]): _attribute_proto_type = AttributeProto.TENSORS - def _to_onnx_deref(self, key: str) -> AttributeProto: - try: - transformed = [from_array(v) for v in self.value] - except AttributeError as e: - raise AttributeTypeError( - "Attribute values don't seem to be numpy arrays." - ) from e - - return make_attribute(key, transformed, attr_type=self._attribute_proto_type) - def _deref(ref: _Ref[T]) -> T: if isinstance(ref._concrete._value, _Ref): diff --git a/tests/test_standard_op.py b/tests/test_standard_op.py index 50b553a7..c4dc77ce 100644 --- a/tests/test_standard_op.py +++ b/tests/test_standard_op.py @@ -75,8 +75,16 @@ def test_multiple_outputs(): ["a"], "Unable to instantiate `AttrInt64s` with value of type `tuple[str, ...]`", ), - ("value_strings", [1], "Attribute values don't seem to be strings."), - ("value_floats", ["a"], "Attribute values don't seem to be floats."), + ( + "value_strings", + [1], + "Unable to instantiate `AttrStrings` with value of type `tuple[int, ...]", + ), + ( + "value_floats", + ["a"], + "Unable to instantiate `AttrFloat32s` with value of type `tuple[str, ...]", + ), ( "value_int", "a", @@ -92,6 +100,8 @@ def test_passing_wrong_type(key: str, values: Any, match: str): def test_passing_wrong_type_tensors(): with pytest.raises( AttributeTypeError, - match=re.escape("Attribute values don't seem to be numpy arrays."), + match=re.escape( + "Unable to instantiate `AttrTensors` with value of type `tuple[int, ...]" + ), ): AttrTensors([1]) # type: ignore