Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type-hints in function definitions #194

Merged
merged 10 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import subprocess
import sys
from subprocess import CalledProcessError
from typing import cast
from typing import Any, Optional, cast

_mod = importlib.import_module("spox")

Expand Down Expand Up @@ -60,7 +60,7 @@
# Copied and adapted from
# https://github.com/pandas-dev/pandas/blob/4a14d064187367cacab3ff4652a12a0e45d0711b/doc/source/conf.py#L613-L659
# Required configuration function to use sphinx.ext.linkcode
def linkcode_resolve(domain, info):
def linkcode_resolve(domain: str, info: dict[str, Any]) -> Optional[str]:
"""Determine the URL corresponding to a given Python object."""
if domain != "py":
return None
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ namespaces = false
[tool.ruff.lint]
# Enable the isort rules.
extend-select = ["I", "UP"]
ignore = [
"UP007", # https://docs.astral.sh/ruff/rules/non-pep604-annotation/
]

[tool.ruff.lint.isort]
known-first-party = ["spox"]
Expand All @@ -62,6 +65,7 @@ ignore_missing_imports = true
no_implicit_optional = true
check_untyped_defs = true
warn_unused_ignores = true
disallow_untyped_defs = true

[tool.pytest.ini_options]
# This will be pytest's future default.
Expand Down
8 changes: 4 additions & 4 deletions src/spox/_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def value(self) -> T:
return _deref(self._value)
return self._value

def _validate(self):
def _validate(self) -> None:
try:
type_in_onnx = self._to_onnx().type
except Exception as e:
Expand Down Expand Up @@ -85,7 +85,7 @@ def _to_onnx_deref(self) -> AttributeProto:
"""Conversion method for the dereferenced case."""
raise NotImplementedError()

def _get_pretty_type_exception(self):
def _get_pretty_type_exception(self) -> TypeError:
if isinstance(self.value, tuple) and len(self.value):
types = ", ".join(sorted({type(v).__name__ for v in self.value}))
msg = f"Unable to instantiate `{type(self).__name__}` from items of type(s) `{types}`."
Expand Down Expand Up @@ -177,7 +177,7 @@ class AttrDtype(Attr[npt.DTypeLike]):

_attribute_proto_type = AttributeProto.INT

def _validate(self):
def _validate(self) -> None:
dtype_to_tensor_type(self.value)

def _to_onnx_deref(self) -> AttributeProto:
Expand All @@ -187,7 +187,7 @@ def _to_onnx_deref(self) -> AttributeProto:
class AttrGraph(Attr[Any]):
_attribute_proto_type = AttributeProto.GRAPH

def _validate(self):
def _validate(self) -> None:
from spox._graph import Graph

if not isinstance(self.value, Graph):
Expand Down
48 changes: 25 additions & 23 deletions src/spox/_build.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import itertools
from dataclasses import dataclass
from typing import (
Expand Down Expand Up @@ -44,7 +46,7 @@ def value(self) -> T:
return self._value

@value.setter
def value(self, to: T):
def value(self, to: T) -> None:
self._value = to


Expand All @@ -61,7 +63,7 @@ class BuildResult:
arguments: tuple[Var, ...]
results: tuple[Var, ...]
opset_req: set[tuple[str, int]]
functions: tuple["_function.Function", ...]
functions: tuple[_function.Function, ...]
initializers: dict[Var, np.ndarray]


Expand Down Expand Up @@ -121,14 +123,14 @@ class ScopeTree:
(lowest common ancestor), which is a common operation on trees.
"""

subgraph_owner: dict["Graph", Node]
scope_of: dict[Node, "Graph"]
subgraph_owner: dict[Graph, Node]
scope_of: dict[Node, Graph]

def __init__(self):
def __init__(self) -> None:
self.subgraph_owner = {}
self.scope_of = {}

def parent(self, graph: "Graph") -> "Graph":
def parent(self, graph: Graph) -> Graph:
"""
Return the parent of a scope in the represented scope tree.

Expand All @@ -141,7 +143,7 @@ def parent(self, graph: "Graph") -> "Graph":
else graph
)

