Skip to content

Commit

Permalink
More initializers
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Nov 5, 2024
1 parent 49e366c commit f5af5d9
Show file tree
Hide file tree
Showing 17 changed files with 17,799 additions and 22,946 deletions.
2 changes: 1 addition & 1 deletion src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing_extensions import Self

from ._attributes import Attr
from ._exceptions import InferenceWarning
Expand Down Expand Up @@ -190,7 +191,6 @@ def vars(self, prop_values) -> Vars:

return self.Vars(**vars_structure)


@dataclass
class BaseOutputs(BaseVarInfos, metaclass=BaseVarsMeta):
@dataclass
Expand Down
11 changes: 6 additions & 5 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def constructor(self, attrs, inputs):
f"Function {type(self).__name__} does not implement a constructor."
)

def infer_output_types(self) -> dict[str, Type]:
def infer_output_types(self, initializers={}) -> dict[str, Type]:
from . import _graph

func_args_var = _graph.arguments_dict(
Expand Down Expand Up @@ -147,7 +147,7 @@ class Attributes(BaseAttributes):
op_type = OpType(name, domain, version)

def constructor(self, attrs, inputs):
return self.Outputs(*fun(*inputs.get_fields().values()))
return self.Outputs(*unwrap_vars(fun(*wrap_vars(inputs.get_fields().values()))))

return _Func

Expand Down Expand Up @@ -192,11 +192,12 @@ def init(*args: Var):

def alt_fun(*args: Var) -> Iterable[Var]:
cls = init(*args)
return wrap_vars(
cls(cls.Attributes(), cls.Inputs(*unwrap_vars(args)))
return [
Var(var_info)
for var_info in cls(cls.Attributes(), cls.Inputs(*unwrap_vars(args)))
.outputs.get_fields()
.values()
)
]

return alt_fun # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion src/spox/_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def opset_req(self) -> set[tuple[str, int]]:
("", INTERNAL_MIN_OPSET)
}

