From 8c8c05a01db829a5873ed460fef7aa23da788387 Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Mon, 2 Dec 2024 18:50:47 +0200 Subject: [PATCH] Left future --- src/spox/_adapt.py | 2 +- src/spox/_build.py | 8 ++++---- src/spox/_function.py | 40 +++++++++++++++++++++++++++++----------- src/spox/_graph.py | 8 ++++---- src/spox/_inline.py | 7 +++++-- src/spox/_internal_op.py | 37 ++++++++++++++++++++++++++++++------- src/spox/_node.py | 15 +++++++++++---- src/spox/_public.py | 5 +++-- src/spox/_schemas.py | 19 ++++++++++--------- src/spox/_scope.py | 21 ++++++++++++--------- src/spox/_standard.py | 3 ++- src/spox/_var.py | 2 +- 12 files changed, 112 insertions(+), 55 deletions(-) diff --git a/src/spox/_adapt.py b/src/spox/_adapt.py index bdd3cb28..35aeec19 100644 --- a/src/spox/_adapt.py +++ b/src/spox/_adapt.py @@ -80,7 +80,7 @@ def adapt_inline( base_model = node.model try: node.model = target_model - target_nodes = node.to_onnx(Scope.of((node, node_name), *var_names.items())) + target_nodes = node.to_onnx(Scope.of((node_name, node), *var_names.items())) finally: node.model = base_model return target_nodes diff --git a/src/spox/_build.py b/src/spox/_build.py index 60c544a0..3fe739d4 100644 --- a/src/spox/_build.py +++ b/src/spox/_build.py @@ -44,7 +44,7 @@ def value(self) -> T: return self._value @value.setter - def value(self, to: T): + def value(self, to: T) -> None: self._value = to @@ -124,7 +124,7 @@ class ScopeTree: subgraph_owner: dict["Graph", Node] scope_of: dict[Node, "Graph"] - def __init__(self): + def __init__(self) -> None: self.subgraph_owner = {} self.scope_of = {} @@ -255,7 +255,7 @@ def discover(self, graph: "Graph") -> tuple[set[Var], set[Var]]: claimed_arguments = self.claimed_arguments_in[graph] = set() used_arguments = set() - def collect_arguments(nd: Node): + def collect_arguments(nd: Node) -> None: nonlocal all_arguments, claimed_arguments, used_arguments if isinstance(nd, Argument): all_arguments.add(nd.outputs.arg) @@ -329,7 +329,7 @@ def update_scope_tree(self, graph: "Graph") -> None: is completed 'bottom-up'. """ - def satisfy_constraints(node): + def satisfy_constraints(node: Node) -> None: # By default, a node is bound to the scope it is found in. self.scope_tree.scope_of.setdefault(node, graph) # Bring up the scope of its node to its ancestors if it is too low to be accessible in the current graph. diff --git a/src/spox/_function.py b/src/spox/_function.py index 344d9eec..7d47d531 100644 --- a/src/spox/_function.py +++ b/src/spox/_function.py @@ -3,10 +3,11 @@ import inspect import itertools -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from dataclasses import dataclass, make_dataclass -from typing import TYPE_CHECKING, Callable, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union +import numpy as np import onnx from . import _attributes @@ -48,7 +49,9 @@ class Function(_InternalNode): func_outputs: BaseOutputs func_graph: "_graph.Graph" - def constructor(self, attrs, inputs): + def constructor( + self, attrs: dict[str, _attributes.Attr], inputs: BaseInputs + ) -> BaseOutputs: """ Abstract method for functions. @@ -89,11 +92,16 @@ def infer_output_types(self) -> dict[str, Type]: } @property - def opset_req(self): + def opset_req(self) -> set[tuple[str, int]]: node_opset_req = Node.opset_req.fget(self) # type: ignore return node_opset_req | self.func_graph._get_build_result().opset_req - def update_metadata(self, opset_req, initializers, functions): + def update_metadata( + self, + opset_req: set[tuple[str, int]], + initializers: dict[Var, np.ndarray], + functions: list["Function"], + ) -> None: super().update_metadata(opset_req, initializers, functions) functions.append(self) functions.extend(self.func_graph._get_build_result().functions) @@ -123,10 +131,18 @@ def to_onnx_function( ) -def _make_function_cls(fun, num_inputs, num_outputs, domain, version, name): +def _make_function_cls( + fun: Callable[..., Any], + num_inputs: int, + num_outputs: int, + domain: str, + version: int, + name: str, +) -> type[Function]: _FuncInputs = make_dataclass( "_FuncInputs", ((f"in{i}", Var) for i in range(num_inputs)), bases=(BaseInputs,) ) + _FuncOutputs = make_dataclass( "_FuncOutputs", ((f"out{i}", Var) for i in range(num_outputs)), @@ -142,13 +158,15 @@ class Attributes(BaseAttributes): Outputs = _FuncOutputs op_type = OpType(name, domain, version) - def constructor(self, attrs, inputs): + def constructor(self, attrs: dict[str, _attributes.Attr], inputs: Any) -> Any: return self.Outputs(*fun(*inputs.get_fields().values())) return _Func -def to_function(name: str, domain: str = "spox.function", *, _version: int = 0): +def to_function( + name: str, domain: str = "spox.function", *, _version: int = 0 +) -> Callable: """ Decorate a given function to make the operation performed by it add a Spox function to the graph. @@ -176,7 +194,7 @@ def get_num_outputs(*args: Var) -> int: _num_outputs = sum(1 for _ in fun(*args)) return _num_outputs - def init(*args: Var): + def init(*args: Var) -> type[Function]: nonlocal _cls if _cls is not None: return _cls @@ -186,9 +204,9 @@ def init(*args: Var): ) return _cls - def alt_fun(*args: Var) -> Iterable[Var]: + def alt_fun(*args: Var) -> Iterable[Union[Var, Optional[Var], Sequence[Var]]]: cls = init(*args) - return ( + return list( cls(cls.Attributes(), cls.Inputs(*args)).outputs.get_fields().values() ) diff --git a/src/spox/_graph.py b/src/spox/_graph.py index 407f24b0..2281215b 100644 --- a/src/spox/_graph.py +++ b/src/spox/_graph.py @@ -143,7 +143,7 @@ class Graph: default_factory=_build.Cached ) - def __repr__(self): + def __repr__(self) -> str: name_repr = self._name if self._name is not None else "?" args_repr = ( f"{', '.join(str(a) for a in self._arguments)}" @@ -158,7 +158,7 @@ def __repr__(self): comments.append(f"+{len(self._extra_opset_req)} opset req") return f" ({res_repr}){': ' if comments else ''}{', '.join(comments)}>" - def __post_init__(self): + def __post_init__(self) -> None: if any(not isinstance(var, Var) for var in self._results.values()): seen_types = {type(obj) for obj in self._results.values()} raise TypeError(f"Graph results must be Vars, not {seen_types - {Var}}.") @@ -362,7 +362,7 @@ def to_onnx_model( model_doc_string: str = "", infer_shapes: bool = False, check_model: Union[Literal[0], Literal[1], Literal[2]] = 1, - ir_version=8, + ir_version: int = 8, concrete: bool = True, ) -> onnx.ModelProto: """ @@ -448,7 +448,7 @@ def results(**kwargs: Var) -> Graph: return Graph(kwargs) -def enum_results(*vars: Var, prefix="out") -> Graph: +def enum_results(*vars: Var, prefix: str = "out") -> Graph: """ Use this function to construct a ``Graph`` object, whenever the exact names are not important. Useful when creating subgraphs. diff --git a/src/spox/_inline.py b/src/spox/_inline.py index 21b74c11..678f9777 100644 --- a/src/spox/_inline.py +++ b/src/spox/_inline.py @@ -26,7 +26,7 @@ def rename_in_graph( rename_node: Optional[Callable[[str], str]] = None, rename_op: Optional[Callable[[str, str], tuple[str, str]]] = None, ) -> onnx.GraphProto: - def rename_in_subgraph(subgraph): + def rename_in_subgraph(subgraph: onnx.GraphProto) -> onnx.GraphProto: return rename_in_graph( subgraph, rename, @@ -146,7 +146,10 @@ def propagate_values(self) -> dict[str, _value_prop.PropValueType]: } def to_onnx( - self, scope: Scope, doc_string: Optional[str] = None, build_subgraph=None + self, + scope: Scope, + doc_string: Optional[str] = None, + build_subgraph: Optional[Callable] = None, ) -> list[onnx.NodeProto]: input_names: dict[str, int] = { p.name: i for i, p in enumerate(self.graph.input) diff --git a/src/spox/_internal_op.py b/src/spox/_internal_op.py index f51f2579..901c694c 100644 --- a/src/spox/_internal_op.py +++ b/src/spox/_internal_op.py @@ -9,8 +9,9 @@ from abc import ABC from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional +from typing import TYPE_CHECKING, Callable, Optional +import numpy as np import onnx from ._attributes import AttrString, AttrTensor, AttrType @@ -22,6 +23,9 @@ from ._value_prop import PropValueType from ._var import Var +if TYPE_CHECKING: + from ._function import Function + # This is a default used for internal operators that # require the default domain. The most common of these # is Introduce, which is effectively used in every graph. @@ -84,7 +88,7 @@ class Outputs(BaseOutputs): inputs: Inputs outputs: Outputs - def post_init(self, **kwargs): + def post_init(self, **kwargs) -> None: if self.attrs.name is not None: self.outputs.arg._rename(self.attrs.name.value) @@ -92,14 +96,22 @@ def infer_output_types(self) -> dict[str, Type]: # Output type is based on the value of the type attribute return {"arg": self.attrs.type.value} - def update_metadata(self, opset_req, initializers, functions): + def update_metadata( + self, + opset_req: set[tuple[str, int]], + initializers: dict[Var, np.ndarray], + functions: list["Function"], + ) -> None: super().update_metadata(opset_req, initializers, functions) var = self.outputs.arg if self.attrs.default is not None: initializers[var] = self.attrs.default.value def to_onnx( - self, scope: "Scope", doc_string: Optional[str] = None, build_subgraph=None + self, + scope: "Scope", + doc_string: Optional[str] = None, + build_subgraph: Optional[Callable] = None, ) -> list[onnx.NodeProto]: return [] @@ -129,12 +141,20 @@ def infer_output_types(self) -> dict[str, Type]: def propagate_values(self) -> dict[str, PropValueType]: return {"arg": self.attrs.value.value} - def update_metadata(self, opset_req, initializers, functions): + def update_metadata( + self, + opset_req: set[tuple[str, int]], + initializers: dict[Var, np.ndarray], + functions: list["Function"], + ) -> None: super().update_metadata(opset_req, initializers, functions) initializers[self.outputs.arg] = self.attrs.value.value def to_onnx( - self, scope: "Scope", doc_string: Optional[str] = None, build_subgraph=None + self, + scope: "Scope", + doc_string: Optional[str] = None, + build_subgraph: Optional[Callable] = None, ) -> list[onnx.NodeProto]: # Initializers are added via update_metadata and don't affect the nodes proto list return [] @@ -173,7 +193,10 @@ def opset_req(self) -> set[tuple[str, int]]: return {("", INTERNAL_MIN_OPSET)} def to_onnx( - self, scope: Scope, doc_string: Optional[str] = None, build_subgraph=None + self, + scope: Scope, + doc_string: Optional[str] = None, + build_subgraph: Optional[Callable] = None, ) -> list[onnx.NodeProto]: assert len(self.inputs.inputs) == len(self.outputs.outputs) # Just create a renaming identity from what we forwarded into our actual output diff --git a/src/spox/_node.py b/src/spox/_node.py index e8cb31ad..e081afad 100644 --- a/src/spox/_node.py +++ b/src/spox/_node.py @@ -12,6 +12,7 @@ from dataclasses import dataclass from typing import ClassVar, Optional, Union +import numpy as np import onnx from ._attributes import AttrGraph @@ -23,6 +24,7 @@ from ._var import Var if typing.TYPE_CHECKING: + from ._function import Function from ._graph import Graph from ._scope import Scope @@ -206,10 +208,10 @@ def get_op_repr(cls) -> str: domain = cls.op_type.domain if cls.op_type.domain != "" else "ai.onnx" return f"{domain}@{cls.op_type.version}::{cls.op_type.identifier}" - def pre_init(self, **_) -> None: + def pre_init(self, **kwargs) -> None: """Pre-initialization hook. Called during ``__init__`` before any field on the object is set.""" - def post_init(self, **_) -> None: + def post_init(self, **kwargs) -> None: """Post-initialization hook. Called at the end of ``__init__`` after other default fields are set.""" def propagate_values(self) -> dict[str, PropValueType]: @@ -284,7 +286,7 @@ def validate_types(self) -> None: stacklevel=4, ) - def _check_concrete_type(self, value_type: Type) -> Optional[str]: + def _check_concrete_type(self, value_type: Optional[Type]) -> Optional[str]: if value_type is None: return "type is None" try: @@ -346,7 +348,12 @@ def subgraphs(self) -> Iterable["Graph"]: if isinstance(attr, AttrGraph): yield attr.value - def update_metadata(self, opset_req, initializers, functions): + def update_metadata( + self, + opset_req: set[tuple[str, int]], + initializers: dict[Var, np.ndarray], + functions: list["Function"], + ) -> None: opset_req.update(self.opset_req) def to_onnx( diff --git a/src/spox/_public.py b/src/spox/_public.py index 101d8d40..173089df 100644 --- a/src/spox/_public.py +++ b/src/spox/_public.py @@ -5,6 +5,7 @@ import contextlib import itertools +from collections.abc import Generator from typing import Optional, Protocol import numpy as np @@ -41,7 +42,7 @@ def argument(typ: Type) -> Var: @contextlib.contextmanager -def _temporary_renames(**kwargs: Var): +def _temporary_renames(**kwargs: Var) -> Generator[None, None, None]: # The build code can't really special-case variable names that are # not just ``Var._name``. So we set names here and reset them # afterwards. @@ -58,7 +59,7 @@ def _temporary_renames(**kwargs: Var): def build( - inputs: dict[str, Var], outputs: dict[str, Var], *, drop_unused_inputs=False + inputs: dict[str, Var], outputs: dict[str, Var], *, drop_unused_inputs: bool = False ) -> onnx.ModelProto: """ Builds an ONNX Model with given model inputs and outputs. diff --git a/src/spox/_schemas.py b/src/spox/_schemas.py index 8bcffa93..7fbb5bf3 100644 --- a/src/spox/_schemas.py +++ b/src/spox/_schemas.py @@ -14,20 +14,21 @@ from onnx.defs import OpSchema, get_all_schemas_with_history - -class _Comparable(Protocol): - def __lt__(self, other) -> bool: ... - - def __gt__(self, other) -> bool: ... - - S = TypeVar("S") K = TypeVar("K") V = TypeVar("V") -T = TypeVar("T", bound=_Comparable) +T = TypeVar("T", bound="_Comparable", contravariant=True) + + +class _Comparable(Protocol[T]): + def __lt__(self: T, other: T) -> bool: ... + + def __gt__(self: T, other: T) -> bool: ... -def _key_groups(seq: Iterable[S], key: Callable[[S], T]): +def _key_groups( + seq: Iterable[S], key: Callable[[S], T] +) -> Iterable[tuple[T, Iterable[S]]]: """Group a sequence by a given key.""" return itertools.groupby(sorted(seq, key=key), key) diff --git a/src/spox/_scope.py b/src/spox/_scope.py index 50b1ad52..33144f62 100644 --- a/src/spox/_scope.py +++ b/src/spox/_scope.py @@ -76,7 +76,7 @@ def __getitem__(self, item: H) -> str: ... @overload def __getitem__(self, item: str) -> H: ... - def __getitem__(self, item: Union[str, H]): + def __getitem__(self, item: Union[str, H]) -> Union[str, H]: """Access the name of an object or an object with a given name in this (or outer) namespace.""" if self.parent is not None and item in self.parent: return self.parent[item] @@ -86,18 +86,18 @@ def __getitem__(self, item: Union[str, H]): return self.name_of[item] @overload - def __setitem__(self, key: str, value: H): ... + def __setitem__(self, key: str, value: H) -> None: ... @overload - def __setitem__(self, key: H, value: str): ... + def __setitem__(self, key: H, value: str) -> None: ... - def __setitem__(self, _key, _value): + def __setitem__(self, _key: Union[str, H], _value: Union[H, str]) -> None: """Set the name of an object in exactly this namespace. Both ``[name] = obj`` and ``[obj] = name`` work.""" if isinstance(_value, str): _key, _value = _value, _key + assert isinstance(_key, str) key: str = _key - value: H = _value - assert isinstance(key, str) + value: H = _value # type: ignore if key in self and self[key] != value: raise ScopeError( f"Failed to name {value}, as its name {key} " @@ -113,7 +113,7 @@ def __setitem__(self, _key, _value): self.of_name[key] = value self.name_of[value] = key - def __delitem__(self, item: Union[str, H]): + def __delitem__(self, item: Union[str, H]) -> None: """Delete a both the name and object from exactly this namespace.""" if isinstance(item, str): key, value = item, self.of_name[item] @@ -173,7 +173,10 @@ def __init__( ) @classmethod - def of(cls, *what): + def of( + cls, + *what: Union[tuple[str, Union[Var, Node]], tuple[Union[Var, Node], str]], + ) -> "Scope": """Convenience constructor for filling a Scope with known names.""" scope = cls() for key, value in what: @@ -188,7 +191,7 @@ def of(cls, *what): raise TypeError(f"Unknown value type for Scope.of: {type(value)}") return scope - def update(self, node: Node, prefix: str = "", force: bool = True): + def update(self, node: Node, prefix: str = "", force: bool = True) -> None: """ Function used for introducing a Node and its outputs into the scope in the build routine. diff --git a/src/spox/_standard.py b/src/spox/_standard.py index ac519875..c32a221b 100644 --- a/src/spox/_standard.py +++ b/src/spox/_standard.py @@ -20,6 +20,7 @@ from ._type_system import Optional, Sequence, Tensor, Type from ._utils import from_array from ._value_prop import PropValueType +from ._var import Var if TYPE_CHECKING: from ._graph import Graph @@ -89,7 +90,7 @@ def to_singleton_onnx_model( ] # Output types with placeholder empty TypeProto (or actual type if not using dummies) - def out_value_info(curr_key, curr_var): + def out_value_info(curr_key: str, curr_var: Var) -> onnx.ValueInfoProto: if dummy_outputs or curr_var.type is None or not curr_var.type._is_concrete: return onnx.helper.make_value_info(curr_key, onnx.TypeProto()) return curr_var.unwrap_type()._to_onnx_value_info(curr_key) diff --git a/src/spox/_var.py b/src/spox/_var.py index e18b2fbb..e1f00b2a 100644 --- a/src/spox/_var.py +++ b/src/spox/_var.py @@ -15,7 +15,7 @@ class NotImplementedOperatorDispatcher: - def _not_impl(self, *_): + def _not_impl(self, *args): return NotImplemented add = sub = mul = truediv = floordiv = neg = and_ = or_ = xor = not_ = _not_impl