Skip to content

Commit

Permalink
Introduce _VarInfo internally to reduce memory footprint in value p…
Browse files Browse the repository at this point in the history
…ropagation (#189)

* Init

* Run linter

* Changes

* Fix ml

* Fix some tests

* Fix some tests

* Fix more tests

* Add some proper typing

* More initializers

* Fix passing

* Minor fixes and linter

* Change initializers name to input_prop_values

* Make tests passing

* Hacky fix mypy

* Correctly codegen

* Comments after code review

* Improve type checking

* Pre-commit enable

* Update documentation

* Fix adapt node

* Fix function inputs passing

* Fix opset generation

* Hint that _VarInfo is private

* Fix jinja

* Fix variadic input value propagation

* Comments after code review

* Move validation to after propagation

* Fix diff

* Fix diffs

* Improve type-hinting information

* Remove unneded functions

* Init

* fix

* Final fixes

* Add test for propagation of optional var

* Unify logic around VarInfos -> Var

* Add comment

* Improve qol

* Update CHANGELOG.rst

* Fix tools/generate

* Update CHANGELOG.rst

Co-authored-by: Christian Bourjau <[email protected]>

* Comments after code-review

---------

Co-authored-by: Christian Bourjau <[email protected]>
  • Loading branch information
neNasko1 and cbourjau authored Dec 10, 2024
1 parent 148a940 commit 320d57f
Show file tree
Hide file tree
Showing 34 changed files with 6,636 additions and 3,807 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
Change log
==========

0.14.0 (unreleased)
-------------------

**Other changes**

- Propagated values may now be garbage collected if their associated `Var` object goes out of scope.

0.13.0 (2024-12-06)
-------------------

Expand Down
20 changes: 10 additions & 10 deletions src/spox/_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from ._node import Node
from ._schemas import SCHEMAS
from ._scope import Scope
from ._var import Var
from ._var import _VarInfo


def adapt_node(
node: Node,
proto: onnx.NodeProto,
source_version: int,
target_version: int,
var_names: dict[Var, str],
var_names: dict[_VarInfo, str],
) -> Optional[list[onnx.NodeProto]]:
if source_version == target_version:
return None
Expand All @@ -30,16 +30,16 @@ def adapt_node(
# By using a dictionary we ensure that we only have a single
# ValueInfo per (possibly repeated) input name.
input_info = {
var_names[var]: var.unwrap_type()._to_onnx_value_info(
var_names[var], _traceback_name=f"adapt-input {key}"
var_names[var_info]: var_info.unwrap_type()._to_onnx_value_info(
var_names[var_info], _traceback_name=f"adapt-input {key}"
)
for key, var in node.inputs.get_vars().items()
for key, var_info in node.inputs.get_var_infos().items()
}
output_info = [
var.unwrap_type()._to_onnx_value_info(
var_names[var], _traceback_name=f"adapt-output {key}"
var_info.unwrap_type()._to_onnx_value_info(
var_names[var_info], _traceback_name=f"adapt-output {key}"
)
for key, var in node.outputs.get_vars().items()
for key, var_info in node.outputs.get_var_infos().items()
]
except ValueError:
return None
Expand All @@ -63,7 +63,7 @@ def adapt_inline(
node: _Inline,
protos: list[onnx.NodeProto],
target_opsets: dict[str, int],
var_names: dict[Var, str],
var_names: dict[_VarInfo, str],
node_name: str,
) -> list[onnx.NodeProto]:
source_version = max({v for d, v in node.opset_req if d in ("", "ai.onnx")})
Expand Down Expand Up @@ -91,7 +91,7 @@ def adapt_best_effort(
node: Node,
protos: list[onnx.NodeProto],
opsets: dict[str, int],
var_names: dict[Var, str],
var_names: dict[_VarInfo, str],
node_names: dict[Node, str],
) -> Optional[list[onnx.NodeProto]]:
if isinstance(node, _Inline):
Expand Down
30 changes: 15 additions & 15 deletions src/spox/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ._node import Node
from ._scope import Scope
from ._traverse import iterative_dfs
from ._var import Var
from ._var import Var, _VarInfo, unwrap_vars

if TYPE_CHECKING:
from ._graph import Graph
Expand Down Expand Up @@ -60,11 +60,11 @@ class BuildResult:

scope: Scope
nodes: dict[Node, tuple[onnx.NodeProto, ...]]
arguments: tuple[Var, ...]
results: tuple[Var, ...]
arguments: tuple[_VarInfo, ...]
results: tuple[_VarInfo, ...]
opset_req: set[tuple[str, int]]
functions: tuple[_function.Function, ...]
initializers: dict[Var, np.ndarray]
initializers: dict[_VarInfo, np.ndarray]


class Builder:
Expand Down Expand Up @@ -95,7 +95,7 @@ class ScopeTree:
"""
Structure representing the tree of scopes, which are identified with the respective graphs.
This structure is the base of the least-enclosing-scope algorithm. Every value (Var), and hence
This structure is the base of the least-enclosing-scope algorithm. Every value (VarInfo), and hence
the responsible Node - up to its (Python object) identity may appear in multiple scopes, but it should
best-cased be computed only once in the ONNX graph, same as in the Python source code.
Expand Down Expand Up @@ -166,12 +166,12 @@ def lca(self, a: Graph, b: Graph) -> Graph:
graphs: set[Graph]
graph_topo: list[Graph]
# Arguments, results
arguments_of: dict[Graph, list[Var]]
results_of: dict[Graph, list[Var]]
arguments_of: dict[Graph, list[_VarInfo]]
results_of: dict[Graph, list[_VarInfo]]
source_of: dict[Graph, Node]
# Arguments found by traversal
all_arguments_in: dict[Graph, set[Var]]
claimed_arguments_in: dict[Graph, set[Var]]
all_arguments_in: dict[Graph, set[_VarInfo]]
claimed_arguments_in: dict[Graph, set[_VarInfo]]
# Scopes
scope_tree: ScopeTree
scope_own: dict[Graph, list[Node]]
Expand Down Expand Up @@ -220,7 +220,7 @@ def get_intro_results(
var._rename(key)
return vars

def discover(self, graph: Graph) -> tuple[set[Var], set[Var]]:
def discover(self, graph: Graph) -> tuple[set[_VarInfo], set[_VarInfo]]:
"""
Run the discovery step of the build process. Resolves arguments and results for the involved graphs.
Finds the topological ordering between (sub)graphs and sets their owners (nodes of which they are attributes).
Expand All @@ -246,8 +246,8 @@ def discover(self, graph: Graph) -> tuple[set[Var], set[Var]]:
# Create and set the source & results of this graph
if not graph.requested_results:
raise BuildError(f"Graph {graph} has no results.")
self.results_of[graph] = self.get_intro_results(
graph.requested_results, graph is self.main
self.results_of[graph] = unwrap_vars(
self.get_intro_results(graph.requested_results, graph is self.main)
)
self.source_of[graph] = self.results_of[graph][0]._op

Expand Down Expand Up @@ -291,8 +291,8 @@ def collect_arguments(nd: Node) -> None:
self.arguments_of[graph] = list(all_arguments - claimed_arguments)
else:
# If there is a request, we may not have found it by traversal if an argument was unused.
all_arguments |= set(graph.requested_arguments)
self.arguments_of[graph] = list(graph.requested_arguments)
all_arguments |= set(unwrap_vars(graph.requested_arguments))
self.arguments_of[graph] = unwrap_vars(graph.requested_arguments)

if set(self.arguments_of[graph]) & claimed_arguments:
raise BuildError(
Expand Down Expand Up @@ -434,7 +434,7 @@ def compile_graph(
# A bunch of model metadata we're collecting
opset_req: set[tuple[str, int]] = set()
functions: list[_function.Function] = []
initializers: dict[Var, np.ndarray] = {}
initializers: dict[_VarInfo, np.ndarray] = {}

# Add arguments to our scope
for arg in self.arguments_of[graph]:
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import contextmanager
from typing import Any

from spox._var import Var
from spox._var import _VarInfo

# If `STORE_TRACEBACK` is `True` any node created will store a traceback for its point of creation.
STORE_TRACEBACK = False
Expand Down Expand Up @@ -40,7 +40,7 @@ def show_construction_tracebacks(
if -1 in found:
del found[-1]
for name, obj in reversed(found.values()):
if isinstance(obj, Var):
if isinstance(obj, _VarInfo):
if not obj:
continue
node = obj._op
Expand Down
132 changes: 107 additions & 25 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import dataclasses
import enum
import warnings
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import Field, dataclass
from typing import Any, Optional, Union, get_type_hints
from typing import Optional, Union, get_type_hints

from ._attributes import Attr
from ._var import Var
from ._exceptions import InferenceWarning
from ._type_system import Optional as tOptional
from ._value_prop import PropDict, PropValue
from ._var import Var, _VarInfo


@dataclass
Expand All @@ -31,20 +37,63 @@ class VarFieldKind(enum.Enum):
VARIADIC = 2


class BaseVars:
"""A collection of `Var`-s used to carry around inputs/outputs of nodes"""

vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]]

def __init__(self, vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]]):
self.vars = vars

def _unpack_to_any(self) -> tuple[Union[Var, Optional[Var], Sequence[Var]], ...]:
"""Unpack the stored fields into a tuple of appropriate length, typed as Any."""
return tuple(self.vars.values())

def _flatten(self) -> Iterator[tuple[str, Optional[Var]]]:
"""Iterate over the pairs of names and values of fields in this object."""
for key, value in self.vars.items():
if value is None or isinstance(value, Var):
yield key, value
else:
yield from ((f"{key}_{i}", v) for i, v in enumerate(value))

def flatten_vars(self) -> dict[str, Var]:
"""Return a flat mapping by name of all the VarInfos in this object."""
return {key: var for key, var in self._flatten() if var is not None}

def __getattr__(self, attr: str) -> Union[Var, Optional[Var], Sequence[Var]]:
"""Retrieves the attribute if present in the stored variables."""
try:
return self.vars[attr]
except KeyError:
raise AttributeError(
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
)

def __setattr__(
self, attr: str, value: Union[Var, Optional[Var], Sequence[Var]]
) -> None:
"""Sets the attribute to a value if the attribute is present in the stored variables."""
if attr == "vars":
super().__setattr__(attr, value)
else:
self.vars[attr] = value


@dataclass
class BaseVars(BaseFields):
class BaseVarInfos(BaseFields):
def __post_init__(self) -> None:
# Check if passed fields are of the appropriate types based on field kinds
for field in dataclasses.fields(self):
value = getattr(self, field.name)
field_type = self._get_field_type(field)
if field_type == VarFieldKind.SINGLE:
if not isinstance(value, Var):
raise TypeError(f"Field expected Var, got: {type(value)}.")
if not isinstance(value, _VarInfo):
raise TypeError(f"Field expected VarInfo, got: {type(value)}.")
elif field_type == VarFieldKind.OPTIONAL:
if value is not None and not isinstance(value, Var):
if value is not None and not isinstance(value, _VarInfo):
raise TypeError(
f"Optional must be Var or None, got: {type(value)}."
f"Optional must be VarInfo or None, got: {type(value)}."
)
elif field_type == VarFieldKind.VARIADIC:
if not isinstance(value, Iterable):
Expand All @@ -53,9 +102,9 @@ def __post_init__(self) -> None:
)
# Cast to tuple to avoid accidental mutation
setattr(self, field.name, tuple(value))
if bad := {type(var) for var in value} - {Var}:
if bad := {type(var) for var in value} - {_VarInfo}:
raise TypeError(
f"Variadic field must only consist of Vars, got: {bad}."
f"Variadic field must only consist of VarInfos, got: {bad}."
)

@classmethod
Expand All @@ -64,56 +113,89 @@ def _get_field_type(cls, field: Field) -> VarFieldKind:
# The field.type may be unannotated as per
# from __future__ import annotations
field_type = get_type_hints(cls)[field.name]
if field_type == Var:
if field_type == _VarInfo:
return VarFieldKind.SINGLE
elif field_type == Optional[Var]:
elif field_type == Optional[_VarInfo]:
return VarFieldKind.OPTIONAL
elif field_type == Sequence[Var]:
elif field_type == Sequence[_VarInfo]:
return VarFieldKind.VARIADIC
raise ValueError(f"Bad field type: '{field.type}'.")

def _flatten(self) -> Iterable[tuple[str, Optional[Var]]]:
def _flatten(self) -> Iterable[tuple[str, Optional[_VarInfo]]]:
"""Iterate over the pairs of names and values of fields in this object."""
for key, value in self.__dict__.items():
if value is None or isinstance(value, Var):
if value is None or isinstance(value, _VarInfo):
yield key, value
else:
yield from ((f"{key}_{i}", v) for i, v in enumerate(value))

def __iter__(self) -> Iterator[Optional[Var]]:
def __iter__(self) -> Iterator[Optional[_VarInfo]]:
"""Iterate over the values of fields in this object."""
yield from (v for _, v in self._flatten())

def __len__(self) -> int:
"""Count the number of fields in this object (should be same as declared in the class)."""
return sum(1 for _ in self)

def get_vars(self) -> dict[str, Var]:
"""Return a flat mapping by name of all the Vars in this object."""
def get_var_infos(self) -> dict[str, _VarInfo]:
"""Return a flat mapping by name of all the VarInfos in this object."""
return {key: var for key, var in self._flatten() if var is not None}

def get_fields(self) -> dict[str, Union[None, Var, Sequence[Var]]]:
def get_fields(self) -> dict[str, Union[None, _VarInfo, Sequence[_VarInfo]]]:
"""Return a mapping of all fields stored in this object by name."""
return self.__dict__.copy()

def _unpack_to_any(self) -> Any:
"""Unpack the stored fields into a tuple of appropriate length, typed as Any."""
return tuple(self.__dict__.values())

@property
def fully_typed(self) -> bool:
"""Check if all stored variables have a concrete type."""
return all(
var.type is not None and var.type._is_concrete
for var in self.get_vars().values()
for var in self.get_var_infos().values()
)

def into_vars(self, prop_values: PropDict) -> BaseVars:
"""Populate a `BaseVars` object with the propagated values and this object's var_infos"""

def _create_var(key: str, var_info: _VarInfo) -> Var:
ret = Var(var_info, None)

if var_info.type is None or key not in prop_values:
return ret

if not isinstance(var_info.type, tOptional) and prop_values[key] is None:
return ret

prop = PropValue(var_info.type, prop_values[key])
if prop.check():
ret._value = prop
else:
warnings.warn(
InferenceWarning(
f"Propagated value {prop} does not type-check, dropping. "
f"Hint: this indicates a bug with the current value prop backend or type inference."
)
)

return ret

ret_dict: dict[str, Union[Var, Optional[Var], Sequence[Var]]] = {}

for key, var_info in self.__dict__.items():
if isinstance(var_info, _VarInfo):
ret_dict[key] = _create_var(key, var_info)
else:
ret_dict[key] = [
_create_var(f"{key}_{i}", v) for i, v in enumerate(var_info)
]

return BaseVars(ret_dict)


@dataclass
class BaseInputs(BaseVars):
class BaseInputs(BaseVarInfos):
pass


@dataclass
class BaseOutputs(BaseVars):
class BaseOutputs(BaseVarInfos):
pass
Loading

0 comments on commit 320d57f

Please sign in to comment.