Skip to content

Commit

Permalink
Changes
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Oct 31, 2024
1 parent b6aa765 commit 1d6bcbc
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 101 deletions.
10 changes: 5 additions & 5 deletions src/spox/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ._node import Node
from ._scope import Scope
from ._traverse import iterative_dfs
from ._var import VarInfo
from ._var import Var, VarInfo, unwrap_vars

if TYPE_CHECKING:
from ._graph import Graph
Expand Down Expand Up @@ -203,16 +203,16 @@ def build_main(self) -> BuildResult:

@staticmethod
def get_intro_results(
request_results: dict[str, VarInfo], set_names: bool
) -> list[VarInfo]:
request_results: dict[str, Var], set_names: bool
) -> list[Var]:
"""
Helper method for wrapping all requested results into a single Introduce and possibly naming them.
By default, only the main graph's results are named (and subgraphs get somewhat autogenerated names),
as usually only ONNX subgraph input/output ordering is significant.
"""
# Created vars all have the same op
vars = list(intros(*request_results.values()))
vars = list(intros(*unwrap_vars(request_results.values())))
for key, var in zip(request_results, vars):
if set_names:
var._rename(key)
Expand Down Expand Up @@ -290,7 +290,7 @@ def collect_arguments(nd: Node):
else:
# If there is a request, we may not have found it by traversal if an argument was unused.
all_arguments |= set(graph.requested_arguments)
self.arguments_of[graph] = list(graph.requested_arguments)
self.arguments_of[graph] = unwrap_vars(graph.requested_arguments)

