Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching for attributes that allow it #123

Merged
merged 7 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Change log
**Other changes**

- The validation of Node attributes has been improved and more consistent exceptions are raised if needed.
- ONNX node attributes are now computed only once and then cached so that the values are reused for validation and building the model.
cbourjau marked this conversation as resolved.
Show resolved Hide resolved


0.9.3 (2023-10-23)
Expand Down
89 changes: 54 additions & 35 deletions src/spox/_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,25 @@

class Attr(ABC, Generic[T]):
_value: Union[T, "_Ref[T]"]
_name: str
_cached_onnx: Optional[AttributeProto]

def __init__(self, value: Union[T, "_Ref[T]"]):
def __init__(self, value: Union[T, "_Ref[T]"], name: str):
self._value = value
self._name = name
self._cached_onnx = None

self._validate()

def deref(self) -> "Attr":
if isinstance(self._value, _Ref):
return type(self)(self.value, self._name)
else:
return self

@classmethod
def maybe(cls: Type[AttrT], value: Optional[T]) -> Optional[AttrT]:
return cls(value) if value is not None else None
def maybe(cls: Type[AttrT], value: Optional[T], name: str) -> Optional[AttrT]:
return cls(value, name) if value is not None else None

@property
def value(self) -> T:
Expand All @@ -41,7 +52,7 @@ def value(self) -> T:

def _validate(self):
try:
type_in_onnx = self._to_onnx_deref("dummy").type
type_in_onnx = self._to_onnx().type
except Exception as e:
# Likely an error from within onnx/protobuf, such as:
# 1) AttributeError: 'int' object has no attribute 'encode'
Expand All @@ -52,18 +63,19 @@ def _validate(self):
if type_in_onnx != self._attribute_proto_type:
raise self._get_pretty_type_exception()

def _to_onnx(self, key: str) -> AttributeProto:
def _to_onnx(self) -> AttributeProto:
if isinstance(self._value, _Ref):
return self._value._to_onnx(key)
return self._to_onnx_deref(key)
return self._value._to_onnx()
if self._cached_onnx is None:
self._cached_onnx = self._to_onnx_deref()
return self._cached_onnx

@property
@abc.abstractmethod
def _attribute_proto_type(self) -> int:
raise NotImplementedError()

@abc.abstractmethod
def _to_onnx_deref(self, key: str) -> AttributeProto:
def _to_onnx_deref(self) -> AttributeProto:
"""Conversion method for the dereferenced case."""
raise NotImplementedError()

Expand All @@ -87,57 +99,58 @@ class _Ref(Generic[T]):

_concrete: Attr[T]

def __init__(self, concrete: Attr[T], outer_name: str):
def __init__(self, concrete: Attr[T], outer_name: str, name: str):
self._concrete = concrete
self._outer_name = outer_name
self._name = name

def copy(self) -> "_Ref[T]":
return self

def _to_onnx(self, key: str) -> AttributeProto:
parent_type = self._concrete._to_onnx(key).type
def _to_onnx(self) -> AttributeProto:
parent_type = self._concrete._to_onnx().type
return AttributeProto(
name=key, ref_attr_name=self._outer_name, type=parent_type
name=self._name, ref_attr_name=self._outer_name, type=parent_type
)


class AttrFloat32(Attr[float]):
_attribute_proto_type = AttributeProto.FLOAT

def _to_onnx_deref(self, key: str) -> AttributeProto:
def _to_onnx_deref(self) -> AttributeProto:
if isinstance(self.value, int):
return make_attribute(key, float(self.value))
return make_attribute(key, self.value)
return make_attribute(self._name, float(self.value))
return make_attribute(self._name, self.value)


class AttrInt64(Attr[int]):
_attribute_proto_type = AttributeProto.INT

def _to_onnx_deref(self, key: str) -> AttributeProto:
return make_attribute(key, self.value)
def _to_onnx_deref(self) -> AttributeProto:
return make_attribute(self._name, self.value)


class AttrString(Attr[str]):
_attribute_proto_type = AttributeProto.STRING

def _to_onnx_deref(self, key: str) -> AttributeProto:
return make_attribute(key, self.value)
def _to_onnx_deref(self) -> AttributeProto:
return make_attribute(self._name, self.value)


class AttrTensor(Attr[np.ndarray]):
_attribute_proto_type = AttributeProto.TENSOR

def __init__(self, value: Union[np.ndarray, _Ref[np.ndarray]]):
super().__init__(value.copy())
def __init__(self, value: Union[np.ndarray, _Ref[np.ndarray]], name: str):
super().__init__(value.copy(), name)

def _to_onnx_deref(self, key: str) -> AttributeProto:
return make_attribute(key, from_array(self.value))
def _to_onnx_deref(self) -> AttributeProto:
return make_attribute(self._name, from_array(self.value))


class AttrType(Attr[_type_system.Type]):
_attribute_proto_type = AttributeProto.TYPE_PROTO

def _to_onnx_deref(self, key: str) -> AttributeProto:
def _to_onnx_deref(self) -> AttributeProto:
value = self.value # for type-checkers with limited property support
if isinstance(value, _type_system.Tensor):
type_proto = make_tensor_type_proto(
Expand All @@ -150,7 +163,7 @@ def _to_onnx_deref(self, key: str) -> AttributeProto:
type_proto = make_optional_type_proto(value.elem_type._to_onnx())
else:
raise NotImplementedError()
return make_attribute(key, type_proto)
return make_attribute(self._name, type_proto)


class AttrDtype(Attr[npt.DTypeLike]):
Expand All @@ -161,8 +174,8 @@ class AttrDtype(Attr[npt.DTypeLike]):
def _validate(self):
dtype_to_tensor_type(self.value)

def _to_onnx_deref(self, key: str) -> AttributeProto:
return make_attribute(key, dtype_to_tensor_type(self.value))
def _to_onnx_deref(self) -> AttributeProto:
return make_attribute(self._name, dtype_to_tensor_type(self.value))


class AttrGraph(Attr[Any]):
Expand All @@ -176,24 +189,30 @@ def _validate(self):
f"Expected value of type `spox.graph.Graph found `{type(self.value)}`"
)

