Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce _VarInfo internally to reduce memory footprint in value propagation #189

Merged
merged 46 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
539ca29
Init
neNasko1 Oct 27, 2024
b6aa765
Run linter
neNasko1 Oct 29, 2024
1d6bcbc
Changes
neNasko1 Oct 31, 2024
3c8e9b1
Fix ml
neNasko1 Nov 1, 2024
096a4f4
Fix some tests
neNasko1 Nov 4, 2024
e29a920
Fix some tests
neNasko1 Nov 4, 2024
bfaed79
Fix more tests
neNasko1 Nov 4, 2024
49e366c
Add some proper typing
neNasko1 Nov 4, 2024
f5af5d9
More initializers
neNasko1 Nov 5, 2024
27c562e
Fix passing
neNasko1 Nov 5, 2024
50e828b
Minor fixes and linter
neNasko1 Nov 5, 2024
d6b59ca
Change initializers name to input_prop_values
neNasko1 Nov 6, 2024
340d4c2
Make tests passing
neNasko1 Nov 6, 2024
3d77a87
Hacky fix mypy
neNasko1 Nov 6, 2024
9060d12
Correctly codegen
neNasko1 Nov 6, 2024
6aa1bf4
Comments after code review
neNasko1 Nov 15, 2024
07f9676
Improve type checking
neNasko1 Nov 18, 2024
df0bee9
Pre-commit enable
neNasko1 Nov 18, 2024
02e36ac
Update documentation
neNasko1 Nov 18, 2024
9488f6a
Fix adapt node
neNasko1 Nov 19, 2024
c23086b
Fix function inputs passing
neNasko1 Nov 19, 2024
1aebc1a
Fix opset generation
neNasko1 Nov 19, 2024
67d5b3b
Hint that _VarInfo is private
neNasko1 Nov 20, 2024
c9de7ec
Merge branch 'main' into split-value-prop
neNasko1 Nov 22, 2024
3f25d7e
Fix jinja
neNasko1 Nov 22, 2024
a8cebe2
Fix variadic input value propagation
neNasko1 Nov 26, 2024
63be89b
Comments after code review
neNasko1 Nov 28, 2024
0d5e2c8
Move validation to after propagation
neNasko1 Nov 28, 2024
d14e300
Fix diff
neNasko1 Nov 28, 2024
e33b00e
Fix diffs
neNasko1 Nov 28, 2024
b9f922c
Merge branch 'main' into split-value-prop
neNasko1 Dec 4, 2024
fdf81a3
Improve type-hinting information
neNasko1 Dec 4, 2024
cfae394
Remove unneded functions
neNasko1 Dec 4, 2024
19d7ebb
Init
neNasko1 Dec 6, 2024
b9cb099
fix
neNasko1 Dec 6, 2024
756f274
Final fixes
neNasko1 Dec 9, 2024
ac8807e
Add test for propagation of optional var
neNasko1 Dec 9, 2024
e5c311f
Unify logic around VarInfos -> Var
neNasko1 Dec 9, 2024
7f87559
Add comment
neNasko1 Dec 9, 2024
1b03440
Merge with main
neNasko1 Dec 9, 2024
8e5d25a
Improve qol
neNasko1 Dec 9, 2024
4215495
Update CHANGELOG.rst
neNasko1 Dec 9, 2024
fdb89cf
Fix tools/generate
neNasko1 Dec 9, 2024
07a8f91
Update CHANGELOG.rst
neNasko1 Dec 10, 2024
6360d61
Merge branch 'main' into split-value-prop
neNasko1 Dec 10, 2024
40b8b87
Comments after code-review
neNasko1 Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
Change log
==========

0.14.0 (unreleased)
-------------------

**Other changes**

- Propagated values may now be garbage collected if their associated `Var` object goes out of scope.

0.13.0 (2024-12-06)
-------------------

Expand Down
20 changes: 10 additions & 10 deletions src/spox/_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from ._node import Node
from ._schemas import SCHEMAS
from ._scope import Scope
from ._var import Var
from ._var import _VarInfo


