Skip to content

Commit

Permalink
Use PEP604 type annotations (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 authored Dec 16, 2024
1 parent 320d57f commit cc2c2f9
Show file tree
Hide file tree
Showing 20 changed files with 156 additions and 157 deletions.
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ namespaces = false
[tool.ruff.lint]
# Enable the isort rules.
extend-select = ["I", "UP"]
ignore = [
"UP007", # https://docs.astral.sh/ruff/rules/non-pep604-annotation/
]

[tool.ruff.lint.isort]
known-first-party = ["spox"]
Expand Down
7 changes: 4 additions & 3 deletions src/spox/_adapt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import warnings
from typing import Optional

import onnx
import onnx.version_converter
Expand All @@ -22,7 +23,7 @@ def adapt_node(
source_version: int,
target_version: int,
var_names: dict[_VarInfo, str],
) -> Optional[list[onnx.NodeProto]]:
) -> list[onnx.NodeProto] | None:
if source_version == target_version:
return None

Expand Down Expand Up @@ -93,7 +94,7 @@ def adapt_best_effort(
opsets: dict[str, int],
var_names: dict[_VarInfo, str],
node_names: dict[Node, str],
) -> Optional[list[onnx.NodeProto]]:
) -> list[onnx.NodeProto] | None:
if isinstance(node, _Inline):
return adapt_inline(
node,
Expand Down
24 changes: 13 additions & 11 deletions src/spox/_attributes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import abc
from abc import ABC
from collections.abc import Iterable
from typing import Any, Generic, Optional, TypeVar, Union
from typing import Any, Generic, TypeVar

import numpy as np
import numpy.typing as npt
Expand All @@ -28,25 +30,25 @@


class Attr(ABC, Generic[T]):
_value: Union[T, "_Ref[T]"]
_value: T | _Ref[T]
_name: str
_cached_onnx: Optional[AttributeProto]
_cached_onnx: AttributeProto | None

def __init__(self, value: Union[T, "_Ref[T]"], name: str):
def __init__(self, value: T | _Ref[T], name: str):
self._value = value
self._name = name
self._cached_onnx = None

self._validate()

def deref(self) -> "Attr":
def deref(self) -> Attr:
if isinstance(self._value, _Ref):
return type(self)(self.value, self._name)
else:
return self

@classmethod
def maybe(cls: type[AttrT], value: Optional[T], name: str) -> Optional[AttrT]:
def maybe(cls: type[AttrT], value: T | None, name: str) -> AttrT | None:
return cls(value, name) if value is not None else None

@property
Expand Down Expand Up @@ -110,7 +112,7 @@ def __init__(self, concrete: Attr[T], outer_name: str, name: str):
self._outer_name = outer_name
self._name = name

def copy(self) -> "_Ref[T]":
def copy(self) -> _Ref[T]:
return self

def _to_onnx(self) -> AttributeProto:
Expand Down Expand Up @@ -146,7 +148,7 @@ def _to_onnx_deref(self) -> AttributeProto:
class AttrTensor(Attr[np.ndarray]):
_attribute_proto_type = AttributeProto.TENSOR

def __init__(self, value: Union[np.ndarray, _Ref[np.ndarray]], name: str):
def __init__(self, value: np.ndarray | _Ref[np.ndarray], name: str):
super().__init__(value.copy(), name)

def _to_onnx_deref(self) -> AttributeProto:
Expand Down Expand Up @@ -202,17 +204,17 @@ def _to_onnx_deref(self) -> AttributeProto:


class _AttrIterable(Attr[tuple[S, ...]], ABC):
def __init__(self, value: Union[Iterable[S], _Ref[tuple[S, ...]]], name: str):
def __init__(self, value: Iterable[S] | _Ref[tuple[S, ...]], name: str):
super().__init__(
value=value if isinstance(value, _Ref) else tuple(value), name=name
)

@classmethod
def maybe(
cls: type[AttrIterableT],
value: Optional[Iterable[S]],
value: Iterable[S] | None,
name: str,
) -> Optional[AttrIterableT]:
) -> AttrIterableT | None:
return cls(tuple(value), name) if value is not None else None

def _to_onnx_deref(self) -> AttributeProto:
Expand Down
5 changes: 2 additions & 3 deletions src/spox/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Any,
Callable,
Generic,
Optional,
TypeVar,
)

Expand All @@ -34,9 +33,9 @@
class Cached(Generic[T]):
"""A generic cached-value type, for which the ``.value`` property raises if it was not previously set."""

_value: Optional[T]
_value: T | None

def __init__(self, value: Optional[T] = None):
def __init__(self, value: T | None = None):
self._value = value

@property
Expand Down
33 changes: 17 additions & 16 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import warnings
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import Field, dataclass
from typing import Optional, Union, get_type_hints
from typing import Optional, get_type_hints