def infer_output_types(self) -> dict[str, Type]:
def infer_output_types(self, initializers={}) -> dict[str, Type]:
# First, type check that we match the ModelProto type requirements
for i, var in zip(self.graph.input, self.inputs.inputs):
if var.type is not None and not (
Expand Down
6 changes: 3 additions & 3 deletions src/spox/_internal_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def post_init(self, **kwargs):
if self.attrs.name is not None:
self.outputs.arg._rename(self.attrs.name.value)

def infer_output_types(self) -> dict[str, Type]:
def infer_output_types(self, initializers={}) -> dict[str, Type]:
# Output type is based on the value of the type attribute
return {"arg": self.attrs.type.value}

Expand Down Expand Up @@ -121,7 +121,7 @@ class Outputs(BaseOutputs):
inputs: BaseInputs
outputs: Outputs

def infer_output_types(self) -> dict[str, Type]:
def infer_output_types(self, initializers={}) -> dict[str, Type]:
# Output type is based on the value of the type attribute
arr = self.attrs.value.value
return {"arg": Tensor(arr.dtype, arr.shape)}
Expand Down Expand Up @@ -161,7 +161,7 @@ class Outputs(BaseOutputs):
inputs: Inputs
outputs: Outputs

def infer_output_types(self) -> dict[str, Type]:
def infer_output_types(self, initializers={}) -> dict[str, Type]:
return {
f"outputs_{i}": arr.type
for i, arr in enumerate(self.inputs.inputs)
Expand Down
10 changes: 5 additions & 5 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
out_variadic: Optional[int] = None,
infer_types: bool = True,
validate: bool = True,
initializers=[],
initializers={},
**kwargs,
):
"""
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
# As inference functions may access which output vars we initialized (e.g. variadics)
# we inject uninitialized vars first
self.outputs = self._init_output_vars()
self.inference(infer_types)
self.inference(infer_types, initializers)
else:
self.outputs = outputs

Expand Down Expand Up @@ -215,18 +215,18 @@ def propagate_values(self, initializers) -> dict[str, PropValueType]:
"""
return {}

def infer_output_types(self) -> dict[str, Type]:
def infer_output_types(self, initializers) -> dict[str, Type]:
"""
Inference routine for output types. Often overriden by inheriting Node types.
Returns a dictionary of output field names into Types for the respective VarInfos.
"""
return {}

def inference(self, infer_types: bool = True):
def inference(self, infer_types: bool = True, initializers={}):
# Type inference routine - call infer_output_types if required
# and check if it provides the expected outputs.
out_types = self.infer_output_types() if infer_types else {}
out_types = self.infer_output_types(initializers=initializers) if infer_types else {}

for key, var in self.outputs.get_vars().items():
if var.type is None: # If no existing type from init_output_vars
Expand Down
21 changes: 14 additions & 7 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from typing import TYPE_CHECKING, Callable

import onnx
from onnx.numpy_helper import from_array
import onnx.reference
import onnx.shape_inference
from onnx.defs import OpSchema
import numpy as np

from . import _value_prop
from ._exceptions import InferenceError
Expand All @@ -18,6 +20,7 @@
from ._shape import SimpleShape
from ._type_system import Optional, Sequence, Tensor, Type
from ._value_prop import PropValueType
from ._utils import from_array

if TYPE_CHECKING:
from ._graph import Graph
Expand Down Expand Up @@ -48,7 +51,7 @@ def min_output(self) -> int:
return self.schema.min_output

def to_singleton_onnx_model(
self, *, dummy_outputs: bool = True, with_dummy_subgraphs: bool = True
self, *, dummy_outputs: bool = True, with_dummy_subgraphs: bool = True, prop_values={}
) -> tuple[onnx.ModelProto, Scope]:
"""
Build a singleton model consisting of just this StandardNode. Used for type inference.
Expand Down Expand Up @@ -97,7 +100,11 @@ def out_value_info(curr_key, curr_var):
]
# Initializers, passed in to allow partial data propagation
# - used so that operators like Reshape are aware of constant shapes
initializers = []
initializers = [
from_array(prop.value, name) # type: ignore
for name, prop in prop_values.items()
if prop is not None and isinstance(prop.value, np.ndarray)
]
# Graph and model
graph = onnx.helper.make_graph(
[node_proto],
Expand All @@ -117,13 +124,13 @@ def out_value_info(curr_key, curr_var):
)
return model, scope

def infer_output_types_onnx(self) -> dict[str, Type]:
def infer_output_types_onnx(self, initializers={}) -> dict[str, Type]:
"""Execute type & shape inference with ``onnx.shape_inference.infer_node_outputs``."""
# Check that all (specified) inputs have known types, as otherwise we fail
if any(var.type is None for var in self.inputs.get_vars().values()):
return {}

model, _ = self.to_singleton_onnx_model()
model, _ = self.to_singleton_onnx_model(prop_values=initializers)

# Attempt to do shape inference - if an error is caught, we extend the traceback a bit
try:
Expand Down Expand Up @@ -161,7 +168,7 @@ def propagate_values_onnx(self, initializers) -> dict[str, PropValueType]:
if next(iter(self.subgraphs), None) is not None:
# Cannot do propagation with subgraphs implicitly for performance - should be reimplemented
return {}
model, scope = self.to_singleton_onnx_model(with_dummy_subgraphs=False)
model, scope = self.to_singleton_onnx_model(with_dummy_subgraphs=False, prop_values=initializers)
wrap_feed, run, unwrap_feed = _value_prop.get_backend_calls()
input_feed = {
scope.var[var_info]: wrap_feed(initializers[name])
Expand All @@ -179,8 +186,8 @@ def propagate_values_onnx(self, initializers) -> dict[str, PropValueType]:
}
return {k: v for k, v in results.items() if k is not None}

def infer_output_types(self) -> dict[str, Type]:
return self.infer_output_types_onnx()
def infer_output_types(self, initializers={}) -> dict[str, Type]:
return self.infer_output_types_onnx(initializers)

def propagate_values(self, initializers) -> dict[str, PropValueType]:
if _value_prop._VALUE_PROP_BACKEND != _value_prop.ValuePropBackend.NONE:
Expand Down
Loading

0 comments on commit f5af5d9

Please sign in to comment.