Skip to content

Commit

Permalink
Left future
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Dec 2, 2024
1 parent 1a462a8 commit 8c8c05a
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/spox/_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def adapt_inline(
base_model = node.model
try:
node.model = target_model
target_nodes = node.to_onnx(Scope.of((node, node_name), *var_names.items()))
target_nodes = node.to_onnx(Scope.of((node_name, node), *var_names.items()))
finally:
node.model = base_model
return target_nodes
Expand Down
8 changes: 4 additions & 4 deletions src/spox/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def value(self) -> T:
return self._value

@value.setter
def value(self, to: T):
def value(self, to: T) -> None:
self._value = to


Expand Down Expand Up @@ -124,7 +124,7 @@ class ScopeTree:
subgraph_owner: dict["Graph", Node]
scope_of: dict[Node, "Graph"]

def __init__(self):
def __init__(self) -> None:
self.subgraph_owner = {}
self.scope_of = {}

Expand Down Expand Up @@ -255,7 +255,7 @@ def discover(self, graph: "Graph") -> tuple[set[Var], set[Var]]:
claimed_arguments = self.claimed_arguments_in[graph] = set()
used_arguments = set()

def collect_arguments(nd: Node):
def collect_arguments(nd: Node) -> None:
nonlocal all_arguments, claimed_arguments, used_arguments
if isinstance(nd, Argument):
all_arguments.add(nd.outputs.arg)
Expand Down Expand Up @@ -329,7 +329,7 @@ def update_scope_tree(self, graph: "Graph") -> None:
is completed 'bottom-up'.
"""

def satisfy_constraints(node):
def satisfy_constraints(node: Node) -> None:
# By default, a node is bound to the scope it is found in.
self.scope_tree.scope_of.setdefault(node, graph)
# Bring up the scope of its node to its ancestors if it is too low to be accessible in the current graph.
Expand Down
40 changes: 29 additions & 11 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import inspect
import itertools
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Callable, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union

import numpy as np
import onnx

from . import _attributes
Expand Down Expand Up @@ -48,7 +49,9 @@ class Function(_InternalNode):
func_outputs: BaseOutputs
func_graph: "_graph.Graph"

def constructor(self, attrs, inputs):
def constructor(
self, attrs: dict[str, _attributes.Attr], inputs: BaseInputs
) -> BaseOutputs:
"""
Abstract method for functions.
Expand Down Expand Up @@ -89,11 +92,16 @@ def infer_output_types(self) -> dict[str, Type]:
}

@property
def opset_req(self):
def opset_req(self) -> set[tuple[str, int]]:
node_opset_req = Node.opset_req.fget(self) # type: ignore
return node_opset_req | self.func_graph._get_build_result().opset_req

def update_metadata(self, opset_req, initializers, functions):
def update_metadata(
self,
opset_req: set[tuple[str, int]],
initializers: dict[Var, np.ndarray],
functions: list["Function"],
) -> None:
super().update_metadata(opset_req, initializers, functions)
functions.append(self)
functions.extend(self.func_graph._get_build_result().functions)
Expand Down Expand Up @@ -123,10 +131,18 @@ def to_onnx_function(
)


def _make_function_cls(fun, num_inputs, num_outputs, domain, version, name):
def _make_function_cls(
fun: Callable[..., Any],
num_inputs: int,
num_outputs: int,
domain: str,
version: int,
name: str,
) -> type[Function]:
_FuncInputs = make_dataclass(
"_FuncInputs", ((f"in{i}", Var) for i in range(num_inputs)), bases=(BaseInputs,)
)