def adapt_node(
node: Node,
proto: onnx.NodeProto,
source_version: int,
target_version: int,
var_names: dict[Var, str],
var_names: dict[_VarInfo, str],
) -> Optional[list[onnx.NodeProto]]:
if source_version == target_version:
return None
Expand All @@ -30,16 +30,16 @@ def adapt_node(
# By using a dictionary we ensure that we only have a single
# ValueInfo per (possibly repeated) input name.
input_info = {
var_names[var]: var.unwrap_type()._to_onnx_value_info(
var_names[var], _traceback_name=f"adapt-input {key}"
var_names[var_info]: var_info.unwrap_type()._to_onnx_value_info(
var_names[var_info], _traceback_name=f"adapt-input {key}"
)
for key, var in node.inputs.get_vars().items()
for key, var_info in node.inputs.get_var_infos().items()
}
output_info = [
var.unwrap_type()._to_onnx_value_info(
var_names[var], _traceback_name=f"adapt-output {key}"
var_info.unwrap_type()._to_onnx_value_info(
var_names[var_info], _traceback_name=f"adapt-output {key}"
)
for key, var in node.outputs.get_vars().items()
for key, var_info in node.outputs.get_var_infos().items()
]
except ValueError:
return None
Expand All @@ -63,7 +63,7 @@ def adapt_inline(
node: _Inline,
protos: list[onnx.NodeProto],
target_opsets: dict[str, int],
var_names: dict[Var, str],
var_names: dict[_VarInfo, str],
node_name: str,
) -> list[onnx.NodeProto]:
source_version = max({v for d, v in node.opset_req if d in ("", "ai.onnx")})
Expand Down Expand Up @@ -91,7 +91,7 @@ def adapt_best_effort(
node: Node,
protos: list[onnx.NodeProto],
opsets: dict[str, int],
var_names: dict[Var, str],
var_names: dict[_VarInfo, str],
node_names: dict[Node, str],
) -> Optional[list[onnx.NodeProto]]:
if isinstance(node, _Inline):
Expand Down
30 changes: 15 additions & 15 deletions src/spox/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ._node import Node
from ._scope import Scope
from ._traverse import iterative_dfs
from ._var import Var
from ._var import Var, _VarInfo, unwrap_vars

if TYPE_CHECKING:
from ._graph import Graph
Expand Down Expand Up @@ -60,11 +60,11 @@ class BuildResult:

scope: Scope
nodes: dict[Node, tuple[onnx.NodeProto, ...]]
arguments: tuple[Var, ...]
results: tuple[Var, ...]
arguments: tuple[_VarInfo, ...]
results: tuple[_VarInfo, ...]
opset_req: set[tuple[str, int]]
functions: tuple[_function.Function, ...]
initializers: dict[Var, np.ndarray]
initializers: dict[_VarInfo, np.ndarray]


class Builder:
Expand Down Expand Up @@ -95,7 +95,7 @@ class ScopeTree:
"""
Structure representing the tree of scopes, which are identified with the respective graphs.

This structure is the base of the least-enclosing-scope algorithm. Every value (Var), and hence
This structure is the base of the least-enclosing-scope algorithm. Every value (VarInfo), and hence
the responsible Node - up to its (Python object) identity may appear in multiple scopes, but it should
best-cased be computed only once in the ONNX graph, same as in the Python source code.

Expand Down Expand Up @@ -166,12 +166,12 @@ def lca(self, a: Graph, b: Graph) -> Graph:
graphs: set[Graph]
graph_topo: list[Graph]
# Arguments, results
arguments_of: dict[Graph, list[Var]]
results_of: dict[Graph, list[Var]]
arguments_of: dict[Graph, list[_VarInfo]]
results_of: dict[Graph, list[_VarInfo]]
source_of: dict[Graph, Node]
# Arguments found by traversal
all_arguments_in: dict[Graph, set[Var]]
claimed_arguments_in: dict[Graph, set[Var]]
all_arguments_in: dict[Graph, set[_VarInfo]]
claimed_arguments_in: dict[Graph, set[_VarInfo]]
# Scopes
scope_tree: ScopeTree
scope_own: dict[Graph, list[Node]]
Expand Down Expand Up @@ -220,7 +220,7 @@ def get_intro_results(
var._rename(key)
return vars

def discover(self, graph: Graph) -> tuple[set[Var], set[Var]]:
def discover(self, graph: Graph) -> tuple[set[_VarInfo], set[_VarInfo]]:
"""
Run the discovery step of the build process. Resolves arguments and results for the involved graphs.
Finds the topological ordering between (sub)graphs and sets their owners (nodes of which they are attributes).
Expand All @@ -246,8 +246,8 @@ def discover(self, graph: Graph) -> tuple[set[Var], set[Var]]:
# Create and set the source & results of this graph
if not graph.requested_results:
raise BuildError(f"Graph {graph} has no results.")
self.results_of[graph] = self.get_intro_results(
graph.requested_results, graph is self.main
self.results_of[graph] = unwrap_vars(
self.get_intro_results(graph.requested_results, graph is self.main)
)
self.source_of[graph] = self.results_of[graph][0]._op

Expand Down Expand Up @@ -291,8 +291,8 @@ def collect_arguments(nd: Node) -> None:
self.arguments_of[graph] = list(all_arguments - claimed_arguments)
else:
# If there is a request, we may not have found it by traversal if an argument was unused.
all_arguments |= set(graph.requested_arguments)
self.arguments_of[graph] = list(graph.requested_arguments)
all_arguments |= set(unwrap_vars(graph.requested_arguments))
self.arguments_of[graph] = unwrap_vars(graph.requested_arguments)

if set(self.arguments_of[graph]) & claimed_arguments:
raise BuildError(
Expand Down Expand Up @@ -434,7 +434,7 @@ def compile_graph(
# A bunch of model metadata we're collecting
opset_req: set[tuple[str, int]] = set()
functions: list[_function.Function] = []
initializers: dict[Var, np.ndarray] = {}
initializers: dict[_VarInfo, np.ndarray] = {}

# Add arguments to our scope
for arg in self.arguments_of[graph]:
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import contextmanager
from typing import Any

from spox._var import Var
from spox._var import _VarInfo

# If `STORE_TRACEBACK` is `True` any node created will store a traceback for its point of creation.
STORE_TRACEBACK = False
Expand Down Expand Up @@ -40,7 +40,7 @@ def show_construction_tracebacks(
if -1 in found:
del found[-1]
for name, obj in reversed(found.values()):
if isinstance(obj, Var):
if isinstance(obj, _VarInfo):
if not obj:
continue
node = obj._op
Expand Down
132 changes: 107 additions & 25 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import dataclasses
import enum
import warnings
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import Field, dataclass
from typing import Any, Optional, Union, get_type_hints
from typing import Optional, Union, get_type_hints

from ._attributes import Attr
from ._var import Var
from ._exceptions import InferenceWarning
from ._type_system import Optional as tOptional
from ._value_prop import PropDict, PropValue
from ._var import Var, _VarInfo


@dataclass
Expand All @@ -31,20 +37,63 @@ class VarFieldKind(enum.Enum):
VARIADIC = 2


class BaseVars:
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
"""A collection of `Var`-s used to carry around inputs/outputs of nodes"""

vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]]

def __init__(self, vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]]):
self.vars = vars

def _unpack_to_any(self) -> tuple[Union[Var, Optional[Var], Sequence[Var]], ...]:
"""Unpack the stored fields into a tuple of appropriate length, typed as Any."""
return tuple(self.vars.values())

def _flatten(self) -> Iterator[tuple[str, Optional[Var]]]:
"""Iterate over the pairs of names and values of fields in this object."""
for key, value in self.vars.items():
if value is None or isinstance(value, Var):
yield key, value
else:
yield from ((f"{key}_{i}", v) for i, v in enumerate(value))

def flatten_vars(self) -> dict[str, Var]:
"""Return a flat mapping by name of all the VarInfos in this object."""
return {key: var for key, var in self._flatten() if var is not None}

def __getattr__(self, attr: str) -> Union[Var, Optional[Var], Sequence[Var]]:
"""Retrieves the attribute if present in the stored variables."""
try:
return self.vars[attr]
except KeyError:
raise AttributeError(
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
)

def __setattr__(
self, attr: str, value: Union[Var, Optional[Var], Sequence[Var]]
) -> None:
"""Sets the attribute to a value if the attribute is present in the stored variables."""
if attr == "vars":
super().__setattr__(attr, value)
else:
self.vars[attr] = value


@dataclass
class BaseVars(BaseFields):
class BaseVarInfos(BaseFields):
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
def __post_init__(self) -> None:
# Check if passed fields are of the appropriate types based on field kinds
for field in dataclasses.fields(self):
value = getattr(self, field.name)
field_type = self._get_field_type(field)
if field_type == VarFieldKind.SINGLE:
if not isinstance(value, Var):
raise TypeError(f"Field expected Var, got: {type(value)}.")
if not isinstance(value, _VarInfo):
raise TypeError(f"Field expected VarInfo, got: {type(value)}.")
elif field_type == VarFieldKind.OPTIONAL:
if value is not None and not isinstance(value, Var):
if value is not None and not isinstance(value, _VarInfo):
raise TypeError(
f"Optional must be Var or None, got: {type(value)}."
f"Optional must be VarInfo or None, got: {type(value)}."
)
elif field_type == VarFieldKind.VARIADIC:
if not isinstance(value, Iterable):
Expand All @@ -53,9 +102,9 @@ def __post_init__(self) -> None:
)
# Cast to tuple to avoid accidental mutation
setattr(self, field.name, tuple(value))
if bad := {type(var) for var in value} - {Var}:
if bad := {type(var) for var in value} - {_VarInfo}:
raise TypeError(
f"Variadic field must only consist of Vars, got: {bad}."
f"Variadic field must only consist of VarInfos, got: {bad}."
)

@classmethod
Expand All @@ -64,56 +113,89 @@ def _get_field_type(cls, field: Field) -> VarFieldKind:
# The field.type may be unannotated as per
# from __future__ import annotations
field_type = get_type_hints(cls)[field.name]
if field_type == Var:
if field_type == _VarInfo:
return VarFieldKind.SINGLE
elif field_type == Optional[Var]:
elif field_type == Optional[_VarInfo]:
return VarFieldKind.OPTIONAL
elif field_type == Sequence[Var]:
elif field_type == Sequence[_VarInfo]:
return VarFieldKind.VARIADIC
raise ValueError(f"Bad field type: '{field.type}'.")

def _flatten(self) -> Iterable[tuple[str, Optional[Var]]]:
def _flatten(self) -> Iterable[tuple[str, Optional[_VarInfo]]]:
"""Iterate over the pairs of names and values of fields in this object."""
for key, value in self.__dict__.items():
if value is None or isinstance(value, Var):
if value is None or isinstance(value, _VarInfo):
yield key, value
else:
yield from ((f"{key}_{i}", v) for i, v in enumerate(value))

def __iter__(self) -> Iterator[Optional[Var]]:
def __iter__(self) -> Iterator[Optional[_VarInfo]]:
"""Iterate over the values of fields in this object."""
yield from (v for _, v in self._flatten())

def __len__(self) -> int:
"""Count the number of fields in this object (should be same as declared in the class)."""
return sum(1 for _ in self)

def get_vars(self) -> dict[str, Var]:
"""Return a flat mapping by name of all the Vars in this object."""
def get_var_infos(self) -> dict[str, _VarInfo]:
"""Return a flat mapping by name of all the VarInfos in this object."""
return {key: var for key, var in self._flatten() if var is not None}

def get_fields(self) -> dict[str, Union[None, Var, Sequence[Var]]]:
def get_fields(self) -> dict[str, Union[None, _VarInfo, Sequence[_VarInfo]]]:
"""Return a mapping of all fields stored in this object by name."""
return self.__dict__.copy()

def _unpack_to_any(self) -> Any:
"""Unpack the stored fields into a tuple of appropriate length, typed as Any."""
return tuple(self.__dict__.values())

@property
def fully_typed(self) -> bool:
"""Check if all stored variables have a concrete type."""
return all(
var.type is not None and var.type._is_concrete
for var in self.get_vars().values()
for var in self.get_var_infos().values()
)

def into_vars(self, prop_values: PropDict) -> BaseVars:
"""Populate a `BaseVars` object with the propagated values and this object's var_infos"""

def _create_var(key: str, var_info: _VarInfo) -> Var:
ret = Var(var_info, None)

if var_info.type is None or key not in prop_values:
return ret

if not isinstance(var_info.type, tOptional) and prop_values[key] is None:
return ret

prop = PropValue(var_info.type, prop_values[key])
if prop.check():
ret._value = prop
else:
warnings.warn(
InferenceWarning(
f"Propagated value {prop} does not type-check, dropping. "
f"Hint: this indicates a bug with the current value prop backend or type inference."
)
)

return ret

ret_dict: dict[str, Union[Var, Optional[Var], Sequence[Var]]] = {}

for key, var_info in self.__dict__.items():
if isinstance(var_info, _VarInfo):
ret_dict[key] = _create_var(key, var_info)
else:
ret_dict[key] = [
_create_var(f"{key}_{i}", v) for i, v in enumerate(var_info)
]

return BaseVars(ret_dict)


@dataclass
class BaseInputs(BaseVars):
class BaseInputs(BaseVarInfos):
pass


@dataclass
class BaseOutputs(BaseVars):
class BaseOutputs(BaseVarInfos):
pass
Loading
Loading