from . import _type_system
from ._attributes import Attr
from ._exceptions import InferenceWarning
from ._type_system import Optional as tOptional
from ._value_prop import PropDict, PropValue
from ._var import Var, _VarInfo

Expand All @@ -24,7 +24,7 @@ class BaseFields:

@dataclass
class BaseAttributes(BaseFields):
def get_fields(self) -> dict[str, Union[None, Attr]]:
def get_fields(self) -> dict[str, None | Attr]:
"""Return a mapping of all fields stored in this object by name."""
return self.__dict__.copy()

Expand All @@ -40,16 +40,16 @@ class VarFieldKind(enum.Enum):
class BaseVars:
"""A collection of `Var`-s used to carry around inputs/outputs of nodes"""

vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]]
vars: dict[str, Var | None | Sequence[Var]]

def __init__(self, vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]]):
def __init__(self, vars: dict[str, Var | None | Sequence[Var]]):
self.vars = vars

def _unpack_to_any(self) -> tuple[Union[Var, Optional[Var], Sequence[Var]], ...]:
def _unpack_to_any(self) -> tuple[Var | None | Sequence[Var], ...]:
"""Unpack the stored fields into a tuple of appropriate length, typed as Any."""
return tuple(self.vars.values())

def _flatten(self) -> Iterator[tuple[str, Optional[Var]]]:
def _flatten(self) -> Iterator[tuple[str, Var | None]]:
"""Iterate over the pairs of names and values of fields in this object."""
for key, value in self.vars.items():
if value is None or isinstance(value, Var):
Expand All @@ -61,7 +61,7 @@ def flatten_vars(self) -> dict[str, Var]:
"""Return a flat mapping by name of all the VarInfos in this object."""
return {key: var for key, var in self._flatten() if var is not None}

def __getattr__(self, attr: str) -> Union[Var, Optional[Var], Sequence[Var]]:
def __getattr__(self, attr: str) -> Var | None | Sequence[Var]:
"""Retrieves the attribute if present in the stored variables."""
try:
return self.vars[attr]
Expand All @@ -70,9 +70,7 @@ def __getattr__(self, attr: str) -> Union[Var, Optional[Var], Sequence[Var]]:
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
)

def __setattr__(
self, attr: str, value: Union[Var, Optional[Var], Sequence[Var]]
) -> None:
def __setattr__(self, attr: str, value: Var | None | Sequence[Var]) -> None:
"""Sets the attribute to a value if the attribute is present in the stored variables."""
if attr == "vars":
super().__setattr__(attr, value)
Expand Down Expand Up @@ -121,15 +119,15 @@ def _get_field_type(cls, field: Field) -> VarFieldKind:
return VarFieldKind.VARIADIC
raise ValueError(f"Bad field type: '{field.type}'.")

def _flatten(self) -> Iterable[tuple[str, Optional[_VarInfo]]]:
def _flatten(self) -> Iterable[tuple[str, _VarInfo | None]]:
"""Iterate over the pairs of names and values of fields in this object."""
for key, value in self.__dict__.items():
if value is None or isinstance(value, _VarInfo):
yield key, value
else:
yield from ((f"{key}_{i}", v) for i, v in enumerate(value))

def __iter__(self) -> Iterator[Optional[_VarInfo]]:
def __iter__(self) -> Iterator[_VarInfo | None]:
"""Iterate over the values of fields in this object."""
yield from (v for _, v in self._flatten())

Expand All @@ -141,7 +139,7 @@ def get_var_infos(self) -> dict[str, _VarInfo]:
"""Return a flat mapping by name of all the VarInfos in this object."""
return {key: var for key, var in self._flatten() if var is not None}

def get_fields(self) -> dict[str, Union[None, _VarInfo, Sequence[_VarInfo]]]:
def get_fields(self) -> dict[str, None | _VarInfo | Sequence[_VarInfo]]:
"""Return a mapping of all fields stored in this object by name."""
return self.__dict__.copy()

Expand All @@ -162,7 +160,10 @@ def _create_var(key: str, var_info: _VarInfo) -> Var:
if var_info.type is None or key not in prop_values:
return ret

if not isinstance(var_info.type, tOptional) and prop_values[key] is None:
if (
not isinstance(var_info.type, _type_system.Optional)
and prop_values[key] is None
):
return ret

prop = PropValue(var_info.type, prop_values[key])
Expand All @@ -178,7 +179,7 @@ def _create_var(key: str, var_info: _VarInfo) -> Var:

return ret

ret_dict: dict[str, Union[Var, Optional[Var], Sequence[Var]]] = {}
ret_dict: dict[str, Var | None | Sequence[Var]] = {}

for key, var_info in self.__dict__.items():
if isinstance(var_info, _VarInfo):
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import itertools
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, TypeVar

import numpy as np
import onnx
Expand Down Expand Up @@ -215,7 +215,7 @@ def init(*args: Var) -> type[Function]:
)
return _cls