def _to_onnx_deref(self, key: str) -> AttributeProto:
def _to_onnx_deref(self) -> AttributeProto:
raise TypeError(
"Graph attributes must be built using the `build_subgraph` callback in `Node.to_onnx`."
)


class _AttrIterable(Attr[Tuple[S, ...]], ABC):
def __init__(self, value: Union[Iterable[S], _Ref[Tuple[S, ...]]]):
super().__init__(value if isinstance(value, _Ref) else tuple(value))
def __init__(self, value: Union[Iterable[S], _Ref[Tuple[S, ...]]], name: str):
super().__init__(
value=value if isinstance(value, _Ref) else tuple(value), name=name
)

@classmethod
def maybe(
cls: Type[AttrIterableT], value: Optional[Iterable[S]]
cls: Type[AttrIterableT],
value: Optional[Iterable[S]],
name: str,
) -> Optional[AttrIterableT]:
return cls(tuple(value)) if value is not None else None
return cls(tuple(value), name) if value is not None else None

def _to_onnx_deref(self, key: str) -> AttributeProto:
return make_attribute(key, self.value, attr_type=self._attribute_proto_type)
def _to_onnx_deref(self) -> AttributeProto:
return make_attribute(
self._name, self.value, attr_type=self._attribute_proto_type
)


class AttrFloat32s(_AttrIterable[float]):
Expand Down
7 changes: 3 additions & 4 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Function(_InternalNode):
"""

func_args: Dict[str, Var]
func_attrs: Dict[str, _attributes._Ref]
func_attrs: Dict[str, _attributes.Attr]
func_inputs: BaseInputs
func_outputs: BaseOutputs
func_graph: "_graph.Graph"
Expand All @@ -64,14 +64,13 @@ def infer_output_types(self) -> Dict[str, Type]:
**{name: var.type for name, var in self.inputs.get_vars().items()}
)

func_attrs = {}
self.func_attrs = {}
for name, attr in self.attrs.get_fields().items():
if attr is None:
raise TypeError(
f"Function attributes is not optional, but {name} is None."
)
func_attrs[name] = _attributes._Ref(concrete=attr, outer_name=name)
self.func_attrs = func_attrs
self.func_attrs[name] = attr

self.func_inputs = self.Inputs(**self.func_args) # type: ignore
self.func_outputs = self.constructor(self.func_attrs, self.func_inputs)
Expand Down
14 changes: 10 additions & 4 deletions src/spox/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,23 @@ def arguments_dict(**kwargs: Optional[Union[Type, numpy.ndarray]]) -> Dict[str,
"""
result = {}
for name, info in kwargs.items():
attr_name = AttrString(name)
attr_name = AttrString(value=name, name="dummy")
if isinstance(info, Type):
result[name] = Argument(
Argument.Attributes(name=attr_name, type=AttrType(info), default=None),
Argument.Attributes(
name=attr_name,
type=AttrType(value=info, name="dummy"),
default=None,
),
BaseInputs(),
).outputs.arg
elif isinstance(info, numpy.ndarray):
ty = Tensor(info.dtype, info.shape)
result[name] = Argument(
Argument.Attributes(
name=attr_name, type=AttrType(ty), default=AttrTensor(info)
name=attr_name,
type=AttrType(value=ty, name="dummy"),
default=AttrTensor(value=info, name="dummy"),
),
BaseInputs(),
).outputs.arg
Expand Down Expand Up @@ -101,7 +107,7 @@ def initializer(arr: numpy.ndarray) -> Var:
Var which is always equal to the respective value provided by `arr`.
"""
return _Initializer(
_Initializer.Attributes(value=AttrTensor(arr)),
_Initializer.Attributes(value=AttrTensor(value=arr, name="dummy")),
BaseInputs(),
).outputs.arg

Expand Down
2 changes: 1 addition & 1 deletion src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def to_onnx(
subgraph = build_subgraph(self, key, attr.value)
attr_proto = onnx.helper.make_attribute(key, subgraph)
else:
attr_proto = attr._to_onnx(key)
attr_proto = attr._to_onnx()
node_proto.attribute.append(attr_proto)

return [node_proto]
2 changes: 1 addition & 1 deletion src/spox/_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def argument(typ: Type) -> Var:
a model input to build a graph.
"""
return _internal_op.Argument(
_internal_op.Argument.Attributes(type=AttrType(typ), default=None)
_internal_op.Argument.Attributes(type=AttrType(typ, "dummy"), default=None)
).outputs.arg


Expand Down
8 changes: 2 additions & 6 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,9 @@ def to_singleton_onnx_model(
# We inject the evaluated attribute values here and then substitute back
self_attrs = self.attrs
try:
# Get exact attribute values to run inference (as
# otherwise refs aren't handled properly).
current_fields = self_attrs.get_fields().items()
SimeonStoykovQC marked this conversation as resolved.
Show resolved Hide resolved
self.attrs = self.Attributes(
**{
k: type(v)(v.value) if v is not None else v
for k, v in self.attrs.get_fields().items()
}
**{k: v.deref() if v is not None else None for k, v in current_fields}
)
node_proto: onnx.NodeProto
# Subgraphs are not fully built for possibly significant performance gains.
Expand Down
Loading