Skip to content

Commit

Permalink
Use self-defined Inference Error rather than a re-export
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Oct 7, 2023
1 parent 624eaef commit d560731
Show file tree
Hide file tree
Showing 14 changed files with 57 additions and 21 deletions.
9 changes: 8 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
Change log
==========

0.9.1 (2023-xx-xx)
0.10.0 (2023-xx-xx)
------------------

**Breaking change**

- Failures during type/shape inference and model validation now use exceptions defined in ``spox`` rather than passing through exceptions from ``onnx``. This addresses an issues where the type of the exceptions raised from ``onnx`` depends on the host platform.

0.9.1 (2023-10-05)
------------------

**Bug fix**
Expand Down
8 changes: 3 additions & 5 deletions src/spox/_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import onnx.shape_inference

InferenceError = onnx.shape_inference.InferenceError
ValidationError = onnx.checker.ValidationError
class InferenceError(ValueError):
"""An error originating from the type or shape inference."""


class InferenceWarning(Warning):
Expand All @@ -27,4 +25,4 @@ class BuildError(Exception):
pass


__all__ = ["InferenceWarning", "InferenceError", "ValidationError"]
__all__ = ["InferenceWarning", "InferenceError"]
5 changes: 4 additions & 1 deletion src/spox/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ._schemas import max_opset_policy
from ._type_system import Tensor, Type
from ._utils import from_array
from ._utils import infer_shapes as _infer_shapes
from ._var import Var


Expand Down Expand Up @@ -418,7 +419,9 @@ def to_onnx_model(
)

if infer_shapes:
model = onnx.shape_inference.infer_shapes(model)
model = _infer_shapes(
model, check_type=False, strict_mode=False, data_prop=False
)
if check_model:
onnx.checker.check_model(model, full_check=check_model >= 2)
return model
Expand Down
6 changes: 3 additions & 3 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ._scope import Scope
from ._shape import SimpleShape
from ._type_system import Optional, Sequence, Tensor, Type
from ._utils import from_array
from ._utils import from_array, infer_shapes
from ._value_prop import PropValueType

if TYPE_CHECKING:
Expand Down Expand Up @@ -134,11 +134,11 @@ def infer_output_types_onnx(self) -> Dict[str, Type]:

# Attempt to do shape inference - if an error is caught, we extend the traceback a bit
try:
typed_model = onnx.shape_inference.infer_shapes(
typed_model = infer_shapes(
model, check_type=True, strict_mode=True, data_prop=True
)
except InferenceError as e:
raise type(e)(
raise InferenceError(
f"{str(e)} -- for {self.schema.name}: {self.signature}"
) from e

Expand Down
25 changes: 24 additions & 1 deletion src/spox/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import numpy as np
import numpy.typing as npt
import onnx
from onnx import TensorProto
from onnx import ModelProto, TensorProto

from ._exceptions import InferenceError


def tensor_type_to_dtype(ttype: int) -> np.dtype:
Expand Down Expand Up @@ -61,3 +63,24 @@ def from_array(arr: np.ndarray, name: Optional[str] = None) -> TensorProto:
).flatten(),
raw=False,
)


