Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce _VarInfo internally to reduce memory footprint in value propagation #189

Merged
merged 46 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
539ca29
Init
neNasko1 Oct 27, 2024
b6aa765
Run linter
neNasko1 Oct 29, 2024
1d6bcbc
Changes
neNasko1 Oct 31, 2024
3c8e9b1
Fix ml
neNasko1 Nov 1, 2024
096a4f4
Fix some tests
neNasko1 Nov 4, 2024
e29a920
Fix some tests
neNasko1 Nov 4, 2024
bfaed79
Fix more tests
neNasko1 Nov 4, 2024
49e366c
Add some proper typing
neNasko1 Nov 4, 2024
f5af5d9
More initializers
neNasko1 Nov 5, 2024
27c562e
Fix passing
neNasko1 Nov 5, 2024
50e828b
Minor fixes and linter
neNasko1 Nov 5, 2024
d6b59ca
Change initializers name to input_prop_values
neNasko1 Nov 6, 2024
340d4c2
Make tests passing
neNasko1 Nov 6, 2024
3d77a87
Hacky fix mypy
neNasko1 Nov 6, 2024
9060d12
Correctly codegen
neNasko1 Nov 6, 2024
6aa1bf4
Comments after code review
neNasko1 Nov 15, 2024
07f9676
Improve type checking
neNasko1 Nov 18, 2024
df0bee9
Pre-commit enable
neNasko1 Nov 18, 2024
02e36ac
Update documentation
neNasko1 Nov 18, 2024
9488f6a
Fix adapt node
neNasko1 Nov 19, 2024
c23086b
Fix function inputs passing
neNasko1 Nov 19, 2024
1aebc1a
Fix opset generation
neNasko1 Nov 19, 2024
67d5b3b
Hint that _VarInfo is private
neNasko1 Nov 20, 2024
c9de7ec
Merge branch 'main' into split-value-prop
neNasko1 Nov 22, 2024
3f25d7e
Fix jinja
neNasko1 Nov 22, 2024
a8cebe2
Fix variadic input value propagation
neNasko1 Nov 26, 2024
63be89b
Comments after code review
neNasko1 Nov 28, 2024
0d5e2c8
Move validation to after propagation
neNasko1 Nov 28, 2024
d14e300
Fix diff
neNasko1 Nov 28, 2024
e33b00e
Fix diffs
neNasko1 Nov 28, 2024
b9f922c
Merge branch 'main' into split-value-prop
neNasko1 Dec 4, 2024
fdf81a3
Improve type-hinting information
neNasko1 Dec 4, 2024
cfae394
Remove unneded functions
neNasko1 Dec 4, 2024
19d7ebb
Init
neNasko1 Dec 6, 2024
b9cb099
fix
neNasko1 Dec 6, 2024
756f274
Final fixes
neNasko1 Dec 9, 2024
ac8807e
Add test for propagation of optional var
neNasko1 Dec 9, 2024
e5c311f
Unify logic around VarInfos -> Var
neNasko1 Dec 9, 2024
7f87559
Add comment
neNasko1 Dec 9, 2024
1b03440
Merge with main
neNasko1 Dec 9, 2024
8e5d25a
Improve qol
neNasko1 Dec 9, 2024
4215495
Update CHANGELOG.rst
neNasko1 Dec 9, 2024
fdb89cf
Fix tools/generate
neNasko1 Dec 9, 2024
07a8f91
Update CHANGELOG.rst
neNasko1 Dec 10, 2024
6360d61
Merge branch 'main' into split-value-prop
neNasko1 Dec 10, 2024
40b8b87
Comments after code-review
neNasko1 Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Change log
0.13.0 (unreleased)
-------------------

**Other changes**

- Split `Var`-s internally to help the garbage collector in collecting propagated values.
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved

**Support change**

- Support for ``Python 3.8`` has been dropped.
Expand Down
28 changes: 10 additions & 18 deletions src/spox/_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import warnings
from typing import Optional

import numpy as np
import onnx
import onnx.version_converter

Expand All @@ -14,16 +13,15 @@
from ._node import Node
from ._schemas import SCHEMAS
from ._scope import Scope
from ._utils import from_array
from ._var import Var
from ._var import _VarInfo