if set(self.arguments_of[graph]) & claimed_arguments:
raise BuildError(
Expand Down
44 changes: 42 additions & 2 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

import dataclasses
import enum
import warnings
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from typing import Any, Optional, Union

from ._attributes import Attr
from ._var import VarInfo
from ._exceptions import InferenceWarning
from ._value_prop import PropValue
from ._var import Var, VarInfo


@dataclass
Expand Down Expand Up @@ -113,4 +116,41 @@ class BaseInputs(BaseVarInfos):

@dataclass
class BaseOutputs(BaseVarInfos):
pass
def _propagate_vars(
self,
prop_values={},
flatten_variadic=False,
):
def _create_var(key, var_info):
ret = Var(var_info, None)

if var_info.type is None or key not in prop_values:
return ret

prop = PropValue(var_info.type, prop_values.get(key))
if prop.check():
ret._value = prop
else:
warnings.warn(
InferenceWarning(
f"Propagated value {prop} does not type-check, dropping. "
f"Hint: this indicates a bug with the current value prop backend or type inference."
)
)

return ret

ret_dict = {}

for key, var_info in self.__dict__.items():
if var_info is None or isinstance(var_info, VarInfo):
ret_dict[key] = _create_var(key, var_info)
elif flatten_variadic:
for i, v in enumerate(var_info):
ret_dict[f"{key}_{i}"] = _create_var(f"{key}_{i}", v)
else:
ret_dict[key] = [
_create_var(f"{key}_{i}", v) for i, v in enumerate(var_info)
]

return ret_dict
4 changes: 2 additions & 2 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ._internal_op import _InternalNode
from ._node import Node, OpType
from ._type_system import Type
from ._var import VarInfo
from ._var import Var, VarInfo

if TYPE_CHECKING:
from . import _graph
Expand Down Expand Up @@ -42,7 +42,7 @@ class Function(_InternalNode):
via the ``to_onnx_function`` method.
"""

func_args: dict[str, VarInfo]
func_args: dict[str, Var]
func_attrs: dict[str, _attributes.Attr]
func_inputs: BaseInputs
func_outputs: BaseOutputs
Expand Down
52 changes: 26 additions & 26 deletions src/spox/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import spox._value_prop
from spox._graph import initializer as _initializer
from spox._type_system import Tensor
from spox._var import VarInfo
from spox._var import Var

TypeWarningLevel = spox._node.TypeWarningLevel

Expand Down Expand Up @@ -46,7 +46,7 @@ def value_prop_backend(backend: ValuePropBackend):
set_value_prop_backend(prev_backend)


def initializer(value: npt.ArrayLike, dtype: npt.DTypeLike = None) -> VarInfo:
def initializer(value: npt.ArrayLike, dtype: npt.DTypeLike = None) -> Var:
"""
Create a VarInfo with a constant value.
Expand Down Expand Up @@ -79,14 +79,14 @@ def __init__(self, op, type_promotion: bool, constant_promotion: bool):
self.constant_promotion = constant_promotion

def _promote(
self, *args: Union[VarInfo, np.generic, int, float], to_floating: bool = False
) -> Iterable[Optional[VarInfo]]:
self, *args: Union[Var, np.generic, int, float], to_floating: bool = False
) -> Iterable[Optional[Var]]:
"""
Apply constant promotion and type promotion to given parameters,
creating constants and/or casting.
"""
targets: list[Union[np.dtype, np.generic, int, float]] = [
x.type.dtype if isinstance(x, VarInfo) and isinstance(x.type, Tensor) else x # type: ignore
x.type.dtype if isinstance(x, Var) and isinstance(x.type, Tensor) else x # type: ignore
for x in args
]
if self.type_promotion:
Expand All @@ -97,7 +97,7 @@ def _promote(
dtypes = {dtype for dtype in targets if isinstance(dtype, np.dtype)}
if len(dtypes) > 1:
raise TypeError(
f"Inconsistent types for VarInfo operator with no type promotion: {dtypes}."
f"Inconsistent types for Var operator with no type promotion: {dtypes}."
)
(target_type,) = dtypes
if issubclass(target_type.type, np.integer):
Expand All @@ -113,63 +113,63 @@ def _promote(
# TODO: Handle more constant-target inconsistencies here?

def _promote_target(
obj: Union[VarInfo, np.generic, int, float],
) -> Optional[VarInfo]:
obj: Union[Var, np.generic, int, float],
) -> Optional[Var]:
if self.constant_promotion and isinstance(obj, (np.generic, int, float)):
return self.op.const(np.array(obj, dtype=target_type))
elif isinstance(obj, VarInfo):
elif isinstance(obj, Var):
return self.op.cast(obj, to=target_type) if self.type_promotion else obj
raise TypeError(
f"Bad value '{obj!r}' of type {type(obj).__name__!r} for operator overloading with VarInfo. "
f"Bad value '{obj!r}' of type {type(obj).__name__!r} for operator overloading with Var. "
f"({self.type_promotion=}, {self.constant_promotion=})"
)

return tuple(var for var in map(_promote_target, args))

def add(self, a, b) -> VarInfo:
def add(self, a, b) -> Var:
a, b = self._promote(a, b)
return self.op.add(a, b)

def sub(self, a, b) -> VarInfo:
def sub(self, a, b) -> Var:
a, b = self._promote(a, b)
return self.op.sub(a, b)

def mul(self, a, b) -> VarInfo:
def mul(self, a, b) -> Var:
a, b = self._promote(a, b)
return self.op.mul(a, b)

def truediv(self, a, b) -> VarInfo:
def truediv(self, a, b) -> Var:
a, b = self._promote(a, b, to_floating=True)
return self.op.div(a, b)

def floordiv(self, a, b) -> VarInfo:
def floordiv(self, a, b) -> Var:
a, b = self._promote(a, b)
c = self.op.div(a, b)
if isinstance(c.type, Tensor) and not issubclass(c.type._elem_type, np.integer):
c = self.op.floor(c)
return c

def neg(self, a: VarInfo) -> VarInfo:
def neg(self, a: Var) -> Var:
return self.op.neg(a)

def and_(self, a: VarInfo, b: VarInfo) -> VarInfo:
def and_(self, a: Var, b: Var) -> Var:
return self.op.and_(a, b)

def or_(self, a: VarInfo, b: VarInfo) -> VarInfo:
def or_(self, a: Var, b: Var) -> Var:
return self.op.or_(a, b)

def xor(self, a: VarInfo, b: VarInfo) -> VarInfo:
def xor(self, a: Var, b: Var) -> Var:
return self.op.xor(a, b)

def not_(self, a: VarInfo) -> VarInfo:
def not_(self, a: Var) -> Var:
return self.op.not_(a)


@contextmanager
def operator_overloading(
op, type_promotion: bool = False, constant_promotion: bool = True
):
"""Enable operator overloading on VarInfo for this block.
"""Enable operator overloading on Var for this block.
May be used either as a context manager, or a decorator.
Expand All @@ -187,7 +187,7 @@ def operator_overloading(
if the type was not conclusively floating (as in numpy).
False by default.
constant_promotion
Whether operator overloading should implicitly promote primitive scalar constants to VarInfo.
Whether operator overloading should implicitly promote primitive scalar constants to Var.
True by default.
Examples
Expand All @@ -203,12 +203,12 @@ def operator_overloading(
... return x * y
>>> assert foo()._get_value() == np.array(6)
"""
prev_dispatcher = VarInfo._operator_dispatcher
VarInfo._operator_dispatcher = _NumpyLikeOperatorDispatcher(
prev_dispatcher = Var._operator_dispatcher
Var._operator_dispatcher = _NumpyLikeOperatorDispatcher(
op, type_promotion, constant_promotion
)
yield
VarInfo._operator_dispatcher = prev_dispatcher
Var._operator_dispatcher = prev_dispatcher


__all__ = [
Expand All @@ -222,6 +222,6 @@ def operator_overloading(
"ValuePropBackend",
"set_value_prop_backend",
"value_prop_backend",
# Operator overloading on VarInfo
# Operator overloading on Var
"operator_overloading",
]
Loading

0 comments on commit 1d6bcbc

Please sign in to comment.