Skip to content

Commit

Permalink
Init
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Oct 27, 2024
1 parent bb337b7 commit 539ca29
Show file tree
Hide file tree
Showing 24 changed files with 2,072 additions and 1,333 deletions.
8 changes: 4 additions & 4 deletions src/spox/_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
from ._schemas import SCHEMAS
from ._scope import Scope
from ._utils import from_array
from ._var import Var
from ._var import VarInfo


def adapt_node(
node: Node,
proto: onnx.NodeProto,
source_version: int,
target_version: int,
var_names: dict[Var, str],
var_names: dict[VarInfo, str],
) -> Optional[list[onnx.NodeProto]]:
if source_version == target_version:
return None
Expand Down Expand Up @@ -71,7 +71,7 @@ def adapt_inline(
node: _Inline,
protos: list[onnx.NodeProto],
target_opsets: dict[str, int],
var_names: dict[Var, str],
var_names: dict[VarInfo, str],
node_name: str,
) -> list[onnx.NodeProto]:
source_version = max({v for d, v in node.opset_req if d in ("", "ai.onnx")})
Expand Down Expand Up @@ -99,7 +99,7 @@ def adapt_best_effort(
node: Node,
protos: list[onnx.NodeProto],
opsets: dict[str, int],
var_names: dict[Var, str],
var_names: dict[VarInfo, str],
node_names: dict[Node, str],
) -> Optional[list[onnx.NodeProto]]:
if isinstance(node, _Inline):
Expand Down
28 changes: 14 additions & 14 deletions src/spox/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ._node import Node
from ._scope import Scope
from ._traverse import iterative_dfs
from ._var import Var
from ._var import VarInfo

if TYPE_CHECKING:
from ._graph import Graph
Expand Down Expand Up @@ -58,11 +58,11 @@ class BuildResult:

scope: Scope
nodes: dict[Node, tuple[onnx.NodeProto, ...]]
arguments: tuple[Var, ...]
results: tuple[Var, ...]
arguments: tuple[VarInfo, ...]
results: tuple[VarInfo, ...]
opset_req: set[tuple[str, int]]
functions: tuple["_function.Function", ...]
initializers: dict[Var, np.ndarray]
initializers: dict[VarInfo, np.ndarray]