def lca(self, a: "Graph", b: "Graph") -> "Graph":
def lca(self, a: Graph, b: Graph) -> Graph:
"""
A simple LCA algorithm without preprocessing that only accesses the parents.

Expand All @@ -160,21 +162,21 @@ def lca(self, a: "Graph", b: "Graph") -> "Graph":
return a

# Graphs needed in the build
main: "Graph"
graphs: set["Graph"]
graph_topo: list["Graph"]
main: Graph
graphs: set[Graph]
graph_topo: list[Graph]
# Arguments, results
arguments_of: dict["Graph", list[Var]]
results_of: dict["Graph", list[Var]]
source_of: dict["Graph", Node]
arguments_of: dict[Graph, list[Var]]
results_of: dict[Graph, list[Var]]
source_of: dict[Graph, Node]
# Arguments found by traversal
all_arguments_in: dict["Graph", set[Var]]
claimed_arguments_in: dict["Graph", set[Var]]
all_arguments_in: dict[Graph, set[Var]]
claimed_arguments_in: dict[Graph, set[Var]]
# Scopes
scope_tree: ScopeTree
scope_own: dict["Graph", list[Node]]
scope_own: dict[Graph, list[Node]]

def __init__(self, main: "Graph"):
def __init__(self, main: Graph):
self.main = main
self.graphs = set()
self.graph_topo = list()
Expand Down Expand Up @@ -218,7 +220,7 @@ def get_intro_results(
var._rename(key)
return vars

def discover(self, graph: "Graph") -> tuple[set[Var], set[Var]]:
def discover(self, graph: Graph) -> tuple[set[Var], set[Var]]:
"""
Run the discovery step of the build process. Resolves arguments and results for the involved graphs.
Finds the topological ordering between (sub)graphs and sets their owners (nodes of which they are attributes).
Expand Down Expand Up @@ -255,7 +257,7 @@ def discover(self, graph: "Graph") -> tuple[set[Var], set[Var]]:
claimed_arguments = self.claimed_arguments_in[graph] = set()
used_arguments = set()

def collect_arguments(nd: Node):
def collect_arguments(nd: Node) -> None:
nonlocal all_arguments, claimed_arguments, used_arguments
if isinstance(nd, Argument):
all_arguments.add(nd.outputs.arg)
Expand Down Expand Up @@ -309,7 +311,7 @@ def collect_arguments(nd: Node):

return all_arguments, claimed_arguments

def update_scope_tree(self, graph: "Graph") -> None:
def update_scope_tree(self, graph: Graph) -> None:
"""
Traverse ``graph`` and update the Builder's scope tree to accommodate the input constraints inside it.

Expand All @@ -329,7 +331,7 @@ def update_scope_tree(self, graph: "Graph") -> None:
is completed 'bottom-up'.
"""

