Skip to content

Commit

Permalink
Improve type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Nov 18, 2024
1 parent 6aa1bf4 commit 07f9676
Show file tree
Hide file tree
Showing 21 changed files with 373 additions and 348 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Change log
0.13.0 (unreleased)
-------------------

**Other changes**

- Split `Var`-s internally to help the garbage collector in collecting propagated values.

**Support change**

- Support for ``Python 3.8`` has been dropped.
Expand Down
2 changes: 1 addition & 1 deletion src/spox/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def compile_graph(
self, graph: "Graph", scope: Scope, prefix: str = ""
) -> BuildResult:
"""
Compile a given Graph into a BuildResult. Handles naming of all the VarInfos/Nodes and only adds Nodes to a
Compile a given Graph into a BuildResult. Handles naming of all the Vars/Nodes and only adds Nodes to a
Graph that should be present in the respective GraphProto. The passed Scope object is aware of values already
available in the outer scope and may be the source of errors if the build fails.
Expand Down
24 changes: 15 additions & 9 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import warnings
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast

from ._attributes import Attr
from ._exceptions import InferenceWarning
from ._value_prop import PropValue
from ._value_prop import PropDict, PropValue
from ._var import Var, VarInfo


Expand Down Expand Up @@ -156,27 +156,33 @@ def fully_typed(self) -> bool:

@dataclass
class BaseInputs(BaseVarInfos):
def vars(self, prop_values):
def vars(self, prop_values: Optional[PropDict] = None) -> BaseVars:
if prop_values is None:
prop_values = {}

vars_dict: dict[str, Union[Var, Sequence[Var]]] = {}

for field in dataclasses.fields(self):
field_type = self._get_field_type(field)
field_value = getattr(self, field.name)

if field_type == VarFieldKind.SINGLE:
vars_dict[field.name] = Var(field_value, prop_values[field.name])
prop_value = cast(PropValue, prop_values[field.name])
vars_dict[field.name] = Var(field_value, prop_value)

elif (
field_type == VarFieldKind.OPTIONAL
and prop_values.get(field.name, None) is not None
):
vars_dict[field.name] = Var(field_value, prop_values[field.name])
prop_value = cast(PropValue, prop_values[field.name])
vars_dict[field.name] = Var(field_value, prop_value)

elif field_type == VarFieldKind.VARIADIC:
vars = []

for i, var_info in enumerate(field_value):
var_value = prop_values.get(f"{field.name}_{i}", None)
assert isinstance(var_value, PropValue)
vars.append(Var(var_info, var_value))

vars_dict[field.name] = vars
Expand All @@ -186,10 +192,10 @@ def vars(self, prop_values):

@dataclass
class BaseOutputs(BaseVarInfos):
def _propagate_vars(
self,
prop_values={},
):
def _propagate_vars(self, prop_values: Optional[PropDict] = None) -> BaseVars:
if prop_values is None:
prop_values = {}

def _create_var(key, var_info):
ret = Var(var_info, None)

Expand Down
6 changes: 3 additions & 3 deletions src/spox/_internal_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ._scope import Scope
from ._shape import SimpleShape
from ._type_system import Tensor, Type
from ._value_prop import PropValueType
from ._value_prop import PropDict, PropValueType
from ._var import Var, VarInfo, unwrap_vars

# This is a default used for internal operators that
Expand Down Expand Up @@ -121,12 +121,12 @@ class Outputs(BaseOutputs):
inputs: BaseInputs
outputs: Outputs

def infer_output_types(self, input_prop_values={}) -> dict[str, Type]:
def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
# Output type is based on the value of the type attribute
arr = self.attrs.value.value
return {"arg": Tensor(arr.dtype, arr.shape)}

def propagate_values(self, input_prop_values) -> dict[str, PropValueType]:
def propagate_values(self, input_prop_values: PropDict) -> dict[str, PropValueType]:
return {"arg": self.attrs.value.value}

def update_metadata(self, opset_req, initializers, functions):
Expand Down
8 changes: 4 additions & 4 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from ._debug import STORE_TRACEBACK
from ._exceptions import InferenceWarning
from ._fields import BaseAttributes, BaseInputs, BaseOutputs, VarFieldKind
from ._type_system import PropDict, Type
from ._type_system import Type
from ._value_prop import PropDict
from ._var import VarInfo

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -94,7 +95,6 @@ def __init__(
out_variadic: Optional[int] = None,
infer_types: bool = True,
validate: bool = True,
input_prop_values: PropDict = {},
**kwargs,
):
"""
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(
# As inference functions may access which output vars we initialized (e.g. variadics)
# we inject uninitialized vars first
self.outputs = self._init_output_vars()
self.inference(infer_types=infer_types, input_prop_values={})
self.inference(infer_types=infer_types)
else:
self.outputs = outputs

Expand Down Expand Up @@ -214,7 +214,7 @@ def propagate_values(self, input_prop_values: PropDict) -> PropDict:
"""
return {}

def infer_output_types(self, input_prop_values) -> dict[str, Type]:
def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
"""
Inference routine for output types. Often overriden by inheriting Node types.
Expand Down
5 changes: 3 additions & 2 deletions src/spox/_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ._inline import _Inline
from ._standard import _strip_dim_symbol
from ._type_system import Type
from ._value_prop import PropDict
from ._var import Var


Expand Down Expand Up @@ -307,9 +308,9 @@ def inline_inner(*args: Var, **kwargs: Var) -> dict[str, Var]:
model=model,
)