def infer_shapes(
model: ModelProto, *, check_type: bool, strict_mode: bool, data_prop: bool
):
"""Infer the types and shapes of this model.
This function normalized the exception raised by onnx which appears to be platform dependent.
Raises
------
InferenceError :
If the type or shape inference failed.
"""
try:
return onnx.shape_inference.infer_shapes(
model, check_type=True, strict_mode=True, data_prop=True
)
except (onnx.shape_inference.InferenceError, RuntimeError) as e:
# onnx to raises a less descriptive `RuntimeError`s on MacOS out of needless cruelty
raise InferenceError(str(e)) from e
3 changes: 2 additions & 1 deletion src/spox/opset/ai/onnx/ml/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
AttrTensor,
AttrType,
)
from spox._exceptions import InferenceError # noqa: F401
from spox._fields import BaseAttributes, BaseInputs, BaseOutputs # noqa: F401
from spox._graph import Graph, subgraph # noqa: F401
from spox._internal_op import intro # noqa: F401
from spox._node import OpType # noqa: F401
from spox._standard import InferenceError, StandardNode # noqa: F401
from spox._standard import StandardNode # noqa: F401
from spox._type_system import Sequence as SpoxSequence # noqa: F401
from spox._type_system import Tensor, Type
from spox._value_prop import PropValueType # noqa: F401
Expand Down
3 changes: 2 additions & 1 deletion src/spox/opset/ai/onnx/v17.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
AttrTensor,
AttrType,
)
from spox._exceptions import InferenceError # noqa: F401
from spox._fields import BaseAttributes, BaseInputs, BaseOutputs # noqa: F401
from spox._graph import Graph, subgraph # noqa: F401
from spox._internal_op import intro # noqa: F401
from spox._node import OpType # noqa: F401
from spox._standard import InferenceError, StandardNode # noqa: F401
from spox._standard import StandardNode # noqa: F401
from spox._type_system import Sequence as SpoxSequence # noqa: F401
from spox._type_system import Tensor, Type
from spox._value_prop import PropValueType # noqa: F401
Expand Down
3 changes: 2 additions & 1 deletion src/spox/opset/ai/onnx/v18.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
AttrTensor,
AttrType,
)
from spox._exceptions import InferenceError # noqa: F401
from spox._fields import BaseAttributes, BaseInputs, BaseOutputs # noqa: F401
from spox._graph import Graph, subgraph # noqa: F401
from spox._internal_op import intro # noqa: F401
from spox._node import OpType # noqa: F401
from spox._standard import InferenceError, StandardNode # noqa: F401
from spox._standard import StandardNode # noqa: F401
from spox._type_system import Sequence as SpoxSequence # noqa: F401
from spox._type_system import Tensor, Type
from spox._value_prop import PropValueType # noqa: F401
Expand Down
3 changes: 2 additions & 1 deletion src/spox/opset/ai/onnx/v19.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
AttrTensor,
AttrType,
)
from spox._exceptions import InferenceError # noqa: F401
from spox._fields import BaseAttributes, BaseInputs, BaseOutputs # noqa: F401
from spox._graph import Graph, subgraph # noqa: F401
from spox._internal_op import intro # noqa: F401
from spox._node import OpType # noqa: F401
from spox._standard import InferenceError, StandardNode # noqa: F401
from spox._standard import StandardNode # noqa: F401
from spox._type_system import Sequence as SpoxSequence # noqa: F401
from spox._type_system import Tensor, Type
from spox._value_prop import PropValueType # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion tests/type_inference/test_compress.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest

import spox.opset.ai.onnx.v17 as op
from spox._exceptions import InferenceError
from spox._graph import arguments
from spox._standard import InferenceError
from spox._type_system import Tensor


Expand Down
2 changes: 1 addition & 1 deletion tests/type_inference/test_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import pytest

import spox.opset.ai.onnx.ml.v3 as op_ml
from spox._exceptions import InferenceError
from spox._graph import arguments
from spox._standard import InferenceError
from spox._type_system import Tensor


Expand Down
4 changes: 2 additions & 2 deletions tests/type_inference/test_one_hot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest

import spox.opset.ai.onnx.v17 as op
from spox._exceptions import InferenceError
from spox._graph import arguments
from spox._standard import InferenceError
from spox._type_system import Tensor


Expand All @@ -29,7 +29,7 @@ def test_one_hot_inference_checks_axis_in_range():
x, y, z = arguments(
x=Tensor(int, ("N", "M")), y=Tensor(int, ()), z=Tensor(float, (2,))
)
with pytest.raises(InferenceError):
with pytest.raises(ValueError):
assert op.one_hot(x, y, z, axis=-4)
with pytest.raises(InferenceError):
assert op.one_hot(x, y, z, axis=3)
2 changes: 1 addition & 1 deletion tests/type_inference/test_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import pytest

import spox.opset.ai.onnx.ml.v3 as op_ml
from spox._exceptions import InferenceError
from spox._graph import arguments
from spox._standard import InferenceError
from spox._type_system import Tensor


Expand Down
3 changes: 2 additions & 1 deletion tools/templates/preamble.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ from spox._attributes import (
AttrTensor,
AttrType,
)
from spox._exceptions import InferenceError # noqa: F401
from spox._graph import Graph, subgraph # noqa: F401
from spox._internal_op import intro # noqa: F401
from spox._node import OpType # noqa: F401
from spox._standard import InferenceError, StandardNode # noqa: F401
from spox._standard import StandardNode # noqa: F401
from spox._type_system import Tensor, Type, Sequence as SpoxSequence # noqa: F401
from spox._value_prop import PropValueType # noqa: F401

0 comments on commit d560731

Please sign in to comment.