_FuncOutputs = make_dataclass(
"_FuncOutputs",
((f"out{i}", Var) for i in range(num_outputs)),
Expand All @@ -142,13 +158,15 @@ class Attributes(BaseAttributes):
Outputs = _FuncOutputs
op_type = OpType(name, domain, version)

def constructor(self, attrs, inputs):
def constructor(self, attrs: dict[str, _attributes.Attr], inputs: Any) -> Any:
return self.Outputs(*fun(*inputs.get_fields().values()))

return _Func


def to_function(name: str, domain: str = "spox.function", *, _version: int = 0):
def to_function(
name: str, domain: str = "spox.function", *, _version: int = 0
) -> Callable:
"""
Decorate a given function to make the operation performed by it add a Spox function to the graph.
Expand Down Expand Up @@ -176,7 +194,7 @@ def get_num_outputs(*args: Var) -> int:
_num_outputs = sum(1 for _ in fun(*args))
return _num_outputs

def init(*args: Var):
def init(*args: Var) -> type[Function]:
nonlocal _cls
if _cls is not None:
return _cls
Expand All @@ -186,9 +204,9 @@ def init(*args: Var):
)
return _cls

def alt_fun(*args: Var) -> Iterable[Var]:
def alt_fun(*args: Var) -> Iterable[Union[Var, Optional[Var], Sequence[Var]]]:
cls = init(*args)
return (
return list(
cls(cls.Attributes(), cls.Inputs(*args)).outputs.get_fields().values()
)

Expand Down
8 changes: 4 additions & 4 deletions src/spox/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class Graph:
default_factory=_build.Cached
)

def __repr__(self):
def __repr__(self) -> str:
name_repr = self._name if self._name is not None else "?"
args_repr = (
f"{', '.join(str(a) for a in self._arguments)}"
Expand All @@ -158,7 +158,7 @@ def __repr__(self):
comments.append(f"+{len(self._extra_opset_req)} opset req")
return f"<Graph '{name_repr}' ({args_repr}) -> ({res_repr}){': ' if comments else ''}{', '.join(comments)}>"

def __post_init__(self):
def __post_init__(self) -> None:
if any(not isinstance(var, Var) for var in self._results.values()):
seen_types = {type(obj) for obj in self._results.values()}
raise TypeError(f"Graph results must be Vars, not {seen_types - {Var}}.")
Expand Down Expand Up @@ -362,7 +362,7 @@ def to_onnx_model(
model_doc_string: str = "",
infer_shapes: bool = False,
check_model: Union[Literal[0], Literal[1], Literal[2]] = 1,
ir_version=8,
ir_version: int = 8,
concrete: bool = True,
) -> onnx.ModelProto:
"""
Expand Down Expand Up @@ -448,7 +448,7 @@ def results(**kwargs: Var) -> Graph:
return Graph(kwargs)


def enum_results(*vars: Var, prefix="out") -> Graph:
def enum_results(*vars: Var, prefix: str = "out") -> Graph:
"""
Use this function to construct a ``Graph`` object, whenever the exact names are not important.
Useful when creating subgraphs.
Expand Down
7 changes: 5 additions & 2 deletions src/spox/_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def rename_in_graph(
rename_node: Optional[Callable[[str], str]] = None,
rename_op: Optional[Callable[[str, str], tuple[str, str]]] = None,
) -> onnx.GraphProto:
def rename_in_subgraph(subgraph):
def rename_in_subgraph(subgraph: onnx.GraphProto) -> onnx.GraphProto:
return rename_in_graph(
subgraph,
rename,
Expand Down Expand Up @@ -146,7 +146,10 @@ def propagate_values(self) -> dict[str, _value_prop.PropValueType]:
}

def to_onnx(
self, scope: Scope, doc_string: Optional[str] = None, build_subgraph=None
self,
scope: Scope,
doc_string: Optional[str] = None,
build_subgraph: Optional[Callable] = None,
) -> list[onnx.NodeProto]:
input_names: dict[str, int] = {
p.name: i for i, p in enumerate(self.graph.input)
Expand Down
37 changes: 30 additions & 7 deletions src/spox/_internal_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from abc import ABC
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional
from typing import TYPE_CHECKING, Callable, Optional

import numpy as np
import onnx

from ._attributes import AttrString, AttrTensor, AttrType
Expand All @@ -22,6 +23,9 @@
from ._value_prop import PropValueType
from ._var import Var

if TYPE_CHECKING:
from ._function import Function

# This is a default used for internal operators that
# require the default domain. The most common of these
# is Introduce, which is effectively used in every graph.
Expand Down Expand Up @@ -84,22 +88,30 @@ class Outputs(BaseOutputs):
inputs: Inputs
outputs: Outputs

def post_init(self, **kwargs):
def post_init(self, **kwargs) -> None:
if self.attrs.name is not None:
self.outputs.arg._rename(self.attrs.name.value)

def infer_output_types(self) -> dict[str, Type]:
# Output type is based on the value of the type attribute
return {"arg": self.attrs.type.value}

def update_metadata(self, opset_req, initializers, functions):
def update_metadata(
self,
opset_req: set[tuple[str, int]],
initializers: dict[Var, np.ndarray],
functions: list["Function"],
) -> None:
super().update_metadata(opset_req, initializers, functions)
var = self.outputs.arg
if self.attrs.default is not None:
initializers[var] = self.attrs.default.value

def to_onnx(
self, scope: "Scope", doc_string: Optional[str] = None, build_subgraph=None
self,
scope: "Scope",
doc_string: Optional[str] = None,
build_subgraph: Optional[Callable] = None,
) -> list[onnx.NodeProto]:
return []

Expand Down Expand Up @@ -129,12 +141,20 @@ def infer_output_types(self) -> dict[str, Type]:
def propagate_values(self) -> dict[str, PropValueType]:
return {"arg": self.attrs.value.value}

def update_metadata(self, opset_req, initializers, functions):
def update_metadata(
self,
opset_req: set[tuple[str, int]],
initializers: dict[Var, np.ndarray],
functions: list["Function"],
) -> None:
super().update_metadata(opset_req, initializers, functions)
initializers[self.outputs.arg] = self.attrs.value.value

def to_onnx(
self, scope: "Scope", doc_string: Optional[str] = None, build_subgraph=None
self,
scope: "Scope",
doc_string: Optional[str] = None,
build_subgraph: Optional[Callable] = None,
) -> list[onnx.NodeProto]:
# Initializers are added via update_metadata and don't affect the nodes proto list
return []
Expand Down Expand Up @@ -173,7 +193,10 @@ def opset_req(self) -> set[tuple[str, int]]:
return {("", INTERNAL_MIN_OPSET)}

def to_onnx(
self, scope: Scope, doc_string: Optional[str] = None, build_subgraph=None
self,
scope: Scope,
doc_string: Optional[str] = None,
build_subgraph: Optional[Callable] = None,
) -> list[onnx.NodeProto]:
assert len(self.inputs.inputs) == len(self.outputs.outputs)
# Just create a renaming identity from what we forwarded into our actual output
Expand Down
15 changes: 11 additions & 4 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dataclasses import dataclass
from typing import ClassVar, Optional, Union

import numpy as np
import onnx

from ._attributes import AttrGraph
Expand All @@ -23,6 +24,7 @@
from ._var import Var

if typing.TYPE_CHECKING:
from ._function import Function
from ._graph import Graph
from ._scope import Scope

Expand Down Expand Up @@ -206,10 +208,10 @@ def get_op_repr(cls) -> str:
domain = cls.op_type.domain if cls.op_type.domain != "" else "ai.onnx"
return f"{domain}@{cls.op_type.version}::{cls.op_type.identifier}"

def pre_init(self, **_) -> None:
def pre_init(self, **kwargs) -> None:
"""Pre-initialization hook. Called during ``__init__`` before any field on the object is set."""

def post_init(self, **_) -> None:
def post_init(self, **kwargs) -> None:
"""Post-initialization hook. Called at the end of ``__init__`` after other default fields are set."""

def propagate_values(self) -> dict[str, PropValueType]:
Expand Down Expand Up @@ -284,7 +286,7 @@ def validate_types(self) -> None:
stacklevel=4,
)

def _check_concrete_type(self, value_type: Type) -> Optional[str]:
def _check_concrete_type(self, value_type: Optional[Type]) -> Optional[str]:
if value_type is None:
return "type is None"
try:
Expand Down Expand Up @@ -346,7 +348,12 @@ def subgraphs(self) -> Iterable["Graph"]:
if isinstance(attr, AttrGraph):
yield attr.value

def update_metadata(self, opset_req, initializers, functions):
def update_metadata(
self,
opset_req: set[tuple[str, int]],
initializers: dict[Var, np.ndarray],
functions: list["Function"],
) -> None:
opset_req.update(self.opset_req)

def to_onnx(
Expand Down
5 changes: 3 additions & 2 deletions src/spox/_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import contextlib
import itertools
from collections.abc import Generator
from typing import Optional, Protocol

import numpy as np
Expand Down Expand Up @@ -41,7 +42,7 @@ def argument(typ: Type) -> Var:


@contextlib.contextmanager
def _temporary_renames(**kwargs: Var):
def _temporary_renames(**kwargs: Var) -> Generator[None, None, None]:
# The build code can't really special-case variable names that are
# not just ``Var._name``. So we set names here and reset them
# afterwards.
Expand All @@ -58,7 +59,7 @@ def _temporary_renames(**kwargs: Var):


def build(
inputs: dict[str, Var], outputs: dict[str, Var], *, drop_unused_inputs=False
inputs: dict[str, Var], outputs: dict[str, Var], *, drop_unused_inputs: bool = False
) -> onnx.ModelProto:
"""
Builds an ONNX Model with given model inputs and outputs.
Expand Down
Loading

0 comments on commit 8c8c05a

Please sign in to comment.