prop_values = {
prop_values: PropDict = {
name: kwargs[name]._value
for i, name in enumerate(in_names)
for name in in_names
if kwargs[name]._value is not None
}

Expand Down
6 changes: 3 additions & 3 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from ._schemas import SCHEMAS
from ._scope import Scope
from ._shape import SimpleShape
from ._type_system import Optional, PropDict, Sequence, Tensor, Type
from ._type_system import Optional, Sequence, Tensor, Type
from ._utils import from_array
from ._value_prop import PropValue, PropValueType
from ._value_prop import PropDict, PropValue, PropValueType

if TYPE_CHECKING:
from ._graph import Graph
Expand Down Expand Up @@ -209,7 +209,7 @@ def propagate_values_onnx(
}
return {k: v for k, v in results.items() if k is not None}

def infer_output_types(self, input_prop_values={}) -> dict[str, Type]:
def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
return self.infer_output_types_onnx(input_prop_values)

def propagate_values(self, input_prop_values) -> dict[str, PropValueType]:
Expand Down
5 changes: 1 addition & 4 deletions src/spox/_type_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from dataclasses import dataclass
from typing import Any, TypeVar
from typing import TypeVar

import numpy as np
import numpy.typing as npt
Expand All @@ -14,9 +14,6 @@
T = TypeVar("T")
S = TypeVar("S")

# TODO: Fix typing
PropDict = dict[str, Any]


@dataclass(frozen=True)
class Type:
Expand Down
9 changes: 6 additions & 3 deletions src/spox/_value_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import enum
import logging
import warnings
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Callable, Union
from typing import Optional as tOptional

import numpy as np
import numpy.typing as npt
Expand All @@ -24,9 +26,10 @@
- PropValue -> Optional, Some (has value)
- None -> Optional, Nothing (no value)
"""
PropValueType = Union[np.ndarray, list["PropValue"], "PropValue", None]
ORTValue = Union[np.ndarray, list, None]
RefValue = Union[np.ndarray, list, float, None]
PropValueType = Union[np.ndarray, Iterable[tOptional["PropValue"]], "PropValue", None]
PropDict = dict[str, Union[Iterable[tOptional["PropValue"]], "PropValue", None]]
ORTValue = Union[np.ndarray, Iterable, None]
RefValue = Union[np.ndarray, Iterable, float, None]

VALUE_PROP_STRICT_CHECK: bool = False

Expand Down
Loading

0 comments on commit 07f9676

Please sign in to comment.