def satisfy_constraints(node):
def satisfy_constraints(node: Node) -> None:
# By default, a node is bound to the scope it is found in.
self.scope_tree.scope_of.setdefault(node, graph)
# Bring up the scope of its node to its ancestors if it is too low to be accessible in the current graph.
Expand Down Expand Up @@ -394,7 +396,7 @@ def get_build_subgraph_callback(
subgraph_opset_req = set() # Keeps track of all opset imports in subgraphs

def build_subgraph(
subgraph_of: Node, key: str, subgraph: "Graph"
subgraph_of: Node, key: str, subgraph: Graph
) -> onnx.GraphProto:
nonlocal subgraph_opset_req
subgraph_name = scope.node[subgraph_of] + f"_{key}"
Expand All @@ -407,7 +409,7 @@ def build_subgraph(
return build_subgraph, subgraph_opset_req

def compile_graph(
self, graph: "Graph", scope: Scope, prefix: str = ""
self, graph: Graph, scope: Scope, prefix: str = ""
) -> BuildResult:
"""
Compile a given Graph into a BuildResult. Handles naming of all the Vars/Nodes and only adds Nodes to a
Expand Down
6 changes: 5 additions & 1 deletion src/spox/_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: BSD-3-Clause

import sys
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any

from spox._var import Var

Expand All @@ -11,7 +13,9 @@


@contextmanager
def show_construction_tracebacks(debug_index):
def show_construction_tracebacks(
debug_index: dict[str, Any],
) -> Iterator[None]:
"""
Context manager constructed with a ``Builder.build_result.debug_index``.

Expand Down
17 changes: 10 additions & 7 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import dataclasses
import enum
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from typing import Any, Optional, Union
from dataclasses import Field, dataclass
from typing import Any, Optional, Union, get_type_hints

from ._attributes import Attr
from ._var import Var
Expand Down Expand Up @@ -33,7 +33,7 @@ class VarFieldKind(enum.Enum):

@dataclass
class BaseVars(BaseFields):
def __post_init__(self):
def __post_init__(self) -> None:
# Check if passed fields are of the appropriate types based on field kinds
for field in dataclasses.fields(self):
value = getattr(self, field.name)
Expand All @@ -59,13 +59,16 @@ def __post_init__(self):
)

@classmethod
def _get_field_type(cls, field) -> VarFieldKind:
def _get_field_type(cls, field: Field) -> VarFieldKind:
"""Access the kind of the field (single, optional, variadic) based on its type annotation."""
if field.type == Var:
# The field.type may be unannotated as per
# from __future__ import annotations
field_type = get_type_hints(cls)[field.name]
if field_type == Var:
return VarFieldKind.SINGLE
elif field.type == Optional[Var]:
elif field_type == Optional[Var]:
return VarFieldKind.OPTIONAL
elif field.type == Sequence[Var]:
elif field_type == Sequence[Var]:
return VarFieldKind.VARIADIC
raise ValueError(f"Bad field type: '{field.type}'.")

Expand Down
44 changes: 32 additions & 12 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import inspect
import itertools
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Callable, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union

import numpy as np
import onnx

from . import _attributes
Expand Down Expand Up @@ -46,9 +49,11 @@ class Function(_InternalNode):
func_attrs: dict[str, _attributes.Attr]
func_inputs: BaseInputs
func_outputs: BaseOutputs
func_graph: "_graph.Graph"
func_graph: _graph.Graph

def constructor(self, attrs, inputs):
def constructor(
self, attrs: dict[str, _attributes.Attr], inputs: BaseInputs
) -> BaseOutputs:
"""
Abstract method for functions.

Expand Down Expand Up @@ -89,11 +94,16 @@ def infer_output_types(self) -> dict[str, Type]:
}

@property
def opset_req(self):
def opset_req(self) -> set[tuple[str, int]]:
node_opset_req = Node.opset_req.fget(self) # type: ignore
return node_opset_req | self.func_graph._get_build_result().opset_req

def update_metadata(self, opset_req, initializers, functions):
def update_metadata(
self,
opset_req: set[tuple[str, int]],
initializers: dict[Var, np.ndarray],
functions: list[Function],
) -> None:
super().update_metadata(opset_req, initializers, functions)
functions.append(self)
functions.extend(self.func_graph._get_build_result().functions)
Expand Down Expand Up @@ -123,10 +133,18 @@ def to_onnx_function(
)


def _make_function_cls(fun, num_inputs, num_outputs, domain, version, name):
def _make_function_cls(
fun: Callable[..., Any],
num_inputs: int,
num_outputs: int,
domain: str,
version: int,
name: str,
) -> type[Function]:
_FuncInputs = make_dataclass(
"_FuncInputs", ((f"in{i}", Var) for i in range(num_inputs)), bases=(BaseInputs,)
)

_FuncOutputs = make_dataclass(
"_FuncOutputs",
((f"out{i}", Var) for i in range(num_outputs)),
Expand All @@ -142,13 +160,15 @@ class Attributes(BaseAttributes):
Outputs = _FuncOutputs
op_type = OpType(name, domain, version)

def constructor(self, attrs, inputs):
def constructor(self, attrs: dict[str, _attributes.Attr], inputs: Any) -> Any:
return self.Outputs(*fun(*inputs.get_fields().values()))

return _Func


def to_function(name: str, domain: str = "spox.function", *, _version: int = 0):
def to_function(
name: str, domain: str = "spox.function", *, _version: int = 0
) -> Callable:
"""
Decorate a given function to make the operation performed by it add a Spox function to the graph.

Expand Down Expand Up @@ -176,7 +196,7 @@ def get_num_outputs(*args: Var) -> int:
_num_outputs = sum(1 for _ in fun(*args))
return _num_outputs

def init(*args: Var):
def init(*args: Var) -> type[Function]:
nonlocal _cls
if _cls is not None:
return _cls
Expand All @@ -186,9 +206,9 @@ def init(*args: Var):
)
return _cls

def alt_fun(*args: Var) -> Iterable[Var]:
def alt_fun(*args: Var) -> Iterable[Union[Var, Optional[Var], Sequence[Var]]]:
cls = init(*args)
return (
return list(
cls(cls.Attributes(), cls.Inputs(*args)).outputs.get_fields().values()
)

Expand Down
Loading
Loading