def alt_fun(*args: Var) -> Iterable[Union[Var, Optional[Var], Sequence[Var]]]:
def alt_fun(*args: Var) -> Iterable[Var | None | Sequence[Var]]:
cls = init(*args)
return [
Var(var_info)
Expand Down
14 changes: 8 additions & 6 deletions src/spox/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

"""Module containing experimental Spox features that may be standard in the future."""

from __future__ import annotations

import warnings
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from types import ModuleType
from typing import Any, Optional, Union
from typing import Any

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -83,13 +85,13 @@ def __init__(
self.constant_promotion = constant_promotion

def _promote(
self, *args: Union[Var, np.generic, int, float], to_floating: bool = False
) -> Iterable[Optional[Var]]:
self, *args: Var | np.generic | int | float, to_floating: bool = False
) -> Iterable[Var | None]:
"""
Apply constant promotion and type promotion to given parameters,
creating constants and/or casting.
"""
targets: list[Union[np.dtype, np.generic, int, float]] = [
targets: list[np.dtype | np.generic | int | float] = [
x.type.dtype if isinstance(x, Var) and isinstance(x.type, Tensor) else x # type: ignore
for x in args
]
Expand Down Expand Up @@ -117,8 +119,8 @@ def _promote(
# TODO: Handle more constant-target inconsistencies here?

def _promote_target(
obj: Union[Var, np.generic, int, float],
) -> Optional[Var]:
obj: Var | np.generic | int | float,
) -> Var | None:
if self.constant_promotion and isinstance(obj, (np.generic, int, float)):
return self.op.const(np.array(obj, dtype=target_type))
elif isinstance(obj, Var):
Expand Down
24 changes: 11 additions & 13 deletions src/spox/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import itertools
from collections.abc import Iterable
from dataclasses import dataclass, replace
from typing import Callable, Literal, Optional, Union
from typing import Callable, Literal

import numpy as np
import onnx
Expand All @@ -27,7 +27,7 @@
from ._var import Var, _VarInfo


def arguments_dict(**kwargs: Optional[Union[Type, np.ndarray]]) -> dict[str, Var]:
def arguments_dict(**kwargs: Type | np.ndarray | None) -> dict[str, Var]:
"""
Parameters
----------
Expand Down Expand Up @@ -76,14 +76,12 @@ def arguments_dict(**kwargs: Optional[Union[Type, np.ndarray]]) -> dict[str, Var
return result # type: ignore


def arguments(**kwargs: Optional[Union[Type, np.ndarray]]) -> tuple[Var, ...]:
def arguments(**kwargs: Type | np.ndarray | None) -> tuple[Var, ...]:
"""This function is a shorthand for a respective call to ``arguments_dict``, unpacking the Vars from the dict."""
return tuple(arguments_dict(**kwargs).values())


def enum_arguments(
*infos: Union[Type, np.ndarray], prefix: str = "in"
) -> tuple[Var, ...]:
def enum_arguments(*infos: Type | np.ndarray, prefix: str = "in") -> tuple[Var, ...]:
"""
Convenience function for creating an enumeration of arguments, prefixed with ``prefix``.
Calls ``arguments`` internally.
Expand Down Expand Up @@ -148,11 +146,11 @@ class Graph:
"""

_results: dict[str, Var]
_name: Optional[str] = None
_doc_string: Optional[str] = None
_arguments: Optional[tuple[Var, ...]] = None
_extra_opset_req: Optional[set[tuple[str, int]]] = None
_constructor: Optional[Callable[..., Iterable[Var]]] = None
_name: str | None = None
_doc_string: str | None = None
_arguments: tuple[Var, ...] | None = None
_extra_opset_req: set[tuple[str, int]] | None = None
_constructor: Callable[..., Iterable[Var]] | None = None
_build_result: _build.Cached[_build.BuildResult] = dataclasses.field(
default_factory=_build.Cached
)
Expand Down Expand Up @@ -227,7 +225,7 @@ def _inject_build_result(self, what: _build.BuildResult) -> Graph:
return replace(self, _build_result=_build.Cached(what))

@property
def requested_arguments(self) -> Optional[Iterable[Var]]:
def requested_arguments(self) -> Iterable[Var] | None:
"""Arguments requested by this Graph (for building) - ``None`` if unspecified."""
return self._arguments

Expand Down Expand Up @@ -375,7 +373,7 @@ def to_onnx_model(
producer_name: str = "spox",
model_doc_string: str = "",
infer_shapes: bool = False,
check_model: Union[Literal[0], Literal[1], Literal[2]] = 1,
check_model: Literal[0] | Literal[1] | Literal[2] = 1,
ir_version: int = 8,
concrete: bool = True,
) -> onnx.ModelProto:
Expand Down
Loading

0 comments on commit cc2c2f9

Please sign in to comment.