class Builder:
Expand Down Expand Up @@ -93,7 +93,7 @@ class ScopeTree:
"""
Structure representing the tree of scopes, which are identified with the respective graphs.
This structure is the base of the least-enclosing-scope algorithm. Every value (Var), and hence
This structure is the base of the least-enclosing-scope algorithm. Every value (VarInfo), and hence
the responsible Node - up to its (Python object) identity may appear in multiple scopes, but it should
best-cased be computed only once in the ONNX graph, same as in the Python source code.
Expand Down Expand Up @@ -164,12 +164,12 @@ def lca(self, a: "Graph", b: "Graph") -> "Graph":
graphs: set["Graph"]
graph_topo: list["Graph"]
# Arguments, results
arguments_of: dict["Graph", list[Var]]
results_of: dict["Graph", list[Var]]
arguments_of: dict["Graph", list[VarInfo]]
results_of: dict["Graph", list[VarInfo]]
source_of: dict["Graph", Node]
# Arguments found by traversal
all_arguments_in: dict["Graph", set[Var]]
claimed_arguments_in: dict["Graph", set[Var]]
all_arguments_in: dict["Graph", set[VarInfo]]
claimed_arguments_in: dict["Graph", set[VarInfo]]
# Scopes
scope_tree: ScopeTree
scope_own: dict["Graph", list[Node]]
Expand Down Expand Up @@ -203,8 +203,8 @@ def build_main(self) -> BuildResult:

@staticmethod
def get_intro_results(
request_results: dict[str, Var], set_names: bool
) -> list[Var]:
request_results: dict[str, VarInfo], set_names: bool
) -> list[VarInfo]:
"""
Helper method for wrapping all requested results into a single Introduce and possibly naming them.
Expand All @@ -218,7 +218,7 @@ def get_intro_results(
var._rename(key)
return vars

def discover(self, graph: "Graph") -> tuple[set[Var], set[Var]]:
def discover(self, graph: "Graph") -> tuple[set[VarInfo], set[VarInfo]]:
"""
Run the discovery step of the build process. Resolves arguments and results for the involved graphs.
Finds the topological ordering between (sub)graphs and sets their owners (nodes of which they are attributes).
Expand Down Expand Up @@ -410,7 +410,7 @@ def compile_graph(
self, graph: "Graph", scope: Scope, prefix: str = ""
) -> BuildResult:
"""
Compile a given Graph into a BuildResult. Handles naming of all the Vars/Nodes and only adds Nodes to a
Compile a given Graph into a BuildResult. Handles naming of all the VarInfos/Nodes and only adds Nodes to a
Graph that should be present in the respective GraphProto. The passed Scope object is aware of values already
available in the outer scope and may be the source of errors if the build fails.
Expand All @@ -432,7 +432,7 @@ def compile_graph(
# A bunch of model metadata we're collecting
opset_req: set[tuple[str, int]] = set()
functions: list[_function.Function] = []
initializers: dict[Var, np.ndarray] = {}
initializers: dict[VarInfo, np.ndarray] = {}

# Add arguments to our scope
for arg in self.arguments_of[graph]:
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from contextlib import contextmanager

from spox._var import Var
from spox._var import VarInfo

# If `STORE_TRACEBACK` is `True` any node created will store a traceback for its point of creation.
STORE_TRACEBACK = False
Expand Down Expand Up @@ -36,7 +36,7 @@ def show_construction_tracebacks(debug_index):
if -1 in found:
del found[-1]
for name, obj in reversed(found.values()):
if isinstance(obj, Var):
if isinstance(obj, VarInfo):
if not obj:
continue
node = obj._op
Expand Down
2 changes: 1 addition & 1 deletion src/spox/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class InferenceWarning(Warning):
"""Warning related to partial typing of Variables.
"""Warning related to partial typing of VarInfoiables.
Incomplete type information may lead to reduced code safety or
failure to build the model. The most common underlying cause for
Expand Down
38 changes: 19 additions & 19 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Optional, Union

from ._attributes import Attr
from ._var import Var
from ._var import VarInfo


@dataclass
Expand All @@ -32,19 +32,19 @@ class VarFieldKind(enum.Enum):


@dataclass
class BaseVars(BaseFields):
class BaseVarInfos(BaseFields):
def __post_init__(self):
# Check if passed fields are of the appropriate types based on field kinds
for field in dataclasses.fields(self):
value = getattr(self, field.name)
field_type = self._get_field_type(field)
if field_type == VarFieldKind.SINGLE:
if not isinstance(value, Var):
raise TypeError(f"Field expected Var, got: {type(value)}.")
if not isinstance(value, VarInfo):
raise TypeError(f"Field expected VarInfo, got: {type(value)}.")
elif field_type == VarFieldKind.OPTIONAL:
if value is not None and not isinstance(value, Var):
if value is not None and not isinstance(value, VarInfo):
raise TypeError(
f"Optional must be Var or None, got: {type(value)}."
f"Optional must be VarInfo or None, got: {type(value)}."
)
elif field_type == VarFieldKind.VARIADIC:
if not isinstance(value, Iterable):
Expand All @@ -53,43 +53,43 @@ def __post_init__(self):
)
# Cast to tuple to avoid accidental mutation
setattr(self, field.name, tuple(value))
if bad := {type(var) for var in value} - {Var}:
if bad := {type(var) for var in value} - {VarInfo}:
raise TypeError(
f"Variadic field must only consist of Vars, got: {bad}."
f"Variadic field must only consist of VarInfos, got: {bad}."
)

@classmethod
def _get_field_type(cls, field) -> VarFieldKind:
"""Access the kind of the field (single, optional, variadic) based on its type annotation."""
if field.type == Var:
if field.type == VarInfo:
return VarFieldKind.SINGLE
elif field.type == Optional[Var]:
elif field.type == Optional[VarInfo]:
return VarFieldKind.OPTIONAL
elif field.type == Sequence[Var]:
elif field.type == Sequence[VarInfo]:
return VarFieldKind.VARIADIC
raise ValueError(f"Bad field type: '{field.type}'.")

def _flatten(self) -> Iterable[tuple[str, Optional[Var]]]:
def _flatten(self) -> Iterable[tuple[str, Optional[VarInfo]]]:
"""Iterate over the pairs of names and values of fields in this object."""
for key, value in self.__dict__.items():
if value is None or isinstance(value, Var):
if value is None or isinstance(value, VarInfo):
yield key, value
else:
yield from ((f"{key}_{i}", v) for i, v in enumerate(value))

def __iter__(self) -> Iterator[Optional[Var]]:
def __iter__(self) -> Iterator[Optional[VarInfo]]:
"""Iterate over the values of fields in this object."""
yield from (v for _, v in self._flatten())

def __len__(self) -> int:
"""Count the number of fields in this object (should be same as declared in the class)."""
return sum(1 for _ in self)

def get_vars(self) -> dict[str, Var]:
"""Return a flat mapping by name of all the Vars in this object."""
def get_vars(self) -> dict[str, VarInfo]:
"""Return a flat mapping by name of all the VarInfos in this object."""
return {key: var for key, var in self._flatten() if var is not None}

def get_fields(self) -> dict[str, Union[None, Var, Sequence[Var]]]:
def get_fields(self) -> dict[str, Union[None, VarInfo, Sequence[VarInfo]]]:
"""Return a mapping of all fields stored in this object by name."""
return self.__dict__.copy()

Expand All @@ -107,10 +107,10 @@ def fully_typed(self) -> bool:


@dataclass
class BaseInputs(BaseVars):
class BaseInputs(BaseVarInfos):
pass


@dataclass
class BaseOutputs(BaseVars):
class BaseOutputs(BaseVarInfos):
pass
20 changes: 11 additions & 9 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
from ._internal_op import _InternalNode
from ._node import Node, OpType
from ._type_system import Type
from ._var import Var
from ._var import VarInfo

if TYPE_CHECKING:
from . import _graph

DEFAULT_FUNCTION_DOMAIN = "spox.default"

ConstructorT = TypeVar("ConstructorT", bound=Callable[..., Iterable[Var]])
ConstructorT = TypeVar("ConstructorT", bound=Callable[..., Iterable[VarInfo]])


class Function(_InternalNode):
Expand All @@ -42,7 +42,7 @@ class Function(_InternalNode):
via the ``to_onnx_function`` method.
"""

func_args: dict[str, Var]
func_args: dict[str, VarInfo]
func_attrs: dict[str, _attributes.Attr]
func_inputs: BaseInputs
func_outputs: BaseOutputs
Expand Down Expand Up @@ -125,11 +125,13 @@ def to_onnx_function(

def _make_function_cls(fun, num_inputs, num_outputs, domain, version, name):
_FuncInputs = make_dataclass(
"_FuncInputs", ((f"in{i}", Var) for i in range(num_inputs)), bases=(BaseInputs,)
"_FuncInputs",
((f"in{i}", VarInfo) for i in range(num_inputs)),
bases=(BaseInputs,),
)
_FuncOutputs = make_dataclass(
"_FuncOutputs",
((f"out{i}", Var) for i in range(num_outputs)),
((f"out{i}", VarInfo) for i in range(num_outputs)),
bases=(BaseOutputs,),
)

Expand All @@ -155,7 +157,7 @@ def to_function(name: str, domain: str = "spox.function", *, _version: int = 0):
The function must be deterministic in the performed operations, as otherwise an error will be raised at build
due to inconsistent function bodies.
``fun`` is assumed to take only Var arguments and return an iterable of them. These will be used to generate the
``fun`` is assumed to take only VarInfo arguments and return an iterable of them. These will be used to generate the
function class signature.
Keep in mind that functions with the same name & domain will be merged together.
Expand All @@ -170,13 +172,13 @@ def inner(fun: ConstructorT) -> ConstructorT:
_num_outputs = None
_cls = None

def get_num_outputs(*args: Var) -> int:
def get_num_outputs(*args: VarInfo) -> int:
nonlocal _num_outputs
if _num_outputs is None:
_num_outputs = sum(1 for _ in fun(*args))
return _num_outputs

def init(*args: Var):
def init(*args: VarInfo):
nonlocal _cls
if _cls is not None:
return _cls
Expand All @@ -186,7 +188,7 @@ def init(*args: Var):
)
return _cls

def alt_fun(*args: Var) -> Iterable[Var]:
def alt_fun(*args: VarInfo) -> Iterable[VarInfo]:
cls = init(*args)
return (
cls(cls.Attributes(), cls.Inputs(*args)).outputs.get_fields().values()
Expand Down
Loading

0 comments on commit 539ca29

Please sign in to comment.