def adapt_node(
node: Node,
proto: onnx.NodeProto,
source_version: int,
target_version: int,
var_names: dict[Var, str],
var_names: dict[_VarInfo, str],
) -> Optional[list[onnx.NodeProto]]:
if source_version == target_version:
return None
Expand All @@ -32,21 +30,16 @@ def adapt_node(
# By using a dictionary we ensure that we only have a single
# ValueInfo per (possibly repeated) input name.
input_info = {
var_names[var]: var.unwrap_type()._to_onnx_value_info(
var_names[var], _traceback_name=f"adapt-input {key}"
var_names[var_info]: var_info.unwrap_type()._to_onnx_value_info(
var_names[var_info], _traceback_name=f"adapt-input {key}"
)
for key, var in node.inputs.get_vars().items()
for key, var_info in node.inputs.get_var_infos().items()
}
output_info = [
var.unwrap_type()._to_onnx_value_info(
var_names[var], _traceback_name=f"adapt-output {key}"
var_info.unwrap_type()._to_onnx_value_info(
var_names[var_info], _traceback_name=f"adapt-output {key}"
)
for key, var in node.outputs.get_vars().items()
]
initializers = [
from_array(var._value, name) # type: ignore
for name, var in node.inputs.get_vars().items()
if isinstance(var._value, np.ndarray)
for key, var_info in node.outputs.get_var_infos().items()
]
except ValueError:
return None
Expand All @@ -57,7 +50,6 @@ def adapt_node(
"spox__singleton_adapter_graph",
list(input_info.values()),
output_info,
initializers,
),
opset_imports=[onnx.helper.make_operatorsetid("", source_version)],
)
Expand All @@ -71,7 +63,7 @@ def adapt_inline(
node: _Inline,
protos: list[onnx.NodeProto],
target_opsets: dict[str, int],
var_names: dict[Var, str],
var_names: dict[_VarInfo, str],
node_name: str,
) -> list[onnx.NodeProto]:
source_version = max({v for d, v in node.opset_req if d in ("", "ai.onnx")})
Expand Down Expand Up @@ -99,7 +91,7 @@ def adapt_best_effort(
node: Node,
protos: list[onnx.NodeProto],
opsets: dict[str, int],
var_names: dict[Var, str],
var_names: dict[_VarInfo, str],
node_names: dict[Node, str],
) -> Optional[list[onnx.NodeProto]]:
if isinstance(node, _Inline):
Expand Down
30 changes: 15 additions & 15 deletions src/spox/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ._node import Node
from ._scope import Scope
from ._traverse import iterative_dfs
from ._var import Var
from ._var import Var, _VarInfo, unwrap_vars

if TYPE_CHECKING:
from ._graph import Graph
Expand Down Expand Up @@ -58,11 +58,11 @@ class BuildResult:

scope: Scope
nodes: dict[Node, tuple[onnx.NodeProto, ...]]
arguments: tuple[Var, ...]
results: tuple[Var, ...]
arguments: tuple[_VarInfo, ...]
results: tuple[_VarInfo, ...]
opset_req: set[tuple[str, int]]
functions: tuple["_function.Function", ...]
initializers: dict[Var, np.ndarray]
initializers: dict[_VarInfo, np.ndarray]


class Builder:
Expand Down Expand Up @@ -93,7 +93,7 @@ class ScopeTree:
"""
Structure representing the tree of scopes, which are identified with the respective graphs.

This structure is the base of the least-enclosing-scope algorithm. Every value (Var), and hence
This structure is the base of the least-enclosing-scope algorithm. Every value (VarInfo), and hence
the responsible Node - up to its (Python object) identity may appear in multiple scopes, but it should
best-cased be computed only once in the ONNX graph, same as in the Python source code.

Expand Down Expand Up @@ -164,12 +164,12 @@ def lca(self, a: "Graph", b: "Graph") -> "Graph":
graphs: set["Graph"]
graph_topo: list["Graph"]
# Arguments, results
arguments_of: dict["Graph", list[Var]]
results_of: dict["Graph", list[Var]]
arguments_of: dict["Graph", list[_VarInfo]]
results_of: dict["Graph", list[_VarInfo]]
source_of: dict["Graph", Node]
# Arguments found by traversal
all_arguments_in: dict["Graph", set[Var]]
claimed_arguments_in: dict["Graph", set[Var]]
all_arguments_in: dict["Graph", set[_VarInfo]]
claimed_arguments_in: dict["Graph", set[_VarInfo]]
# Scopes
scope_tree: ScopeTree
scope_own: dict["Graph", list[Node]]
Expand Down Expand Up @@ -218,7 +218,7 @@ def get_intro_results(
var._rename(key)
return vars

def discover(self, graph: "Graph") -> tuple[set[Var], set[Var]]:
def discover(self, graph: "Graph") -> tuple[set[_VarInfo], set[_VarInfo]]:
"""
Run the discovery step of the build process. Resolves arguments and results for the involved graphs.
Finds the topological ordering between (sub)graphs and sets their owners (nodes of which they are attributes).
Expand All @@ -244,8 +244,8 @@ def discover(self, graph: "Graph") -> tuple[set[Var], set[Var]]:
# Create and set the source & results of this graph
if not graph.requested_results:
raise BuildError(f"Graph {graph} has no results.")
self.results_of[graph] = self.get_intro_results(
graph.requested_results, graph is self.main
self.results_of[graph] = unwrap_vars(
self.get_intro_results(graph.requested_results, graph is self.main)
)
self.source_of[graph] = self.results_of[graph][0]._op

Expand Down Expand Up @@ -289,8 +289,8 @@ def collect_arguments(nd: Node):
self.arguments_of[graph] = list(all_arguments - claimed_arguments)
else:
# If there is a request, we may not have found it by traversal if an argument was unused.
all_arguments |= set(graph.requested_arguments)
self.arguments_of[graph] = list(graph.requested_arguments)
all_arguments |= set(unwrap_vars(graph.requested_arguments))
self.arguments_of[graph] = unwrap_vars(graph.requested_arguments)

if set(self.arguments_of[graph]) & claimed_arguments:
raise BuildError(
Expand Down Expand Up @@ -432,7 +432,7 @@ def compile_graph(
# A bunch of model metadata we're collecting
opset_req: set[tuple[str, int]] = set()
functions: list[_function.Function] = []
initializers: dict[Var, np.ndarray] = {}
initializers: dict[_VarInfo, np.ndarray] = {}

# Add arguments to our scope
for arg in self.arguments_of[graph]:
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from contextlib import contextmanager

from spox._var import Var
from spox._var import _VarInfo

# If `STORE_TRACEBACK` is `True` any node created will store a traceback for its point of creation.
STORE_TRACEBACK = False
Expand Down Expand Up @@ -36,7 +36,7 @@ def show_construction_tracebacks(debug_index):
if -1 in found:
del found[-1]
for name, obj in reversed(found.values()):
if isinstance(obj, Var):
if isinstance(obj, _VarInfo):
if not obj:
continue
node = obj._op
Expand Down
Loading