diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 27d8b7f8..c0e3db09 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,7 +7,14 @@ Change log ========== -0.9.1 (2023-xx-xx) +0.10.0 (2023-xx-xx) +------------------ + +**Breaking change** + +- Failures during type/shape inference and model validation now use exceptions defined in ``spox`` rather than passing through exceptions from ``onnx``. This addresses an issues where the type of the exceptions raised from ``onnx`` depends on the host platform. + +0.9.1 (2023-10-05) ------------------ **Bug fix** diff --git a/src/spox/_exceptions.py b/src/spox/_exceptions.py index 91fd34d8..95cd6bbc 100644 --- a/src/spox/_exceptions.py +++ b/src/spox/_exceptions.py @@ -1,7 +1,5 @@ -import onnx.shape_inference - -InferenceError = onnx.shape_inference.InferenceError -ValidationError = onnx.checker.ValidationError +class InferenceError(ValueError): + """An error originating from the type or shape inference.""" class InferenceWarning(Warning): @@ -27,4 +25,4 @@ class BuildError(Exception): pass -__all__ = ["InferenceWarning", "InferenceError", "ValidationError"] +__all__ = ["InferenceWarning", "InferenceError"] diff --git a/src/spox/_graph.py b/src/spox/_graph.py index 11d2c5ee..edc8f9ab 100644 --- a/src/spox/_graph.py +++ b/src/spox/_graph.py @@ -18,6 +18,7 @@ from ._schemas import max_opset_policy from ._type_system import Tensor, Type from ._utils import from_array +from ._utils import infer_shapes as _infer_shapes from ._var import Var @@ -418,7 +419,9 @@ def to_onnx_model( ) if infer_shapes: - model = onnx.shape_inference.infer_shapes(model) + model = _infer_shapes( + model, check_type=False, strict_mode=False, data_prop=False + ) if check_model: onnx.checker.check_model(model, full_check=check_model >= 2) return model diff --git a/src/spox/_standard.py b/src/spox/_standard.py index aca55a41..0184ca4b 100644 --- a/src/spox/_standard.py +++ b/src/spox/_standard.py @@ -15,7 +15,7 @@ from ._scope import Scope from ._shape import SimpleShape from ._type_system import Optional, Sequence, Tensor, Type -from ._utils import from_array +from ._utils import from_array, infer_shapes from ._value_prop import PropValueType if TYPE_CHECKING: @@ -134,11 +134,11 @@ def infer_output_types_onnx(self) -> Dict[str, Type]: # Attempt to do shape inference - if an error is caught, we extend the traceback a bit try: - typed_model = onnx.shape_inference.infer_shapes( + typed_model = infer_shapes( model, check_type=True, strict_mode=True, data_prop=True ) except InferenceError as e: - raise type(e)( + raise InferenceError( f"{str(e)} -- for {self.schema.name}: {self.signature}" ) from e diff --git a/src/spox/_utils.py b/src/spox/_utils.py index ab4ccfb4..5c344cb1 100644 --- a/src/spox/_utils.py +++ b/src/spox/_utils.py @@ -3,7 +3,9 @@ import numpy as np import numpy.typing as npt import onnx -from onnx import TensorProto +from onnx import ModelProto, TensorProto + +from ._exceptions import InferenceError def tensor_type_to_dtype(ttype: int) -> np.dtype: @@ -61,3 +63,24 @@ def from_array(arr: np.ndarray, name: Optional[str] = None) -> TensorProto: ).flatten(), raw=False, ) + + +def infer_shapes( + model: ModelProto, *, check_type: bool, strict_mode: bool, data_prop: bool +): + """Infer the types and shapes of this model. + + This function normalized the exception raised by onnx which appears to be platform dependent. + + Raises + ------ + InferenceError : + If the type or shape inference failed. + """ + try: + return onnx.shape_inference.infer_shapes( + model, check_type=True, strict_mode=True, data_prop=True + ) + except (onnx.shape_inference.InferenceError, RuntimeError) as e: + # onnx to raises a less descriptive `RuntimeError`s on MacOS out of needless cruelty + raise InferenceError(str(e)) from e diff --git a/src/spox/opset/ai/onnx/ml/v3.py b/src/spox/opset/ai/onnx/ml/v3.py index 5ef6d835..767c3cd2 100644 --- a/src/spox/opset/ai/onnx/ml/v3.py +++ b/src/spox/opset/ai/onnx/ml/v3.py @@ -30,11 +30,12 @@ AttrTensor, AttrType, ) +from spox._exceptions import InferenceError # noqa: F401 from spox._fields import BaseAttributes, BaseInputs, BaseOutputs # noqa: F401 from spox._graph import Graph, subgraph # noqa: F401 from spox._internal_op import intro # noqa: F401 from spox._node import OpType # noqa: F401 -from spox._standard import InferenceError, StandardNode # noqa: F401 +from spox._standard import StandardNode # noqa: F401 from spox._type_system import Sequence as SpoxSequence # noqa: F401 from spox._type_system import Tensor, Type from spox._value_prop import PropValueType # noqa: F401 diff --git a/src/spox/opset/ai/onnx/v17.py b/src/spox/opset/ai/onnx/v17.py index 896e3dfc..e4df1069 100644 --- a/src/spox/opset/ai/onnx/v17.py +++ b/src/spox/opset/ai/onnx/v17.py @@ -30,11 +30,12 @@ AttrTensor, AttrType, ) +from spox._exceptions import InferenceError # noqa: F401 from spox._fields import BaseAttributes, BaseInputs, BaseOutputs # noqa: F401 from spox._graph import Graph, subgraph # noqa: F401 from spox._internal_op import intro # noqa: F401 from spox._node import OpType # noqa: F401 -from spox._standard import InferenceError, StandardNode # noqa: F401 +from spox._standard import StandardNode # noqa: F401 from spox._type_system import Sequence as SpoxSequence # noqa: F401 from spox._type_system import Tensor, Type from spox._value_prop import PropValueType # noqa: F401 diff --git a/src/spox/opset/ai/onnx/v18.py b/src/spox/opset/ai/onnx/v18.py index 0cd9e3ce..196055a1 100644 --- a/src/spox/opset/ai/onnx/v18.py +++ b/src/spox/opset/ai/onnx/v18.py @@ -30,11 +30,12 @@ AttrTensor, AttrType, ) +from spox._exceptions import InferenceError # noqa: F401 from spox._fields import BaseAttributes, BaseInputs, BaseOutputs # noqa: F401 from spox._graph import Graph, subgraph # noqa: F401 from spox._internal_op import intro # noqa: F401 from spox._node import OpType # noqa: F401 -from spox._standard import InferenceError, StandardNode # noqa: F401 +from spox._standard import StandardNode # noqa: F401 from spox._type_system import Sequence as SpoxSequence # noqa: F401 from spox._type_system import Tensor, Type from spox._value_prop import PropValueType # noqa: F401 diff --git a/src/spox/opset/ai/onnx/v19.py b/src/spox/opset/ai/onnx/v19.py index 6695deb6..f1d829a0 100644 --- a/src/spox/opset/ai/onnx/v19.py +++ b/src/spox/opset/ai/onnx/v19.py @@ -30,11 +30,12 @@ AttrTensor, AttrType, ) +from spox._exceptions import InferenceError # noqa: F401 from spox._fields import BaseAttributes, BaseInputs, BaseOutputs # noqa: F401 from spox._graph import Graph, subgraph # noqa: F401 from spox._internal_op import intro # noqa: F401 from spox._node import OpType # noqa: F401 -from spox._standard import InferenceError, StandardNode # noqa: F401 +from spox._standard import StandardNode # noqa: F401 from spox._type_system import Sequence as SpoxSequence # noqa: F401 from spox._type_system import Tensor, Type from spox._value_prop import PropValueType # noqa: F401 diff --git a/tests/type_inference/test_compress.py b/tests/type_inference/test_compress.py index 9ec5acd6..6f4ee9c0 100644 --- a/tests/type_inference/test_compress.py +++ b/tests/type_inference/test_compress.py @@ -1,8 +1,8 @@ import pytest import spox.opset.ai.onnx.v17 as op +from spox._exceptions import InferenceError from spox._graph import arguments -from spox._standard import InferenceError from spox._type_system import Tensor diff --git a/tests/type_inference/test_imputer.py b/tests/type_inference/test_imputer.py index 0686b33d..217ccd50 100644 --- a/tests/type_inference/test_imputer.py +++ b/tests/type_inference/test_imputer.py @@ -2,8 +2,8 @@ import pytest import spox.opset.ai.onnx.ml.v3 as op_ml +from spox._exceptions import InferenceError from spox._graph import arguments -from spox._standard import InferenceError from spox._type_system import Tensor diff --git a/tests/type_inference/test_one_hot.py b/tests/type_inference/test_one_hot.py index b7a80460..4f27d51c 100644 --- a/tests/type_inference/test_one_hot.py +++ b/tests/type_inference/test_one_hot.py @@ -1,8 +1,8 @@ import pytest import spox.opset.ai.onnx.v17 as op +from spox._exceptions import InferenceError from spox._graph import arguments -from spox._standard import InferenceError from spox._type_system import Tensor @@ -29,7 +29,7 @@ def test_one_hot_inference_checks_axis_in_range(): x, y, z = arguments( x=Tensor(int, ("N", "M")), y=Tensor(int, ()), z=Tensor(float, (2,)) ) - with pytest.raises(InferenceError): + with pytest.raises(ValueError): assert op.one_hot(x, y, z, axis=-4) with pytest.raises(InferenceError): assert op.one_hot(x, y, z, axis=3) diff --git a/tests/type_inference/test_scaler.py b/tests/type_inference/test_scaler.py index e41a4638..57e7ebfb 100644 --- a/tests/type_inference/test_scaler.py +++ b/tests/type_inference/test_scaler.py @@ -2,8 +2,8 @@ import pytest import spox.opset.ai.onnx.ml.v3 as op_ml +from spox._exceptions import InferenceError from spox._graph import arguments -from spox._standard import InferenceError from spox._type_system import Tensor diff --git a/tools/templates/preamble.jinja2 b/tools/templates/preamble.jinja2 index 03f8480b..10edc40e 100644 --- a/tools/templates/preamble.jinja2 +++ b/tools/templates/preamble.jinja2 @@ -32,9 +32,10 @@ from spox._attributes import ( AttrTensor, AttrType, ) +from spox._exceptions import InferenceError # noqa: F401 from spox._graph import Graph, subgraph # noqa: F401 from spox._internal_op import intro # noqa: F401 from spox._node import OpType # noqa: F401 -from spox._standard import InferenceError, StandardNode # noqa: F401 +from spox._standard import StandardNode # noqa: F401 from spox._type_system import Tensor, Type, Sequence as SpoxSequence # noqa: F401 from spox._value_prop import PropValueType # noqa: F401