Skip to content

Commit

Permalink
Christian's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SimeonStoykovQC committed Dec 5, 2023
1 parent 7e64651 commit 09df28f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 35 deletions.
37 changes: 5 additions & 32 deletions src/spox/_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@ def value(self) -> T:
def _validate(self):
try:
type_in_onnx = self._to_onnx_deref("dummy").type
except TypeError as e:
# TypeError: 'a' has type str, but expected one of: int -- this is raised from within protobuf
except Exception as e:
# Likely an error from within onnx/protobuf, such as:
# 1) AttributeError: 'int' object has no attribute 'encode'
# 2) TypeError: 'a' has type str, but expected one of: int -- this is raised from within protobuf
# when .extend()-ing with the wrong type.
if "but expected one of" in str(e):
raise self._get_pretty_type_exception() from e
else:
raise
raise self._get_pretty_type_exception() from e

if type_in_onnx != self._attribute_proto_type:
raise self._get_pretty_type_exception()
Expand Down Expand Up @@ -207,13 +206,6 @@ def _to_onnx_deref(self, key: str) -> AttributeProto:
class AttrFloat32s(_AttrIterable[float]):
_attribute_proto_type = AttributeProto.FLOATS

def _to_onnx_deref(self, key: str) -> AttributeProto:
try:
transformed = [float(v) for v in self.value]
except (ValueError, TypeError) as e:
raise AttributeTypeError("Attribute values don't seem to be floats.") from e
return make_attribute(key, transformed, attr_type=self._attribute_proto_type)


class AttrInt64s(_AttrIterable[int]):
_attribute_proto_type = AttributeProto.INTS
Expand All @@ -222,29 +214,10 @@ class AttrInt64s(_AttrIterable[int]):
class AttrStrings(_AttrIterable[str]):
_attribute_proto_type = AttributeProto.STRINGS

def _to_onnx_deref(self, key: str) -> AttributeProto:
try:
transformed = [v.encode() for v in self.value]
except AttributeError as e:
raise AttributeTypeError(
"Attribute values don't seem to be strings."
) from e
return make_attribute(key, transformed, attr_type=self._attribute_proto_type)


class AttrTensors(_AttrIterable[np.ndarray]):
_attribute_proto_type = AttributeProto.TENSORS

def _to_onnx_deref(self, key: str) -> AttributeProto:
try:
transformed = [from_array(v) for v in self.value]
except AttributeError as e:
raise AttributeTypeError(
"Attribute values don't seem to be numpy arrays."
) from e

return make_attribute(key, transformed, attr_type=self._attribute_proto_type)


def _deref(ref: _Ref[T]) -> T:
if isinstance(ref._concrete._value, _Ref):
Expand Down
16 changes: 13 additions & 3 deletions tests/test_standard_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,16 @@ def test_multiple_outputs():
["a"],
"Unable to instantiate `AttrInt64s` with value of type `tuple[str, ...]`",
),
("value_strings", [1], "Attribute values don't seem to be strings."),
("value_floats", ["a"], "Attribute values don't seem to be floats."),
(
"value_strings",
[1],
"Unable to instantiate `AttrStrings` with value of type `tuple[int, ...]",
),
(
"value_floats",
["a"],
"Unable to instantiate `AttrFloat32s` with value of type `tuple[str, ...]",
),
(
"value_int",
"a",
Expand All @@ -92,6 +100,8 @@ def test_passing_wrong_type(key: str, values: Any, match: str):
def test_passing_wrong_type_tensors():
with pytest.raises(
AttributeTypeError,
match=re.escape("Attribute values don't seem to be numpy arrays."),
match=re.escape(
"Unable to instantiate `AttrTensors` with value of type `tuple[int, ...]"
),
):
AttrTensors([1]) # type: ignore

0 comments on commit 09df28f

Please sign in to comment.