diff --git a/pyproject.toml b/pyproject.toml index a21f2d4c..5944e975 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,9 +52,6 @@ namespaces = false [tool.ruff.lint] # Enable the isort rules. extend-select = ["I", "UP"] -ignore = [ - "UP007", # https://docs.astral.sh/ruff/rules/non-pep604-annotation/ -] [tool.ruff.lint.isort] known-first-party = ["spox"] diff --git a/src/spox/_adapt.py b/src/spox/_adapt.py index 8877e40e..42aad5bf 100644 --- a/src/spox/_adapt.py +++ b/src/spox/_adapt.py @@ -1,8 +1,9 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import warnings -from typing import Optional import onnx import onnx.version_converter @@ -22,7 +23,7 @@ def adapt_node( source_version: int, target_version: int, var_names: dict[_VarInfo, str], -) -> Optional[list[onnx.NodeProto]]: +) -> list[onnx.NodeProto] | None: if source_version == target_version: return None @@ -93,7 +94,7 @@ def adapt_best_effort( opsets: dict[str, int], var_names: dict[_VarInfo, str], node_names: dict[Node, str], -) -> Optional[list[onnx.NodeProto]]: +) -> list[onnx.NodeProto] | None: if isinstance(node, _Inline): return adapt_inline( node, diff --git a/src/spox/_attributes.py b/src/spox/_attributes.py index 1e9a8f09..c0d2813e 100644 --- a/src/spox/_attributes.py +++ b/src/spox/_attributes.py @@ -1,10 +1,12 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import abc from abc import ABC from collections.abc import Iterable -from typing import Any, Generic, Optional, TypeVar, Union +from typing import Any, Generic, TypeVar import numpy as np import numpy.typing as npt @@ -28,25 +30,25 @@ class Attr(ABC, Generic[T]): - _value: Union[T, "_Ref[T]"] + _value: T | _Ref[T] _name: str - _cached_onnx: Optional[AttributeProto] + _cached_onnx: AttributeProto | None - def __init__(self, value: Union[T, "_Ref[T]"], name: str): + def __init__(self, value: T | _Ref[T], name: str): self._value = value self._name = name self._cached_onnx = None self._validate() - def deref(self) -> "Attr": + def deref(self) -> Attr: if isinstance(self._value, _Ref): return type(self)(self.value, self._name) else: return self @classmethod - def maybe(cls: type[AttrT], value: Optional[T], name: str) -> Optional[AttrT]: + def maybe(cls: type[AttrT], value: T | None, name: str) -> AttrT | None: return cls(value, name) if value is not None else None @property @@ -110,7 +112,7 @@ def __init__(self, concrete: Attr[T], outer_name: str, name: str): self._outer_name = outer_name self._name = name - def copy(self) -> "_Ref[T]": + def copy(self) -> _Ref[T]: return self def _to_onnx(self) -> AttributeProto: @@ -146,7 +148,7 @@ def _to_onnx_deref(self) -> AttributeProto: class AttrTensor(Attr[np.ndarray]): _attribute_proto_type = AttributeProto.TENSOR - def __init__(self, value: Union[np.ndarray, _Ref[np.ndarray]], name: str): + def __init__(self, value: np.ndarray | _Ref[np.ndarray], name: str): super().__init__(value.copy(), name) def _to_onnx_deref(self) -> AttributeProto: @@ -202,7 +204,7 @@ def _to_onnx_deref(self) -> AttributeProto: class _AttrIterable(Attr[tuple[S, ...]], ABC): - def __init__(self, value: Union[Iterable[S], _Ref[tuple[S, ...]]], name: str): + def __init__(self, value: Iterable[S] | _Ref[tuple[S, ...]], name: str): super().__init__( value=value if isinstance(value, _Ref) else tuple(value), name=name ) @@ -210,9 +212,9 @@ def __init__(self, value: Union[Iterable[S], _Ref[tuple[S, ...]]], name: str): @classmethod def maybe( cls: type[AttrIterableT], - value: Optional[Iterable[S]], + value: Iterable[S] | None, name: str, - ) -> Optional[AttrIterableT]: + ) -> AttrIterableT | None: return cls(tuple(value), name) if value is not None else None def _to_onnx_deref(self) -> AttributeProto: diff --git a/src/spox/_build.py b/src/spox/_build.py index 1fb24cea..235ab74a 100644 --- a/src/spox/_build.py +++ b/src/spox/_build.py @@ -10,7 +10,6 @@ Any, Callable, Generic, - Optional, TypeVar, ) @@ -34,9 +33,9 @@ class Cached(Generic[T]): """A generic cached-value type, for which the ``.value`` property raises if it was not previously set.""" - _value: Optional[T] + _value: T | None - def __init__(self, value: Optional[T] = None): + def __init__(self, value: T | None = None): self._value = value @property diff --git a/src/spox/_fields.py b/src/spox/_fields.py index ff5d6efc..5cb77be9 100644 --- a/src/spox/_fields.py +++ b/src/spox/_fields.py @@ -8,11 +8,11 @@ import warnings from collections.abc import Iterable, Iterator, Sequence from dataclasses import Field, dataclass -from typing import Optional, Union, get_type_hints +from typing import Optional, get_type_hints +from . import _type_system from ._attributes import Attr from ._exceptions import InferenceWarning -from ._type_system import Optional as tOptional from ._value_prop import PropDict, PropValue from ._var import Var, _VarInfo @@ -24,7 +24,7 @@ class BaseFields: @dataclass class BaseAttributes(BaseFields): - def get_fields(self) -> dict[str, Union[None, Attr]]: + def get_fields(self) -> dict[str, None | Attr]: """Return a mapping of all fields stored in this object by name.""" return self.__dict__.copy() @@ -40,16 +40,16 @@ class VarFieldKind(enum.Enum): class BaseVars: """A collection of `Var`-s used to carry around inputs/outputs of nodes""" - vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]] + vars: dict[str, Var | None | Sequence[Var]] - def __init__(self, vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]]): + def __init__(self, vars: dict[str, Var | None | Sequence[Var]]): self.vars = vars - def _unpack_to_any(self) -> tuple[Union[Var, Optional[Var], Sequence[Var]], ...]: + def _unpack_to_any(self) -> tuple[Var | None | Sequence[Var], ...]: """Unpack the stored fields into a tuple of appropriate length, typed as Any.""" return tuple(self.vars.values()) - def _flatten(self) -> Iterator[tuple[str, Optional[Var]]]: + def _flatten(self) -> Iterator[tuple[str, Var | None]]: """Iterate over the pairs of names and values of fields in this object.""" for key, value in self.vars.items(): if value is None or isinstance(value, Var): @@ -61,7 +61,7 @@ def flatten_vars(self) -> dict[str, Var]: """Return a flat mapping by name of all the VarInfos in this object.""" return {key: var for key, var in self._flatten() if var is not None} - def __getattr__(self, attr: str) -> Union[Var, Optional[Var], Sequence[Var]]: + def __getattr__(self, attr: str) -> Var | None | Sequence[Var]: """Retrieves the attribute if present in the stored variables.""" try: return self.vars[attr] @@ -70,9 +70,7 @@ def __getattr__(self, attr: str) -> Union[Var, Optional[Var], Sequence[Var]]: f"{self.__class__.__name__!r} object has no attribute {attr!r}" ) - def __setattr__( - self, attr: str, value: Union[Var, Optional[Var], Sequence[Var]] - ) -> None: + def __setattr__(self, attr: str, value: Var | None | Sequence[Var]) -> None: """Sets the attribute to a value if the attribute is present in the stored variables.""" if attr == "vars": super().__setattr__(attr, value) @@ -121,7 +119,7 @@ def _get_field_type(cls, field: Field) -> VarFieldKind: return VarFieldKind.VARIADIC raise ValueError(f"Bad field type: '{field.type}'.") - def _flatten(self) -> Iterable[tuple[str, Optional[_VarInfo]]]: + def _flatten(self) -> Iterable[tuple[str, _VarInfo | None]]: """Iterate over the pairs of names and values of fields in this object.""" for key, value in self.__dict__.items(): if value is None or isinstance(value, _VarInfo): @@ -129,7 +127,7 @@ def _flatten(self) -> Iterable[tuple[str, Optional[_VarInfo]]]: else: yield from ((f"{key}_{i}", v) for i, v in enumerate(value)) - def __iter__(self) -> Iterator[Optional[_VarInfo]]: + def __iter__(self) -> Iterator[_VarInfo | None]: """Iterate over the values of fields in this object.""" yield from (v for _, v in self._flatten()) @@ -141,7 +139,7 @@ def get_var_infos(self) -> dict[str, _VarInfo]: """Return a flat mapping by name of all the VarInfos in this object.""" return {key: var for key, var in self._flatten() if var is not None} - def get_fields(self) -> dict[str, Union[None, _VarInfo, Sequence[_VarInfo]]]: + def get_fields(self) -> dict[str, None | _VarInfo | Sequence[_VarInfo]]: """Return a mapping of all fields stored in this object by name.""" return self.__dict__.copy() @@ -162,7 +160,10 @@ def _create_var(key: str, var_info: _VarInfo) -> Var: if var_info.type is None or key not in prop_values: return ret - if not isinstance(var_info.type, tOptional) and prop_values[key] is None: + if ( + not isinstance(var_info.type, _type_system.Optional) + and prop_values[key] is None + ): return ret prop = PropValue(var_info.type, prop_values[key]) @@ -178,7 +179,7 @@ def _create_var(key: str, var_info: _VarInfo) -> Var: return ret - ret_dict: dict[str, Union[Var, Optional[Var], Sequence[Var]]] = {} + ret_dict: dict[str, Var | None | Sequence[Var]] = {} for key, var_info in self.__dict__.items(): if isinstance(var_info, _VarInfo): diff --git a/src/spox/_function.py b/src/spox/_function.py index 63760d75..e782a7cc 100644 --- a/src/spox/_function.py +++ b/src/spox/_function.py @@ -7,7 +7,7 @@ import itertools from collections.abc import Iterable, Sequence from dataclasses import dataclass, make_dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, TypeVar import numpy as np import onnx @@ -215,7 +215,7 @@ def init(*args: Var) -> type[Function]: ) return _cls - def alt_fun(*args: Var) -> Iterable[Union[Var, Optional[Var], Sequence[Var]]]: + def alt_fun(*args: Var) -> Iterable[Var | None | Sequence[Var]]: cls = init(*args) return [ Var(var_info) diff --git a/src/spox/_future.py b/src/spox/_future.py index e7d7b76b..9c19e054 100644 --- a/src/spox/_future.py +++ b/src/spox/_future.py @@ -3,11 +3,13 @@ """Module containing experimental Spox features that may be standard in the future.""" +from __future__ import annotations + import warnings from collections.abc import Iterable, Iterator from contextlib import contextmanager from types import ModuleType -from typing import Any, Optional, Union +from typing import Any import numpy as np import numpy.typing as npt @@ -83,13 +85,13 @@ def __init__( self.constant_promotion = constant_promotion def _promote( - self, *args: Union[Var, np.generic, int, float], to_floating: bool = False - ) -> Iterable[Optional[Var]]: + self, *args: Var | np.generic | int | float, to_floating: bool = False + ) -> Iterable[Var | None]: """ Apply constant promotion and type promotion to given parameters, creating constants and/or casting. """ - targets: list[Union[np.dtype, np.generic, int, float]] = [ + targets: list[np.dtype | np.generic | int | float] = [ x.type.dtype if isinstance(x, Var) and isinstance(x.type, Tensor) else x # type: ignore for x in args ] @@ -117,8 +119,8 @@ def _promote( # TODO: Handle more constant-target inconsistencies here? def _promote_target( - obj: Union[Var, np.generic, int, float], - ) -> Optional[Var]: + obj: Var | np.generic | int | float, + ) -> Var | None: if self.constant_promotion and isinstance(obj, (np.generic, int, float)): return self.op.const(np.array(obj, dtype=target_type)) elif isinstance(obj, Var): diff --git a/src/spox/_graph.py b/src/spox/_graph.py index 44afb765..211030e7 100644 --- a/src/spox/_graph.py +++ b/src/spox/_graph.py @@ -9,7 +9,7 @@ import itertools from collections.abc import Iterable from dataclasses import dataclass, replace -from typing import Callable, Literal, Optional, Union +from typing import Callable, Literal import numpy as np import onnx @@ -27,7 +27,7 @@ from ._var import Var, _VarInfo -def arguments_dict(**kwargs: Optional[Union[Type, np.ndarray]]) -> dict[str, Var]: +def arguments_dict(**kwargs: Type | np.ndarray | None) -> dict[str, Var]: """ Parameters ---------- @@ -76,14 +76,12 @@ def arguments_dict(**kwargs: Optional[Union[Type, np.ndarray]]) -> dict[str, Var return result # type: ignore -def arguments(**kwargs: Optional[Union[Type, np.ndarray]]) -> tuple[Var, ...]: +def arguments(**kwargs: Type | np.ndarray | None) -> tuple[Var, ...]: """This function is a shorthand for a respective call to ``arguments_dict``, unpacking the Vars from the dict.""" return tuple(arguments_dict(**kwargs).values()) -def enum_arguments( - *infos: Union[Type, np.ndarray], prefix: str = "in" -) -> tuple[Var, ...]: +def enum_arguments(*infos: Type | np.ndarray, prefix: str = "in") -> tuple[Var, ...]: """ Convenience function for creating an enumeration of arguments, prefixed with ``prefix``. Calls ``arguments`` internally. @@ -148,11 +146,11 @@ class Graph: """ _results: dict[str, Var] - _name: Optional[str] = None - _doc_string: Optional[str] = None - _arguments: Optional[tuple[Var, ...]] = None - _extra_opset_req: Optional[set[tuple[str, int]]] = None - _constructor: Optional[Callable[..., Iterable[Var]]] = None + _name: str | None = None + _doc_string: str | None = None + _arguments: tuple[Var, ...] | None = None + _extra_opset_req: set[tuple[str, int]] | None = None + _constructor: Callable[..., Iterable[Var]] | None = None _build_result: _build.Cached[_build.BuildResult] = dataclasses.field( default_factory=_build.Cached ) @@ -227,7 +225,7 @@ def _inject_build_result(self, what: _build.BuildResult) -> Graph: return replace(self, _build_result=_build.Cached(what)) @property - def requested_arguments(self) -> Optional[Iterable[Var]]: + def requested_arguments(self) -> Iterable[Var] | None: """Arguments requested by this Graph (for building) - ``None`` if unspecified.""" return self._arguments @@ -375,7 +373,7 @@ def to_onnx_model( producer_name: str = "spox", model_doc_string: str = "", infer_shapes: bool = False, - check_model: Union[Literal[0], Literal[1], Literal[2]] = 1, + check_model: Literal[0] | Literal[1] | Literal[2] = 1, ir_version: int = 8, concrete: bool = True, ) -> onnx.ModelProto: diff --git a/src/spox/_internal_op.py b/src/spox/_internal_op.py index 3a3af424..4109e420 100644 --- a/src/spox/_internal_op.py +++ b/src/spox/_internal_op.py @@ -11,7 +11,7 @@ from abc import ABC from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import numpy as np import onnx @@ -78,8 +78,8 @@ class Argument(_InternalNode): @dataclass class Attributes(BaseAttributes): type: AttrType - name: Optional[AttrString] = None - default: Optional[AttrTensor] = None + name: AttrString | None = None + default: AttrTensor | None = None @dataclass class Inputs(BaseInputs): @@ -115,8 +115,8 @@ def update_metadata( def to_onnx( self, scope: Scope, - doc_string: Optional[str] = None, - build_subgraph: Optional[Callable] = None, + doc_string: str | None = None, + build_subgraph: Callable | None = None, ) -> list[onnx.NodeProto]: return [] @@ -158,8 +158,8 @@ def update_metadata( def to_onnx( self, scope: Scope, - doc_string: Optional[str] = None, - build_subgraph: Optional[Callable] = None, + doc_string: str | None = None, + build_subgraph: Callable | None = None, ) -> list[onnx.NodeProto]: # Initializers are added via update_metadata and don't affect the nodes proto list return [] @@ -200,8 +200,8 @@ def opset_req(self) -> set[tuple[str, int]]: def to_onnx( self, scope: Scope, - doc_string: Optional[str] = None, - build_subgraph: Optional[Callable] = None, + doc_string: str | None = None, + build_subgraph: Callable | None = None, ) -> list[onnx.NodeProto]: assert len(self.inputs.inputs) == len(self.outputs.outputs) # Just create a renaming identity from what we forwarded into our actual output diff --git a/src/spox/_node.py b/src/spox/_node.py index 8e2a0222..dac276ab 100644 --- a/src/spox/_node.py +++ b/src/spox/_node.py @@ -12,7 +12,7 @@ from abc import ABC from collections.abc import Iterable, Iterator, Sequence from dataclasses import dataclass -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar import numpy as np import onnx @@ -94,17 +94,17 @@ class Node(ABC): inputs: BaseInputs outputs: BaseOutputs - out_variadic: Optional[int] - _traceback: Union[list[str], None] + out_variadic: int | None + _traceback: list[str] | None _validate: bool def __init__( self, - attrs: Optional[BaseAttributes] = None, - inputs: Optional[BaseInputs] = None, - outputs: Optional[BaseOutputs] = None, + attrs: BaseAttributes | None = None, + inputs: BaseInputs | None = None, + outputs: BaseOutputs | None = None, *, - out_variadic: Optional[int] = None, + out_variadic: int | None = None, infer_types: bool = True, validate: bool = True, **kwargs: Any, @@ -234,7 +234,7 @@ def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]: return {} def inference( - self, input_prop_values: Optional[PropDict] = None, infer_types: bool = True + self, input_prop_values: PropDict | None = None, infer_types: bool = True ) -> None: if input_prop_values is None: input_prop_values = {} @@ -254,7 +254,7 @@ def inference( var.type = out_types.get(key) def get_output_vars( - self, input_prop_values: Optional[PropDict] = None, infer_types: bool = True + self, input_prop_values: PropDict | None = None, infer_types: bool = True ) -> BaseVars: if input_prop_values is None: input_prop_values = {} @@ -298,7 +298,7 @@ def validate_types(self) -> None: stacklevel=4, ) - def _check_concrete_type(self, value_type: Optional[Type]) -> Optional[str]: + def _check_concrete_type(self, value_type: Type | None) -> str | None: if value_type is None: return "type is None" try: @@ -307,7 +307,7 @@ def _check_concrete_type(self, value_type: Optional[Type]) -> Optional[str]: return f"{type(e).__name__}: {str(e)}" return None - def _list_types(self, source: BaseVarInfos) -> Iterator[tuple[str, Optional[Type]]]: + def _list_types(self, source: BaseVarInfos) -> Iterator[tuple[str, Type | None]]: return ((key, var.type) for key, var in source.get_var_infos().items()) def _init_output_vars(self) -> BaseOutputs: @@ -325,7 +325,7 @@ def _init_output_vars(self) -> BaseOutputs: (variadic,) = variadics else: variadic = None - outputs: dict[str, Union[_VarInfo, Sequence[_VarInfo]]] = { + outputs: dict[str, _VarInfo | Sequence[_VarInfo]] = { field.name: _VarInfo(self, None) for field in dataclasses.fields(self.Outputs) if field.name != variadic @@ -367,10 +367,9 @@ def update_metadata( def to_onnx( self, scope: Scope, - doc_string: Optional[str] = None, - build_subgraph: Optional[ - typing.Callable[[Node, str, Graph], onnx.GraphProto] - ] = None, + doc_string: str | None = None, + build_subgraph: typing.Callable[[Node, str, Graph], onnx.GraphProto] + | None = None, ) -> list[onnx.NodeProto]: """Translates self into an ONNX NodeProto.""" assert self.op_type.identifier diff --git a/src/spox/_public.py b/src/spox/_public.py index e12240e7..369a67fc 100644 --- a/src/spox/_public.py +++ b/src/spox/_public.py @@ -3,10 +3,12 @@ """Module implementing the main public interface functions in Spox.""" +from __future__ import annotations + import contextlib import itertools from collections.abc import Iterator -from typing import Optional, Protocol +from typing import Protocol import numpy as np import onnx @@ -51,8 +53,8 @@ def _temporary_renames(**kwargs: Var) -> Iterator[None]: # The build code can't really special-case variable names that are # not just ``Var._name``. So we set names here and reset them # afterwards. - name: Optional[str] - pre: dict[Var, Optional[str]] = {} + name: str | None + pre: dict[Var, str | None] = {} try: for name, arg in kwargs.items(): pre[arg] = arg._var_info._name diff --git a/src/spox/_schemas.py b/src/spox/_schemas.py index 7486dcce..4eb51c27 100644 --- a/src/spox/_schemas.py +++ b/src/spox/_schemas.py @@ -3,9 +3,11 @@ """Exposes information related to reference ONNX operator schemas, used by StandardOpNode.""" +from __future__ import annotations + import itertools from collections.abc import Iterable -from typing import Any, Callable, Optional, Protocol, TypeVar +from typing import Any, Callable, Protocol, TypeVar from onnx.defs import OpSchema, get_all_schemas_with_history @@ -30,8 +32,8 @@ def _key_groups( def _current_schema( - schemas: Iterable[OpSchema], version: Optional[int] = None -) -> Optional[OpSchema]: + schemas: Iterable[OpSchema], version: int | None = None +) -> OpSchema | None: """ Find the schema for the current ``version`` from the list (the latest existing version). If ``version`` is None (or left to default), the newest of the schemas is returned. diff --git a/src/spox/_scope.py b/src/spox/_scope.py index 8eff3dbb..399b81ff 100644 --- a/src/spox/_scope.py +++ b/src/spox/_scope.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import Hashable -from typing import Generic, Optional, TypeVar, Union, overload +from typing import Generic, TypeVar, overload from ._node import Node from ._var import _VarInfo @@ -30,14 +30,14 @@ class ScopeSpace(Generic[H]): of_name: dict[str, H] reserved: set[str] base_name_counters: dict[str, int] - parent: Optional[ScopeSpace[H]] + parent: ScopeSpace[H] | None def __init__( self, - name_of: Optional[dict[H, str]] = None, - of_name: Optional[dict[str, H]] = None, - reserved: Optional[set[str]] = None, - parent: Optional[ScopeSpace[H]] = None, + name_of: dict[H, str] | None = None, + of_name: dict[str, H] | None = None, + reserved: set[str] | None = None, + parent: ScopeSpace[H] | None = None, ): """ Parameters @@ -63,7 +63,7 @@ def __init__( parent.base_name_counters if parent is not None else dict() ) - def __contains__(self, item: Union[str, H]) -> bool: + def __contains__(self, item: str | H) -> bool: """Checks if a given name or object is declared in this (or outer) namespace.""" return ( (self.parent is not None and item in self.parent) @@ -78,7 +78,7 @@ def __getitem__(self, item: H) -> str: ... @overload def __getitem__(self, item: str) -> H: ... - def __getitem__(self, item: Union[str, H]) -> Union[str, H]: + def __getitem__(self, item: str | H) -> str | H: """Access the name of an object or an object with a given name in this (or outer) namespace.""" if self.parent is not None and item in self.parent: return self.parent[item] @@ -93,7 +93,7 @@ def __setitem__(self, key: str, value: H) -> None: ... @overload def __setitem__(self, key: H, value: str) -> None: ... - def __setitem__(self, _key: Union[str, H], _value: Union[H, str]) -> None: + def __setitem__(self, _key: str | H, _value: H | str) -> None: """Set the name of an object in exactly this namespace. Both ``[name] = obj`` and ``[obj] = name`` work.""" if isinstance(_value, str): _key, _value = _value, _key @@ -115,7 +115,7 @@ def __setitem__(self, _key: Union[str, H], _value: Union[H, str]) -> None: self.of_name[key] = value self.name_of[value] = key - def __delitem__(self, item: Union[str, H]) -> None: + def __delitem__(self, item: str | H) -> None: """Delete a both the name and object from exactly this namespace.""" if isinstance(item, str): key, value = item, self.of_name[item] @@ -162,9 +162,9 @@ class Scope: def __init__( self, - sub_var: Optional[ScopeSpace[_VarInfo]] = None, - sub_node: Optional[ScopeSpace[Node]] = None, - parent: Optional[Scope] = None, + sub_var: ScopeSpace[_VarInfo] | None = None, + sub_node: ScopeSpace[Node] | None = None, + parent: Scope | None = None, ): self.var = sub_var if sub_var is not None else ScopeSpace() self.node = sub_node if sub_node is not None else ScopeSpace() @@ -177,9 +177,7 @@ def __init__( @classmethod def of( cls, - *what: Union[ - tuple[str, Union[_VarInfo, Node]], tuple[Union[_VarInfo, Node], str] - ], + *what: tuple[str, _VarInfo | Node] | tuple[_VarInfo | Node, str], ) -> Scope: """Convenience constructor for filling a Scope with known names.""" scope = cls() diff --git a/src/spox/_shape.py b/src/spox/_shape.py index 4950d785..584a45f6 100644 --- a/src/spox/_shape.py +++ b/src/spox/_shape.py @@ -96,7 +96,7 @@ class Unknown(Natural): label: str = "" - def to_simple(self) -> Union[str, None]: + def to_simple(self) -> str | None: return None if not self.label else self.label def __le__(self, other: Natural) -> bool: @@ -127,7 +127,7 @@ def __le__(self, other: Natural) -> bool: class Shape: """Type representing a static Tensor shape.""" - dims: Optional[tuple[Natural, ...]] + dims: tuple[Natural, ...] | None def __bool__(self) -> bool: return self.dims is not None @@ -140,7 +140,7 @@ def from_simple(cls: type[ShapeT], shape: SimpleShape) -> ShapeT: ) @classmethod - def from_onnx(cls: type[ShapeT], proto: Optional[onnx.TensorShapeProto]) -> ShapeT: + def from_onnx(cls: type[ShapeT], proto: onnx.TensorShapeProto | None) -> ShapeT: """Translate into a Shape from ONNX shape.""" return ( cls(tuple(Natural.from_onnx(dim) for dim in proto.dim)) @@ -154,7 +154,7 @@ def to_simple(self) -> SimpleShape: tuple(v.to_simple() for v in self.dims) if self.dims is not None else None ) - def to_onnx(self) -> Optional[onnx.TensorShapeProto]: + def to_onnx(self) -> onnx.TensorShapeProto | None: """Translate into the ONNX representation.""" if self.dims is None: return None @@ -165,7 +165,7 @@ def to_onnx(self) -> Optional[onnx.TensorShapeProto]: return proto @property - def maybe_rank(self) -> Optional[int]: + def maybe_rank(self) -> int | None: """Get the rank of this Shape, or None if it is unknown.""" return len(self.dims) if self.dims is not None else None @@ -176,7 +176,7 @@ def rank(self) -> int: raise ShapeError(f"Rank of {self} is unknown.") return r - def __getitem__(self, item: Union[slice, int]) -> Union[Shape, Natural]: + def __getitem__(self, item: slice | int) -> Shape | Natural: """Indexing the dimensions, also provides iteration.""" if self.dims is None: raise ShapeError(f"Cannot index unknown {self}.") @@ -193,7 +193,7 @@ def can_broadcast(self, other: Shape) -> bool: else: return True - def broadcast(self, other: Union[Shape, SimpleShape]) -> Shape: + def broadcast(self, other: Shape | SimpleShape) -> Shape: """Return the result of shape broadcasting on ``self`` and ``other``.""" if not isinstance(other, Shape): other = Shape.from_simple(other) diff --git a/src/spox/_traverse.py b/src/spox/_traverse.py index 6b641892..4a327759 100644 --- a/src/spox/_traverse.py +++ b/src/spox/_traverse.py @@ -1,8 +1,10 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from collections.abc import Iterable, Iterator -from typing import Callable, Optional, TypeVar +from typing import Callable, TypeVar V = TypeVar("V") @@ -10,7 +12,7 @@ def iterative_dfs( sources: Iterable[V], adj: Callable[[V], Iterable[V]], - post_callback: Optional[Callable[[V], None]] = None, + post_callback: Callable[[V], None] | None = None, raise_on_cycle: bool = True, ) -> list[V]: """ diff --git a/src/spox/_utils.py b/src/spox/_utils.py index af6ed1a5..7b00de4e 100644 --- a/src/spox/_utils.py +++ b/src/spox/_utils.py @@ -1,7 +1,7 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional +from __future__ import annotations import numpy as np import numpy.typing as npt @@ -42,7 +42,7 @@ def dtype_to_tensor_type(dtype_like: npt.DTypeLike) -> int: raise TypeError(err_msg) -def from_array(arr: np.ndarray, name: Optional[str] = None) -> TensorProto: +def from_array(arr: np.ndarray, name: str | None = None) -> TensorProto: """Convert the given ``numpy.array`` into an ``onnx.TensorProto``. As it may be useful to name the TensorProto (e.g. in diff --git a/src/spox/_var.py b/src/spox/_var.py index d97e447d..0bde9e03 100644 --- a/src/spox/_var.py +++ b/src/spox/_var.py @@ -5,7 +5,7 @@ import typing from collections.abc import Iterable, Sequence -from typing import Any, Callable, ClassVar, Optional, TypeVar, Union, overload +from typing import Any, Callable, ClassVar, TypeVar, overload import numpy as np @@ -32,14 +32,14 @@ class _VarInfo: If a ``VarInfo`` or any of its fields are modified, the behaviour is undefined and the produced graph may be invalid. """ - type: Optional[_type_system.Type] + type: _type_system.Type | None _op: Node - _name: Optional[str] + _name: str | None def __init__( self, op: Node, - type_: Optional[_type_system.Type], + type_: _type_system.Type | None, ): """The initializer of ``VarInfo`` is protected. Use operator constructors to construct them instead.""" if type_ is not None and not isinstance(type_, _type_system.Type): @@ -49,12 +49,12 @@ def __init__( self._op = op self._name = None - def _rename(self, name: Optional[str]) -> None: + def _rename(self, name: str | None) -> None: """Mutates the internal state of the VarInfo, overriding its name as given.""" self._name = name @property - def _which_output(self) -> Optional[str]: + def _which_output(self) -> str | None: """Return the name of the output field that this var is stored in under ``self._op``.""" if self._op is None: return None @@ -140,14 +140,14 @@ class Var: """ _var_info: _VarInfo - _value: Optional[_value_prop.PropValue] + _value: _value_prop.PropValue | None _operator_dispatcher: ClassVar[Any] = NotImplementedOperatorDispatcher() def __init__( self, var_info: _VarInfo, - value: Optional[_value_prop.PropValue] = None, + value: _value_prop.PropValue | None = None, ): """The initializer of ``Var`` is protected. Use operator constructors to construct them instead.""" if value is not None and not isinstance(value, _value_prop.PropValue): @@ -212,18 +212,18 @@ def _op(self) -> Node: return self._var_info._op @property - def _name(self) -> Optional[str]: + def _name(self) -> str | None: return self._var_info._name - def _rename(self, name: Optional[str]) -> None: + def _rename(self, name: str | None) -> None: self._var_info._rename(name) @property - def _which_output(self) -> Optional[str]: + def _which_output(self) -> str | None: return self._var_info._which_output @property - def type(self) -> Optional[_type_system.Type]: + def type(self) -> _type_system.Type | None: return self._var_info.type def __copy__(self) -> Var: @@ -298,7 +298,7 @@ def wrap_vars(var_info: _VarInfo) -> Var: ... @overload -def wrap_vars(var_info: Optional[_VarInfo]) -> Optional[Var]: ... +def wrap_vars(var_info: _VarInfo | None) -> Var | None: ... @overload @@ -327,7 +327,7 @@ def unwrap_vars(var: Var) -> _VarInfo: ... @overload -def unwrap_vars(var: Optional[Var]) -> Optional[_VarInfo]: ... +def unwrap_vars(var: Var | None) -> _VarInfo | None: ... @overload @@ -352,7 +352,7 @@ def unwrap_vars(var): # type: ignore def result_type( - *types: Union[_VarInfo, np.generic, int, float], + *types: _VarInfo | np.generic | int | float, ) -> type[np.generic]: """Promote type for all given element types/values using ``np.result_type``.""" return np.dtype( @@ -368,7 +368,7 @@ def result_type( def create_prop_dict( - **kwargs: Union[Var, Sequence[Var], Optional[Var]], + **kwargs: Var | Sequence[Var] | Var | None, ) -> _value_prop.PropDict: from ._fields import BaseVars diff --git a/tests/test_function.py b/tests/test_function.py index 9933f611..0246868a 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -1,9 +1,10 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import functools from dataclasses import dataclass -from typing import Union import numpy as np import onnx @@ -59,9 +60,7 @@ def constructor(self, attrs: dict[str, Attr], inputs: BaseVars) -> Outputs: x = inputs.X return self.Outputs(op.add(op.mul(a, x), b)._var_info) - def linear_inner( - x: Var, a: Union[float, _Ref[float]], b: Union[float, _Ref[float]] - ) -> Var: + def linear_inner(x: Var, a: float | _Ref[float], b: float | _Ref[float]) -> Var: return ( LinearFunction( LinearFunction.Attributes( @@ -108,9 +107,7 @@ def constructor(self, attrs: dict[str, Attr], inputs: BaseVars) -> Outputs: )._var_info ) - def linear_inner( - x: Var, a: Union[float, _Ref[float]], b: Union[float, _Ref[float]] - ) -> Var: + def linear_inner(x: Var, a: float | _Ref[float], b: float | _Ref[float]) -> Var: return ( LinearFunction2( LinearFunction2.Attributes( diff --git a/tools/generate_opset.py b/tools/generate_opset.py index 36eeecd6..1866fb9f 100644 --- a/tools/generate_opset.py +++ b/tools/generate_opset.py @@ -1,13 +1,14 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import re import subprocess from collections.abc import Iterable, Sequence from dataclasses import dataclass from io import TextIOWrapper from pathlib import Path -from typing import Optional, Union import jinja2 import onnx @@ -118,13 +119,13 @@ class Attribute: # The name of the attribute used as argument and member name name: str # Default value used in the constructor function - constructor_default: Optional[str] + constructor_default: str | None # Type hint used in the constructor function. May be wrapped in `Optional`. constructor_type_hint: str # Member type without a potential ``Optional`` wrapper _member_type: str # Python expression for generating the argument types for this subgraph - subgraph_solution: Optional[str] = None + subgraph_solution: str | None = None # Mark whether generating extra constructor arguments caused by this should raise allow_extra: bool = False @@ -193,7 +194,7 @@ def get_attributes( def _get_default_value( attr: onnx.defs.OpSchema.Attribute, attr_type_overrides: dict[str, tuple[str, str]] -) -> Optional[str]: +) -> str | None: """Get default value if any as a string ready to be used in a template. This function has a special handling with respect to ``attr_type_overrides``. @@ -294,7 +295,7 @@ def _pandoc_gfm_to_rst(*args: str) -> tuple[str, ...]: if not (arg in _PANDOC_GFM_TO_RST_CACHE or not arg) ] results = _pandoc_gfm_to_rst_run(*[args[i] for i in valid]) - sub: list[Optional[str]] = [None] * len(args) + sub: list[str | None] = [None] * len(args) for i, result in zip(valid, results): sub[i] = result for i, arg in enumerate(args): @@ -312,7 +313,7 @@ def pandoc_gfm_to_rst(doc: str) -> str: return result -def format_github_markdown(doc: str, *, to_batch: Optional[list[str]] = None) -> str: +def format_github_markdown(doc: str, *, to_batch: list[str] | None = None) -> str: """Jinja filter. Makes some attempt at fixing "Markdown" into RST.""" # Sometimes Tensor is used in the docs (~17 instances at 1.13) # and is treated as invalid HTML tags by pandoc. @@ -367,7 +368,7 @@ def write_schemas_code( value_propagation: dict[str, str], out_variadic_solutions: dict[str, str], subgraphs_solutions: dict[str, dict[str, str]], - attr_type_overrides: list[tuple[Optional[str], str, tuple[str, str]]], + attr_type_overrides: list[tuple[str | None, str, tuple[str, str]]], allow_extra_constructor_arguments: set[str], inherited_schemas: dict[onnx.defs.OpSchema, str], extras: Sequence[str], @@ -514,7 +515,7 @@ def write_schemas_code( def run_pre_commit_hooks( - filenames: Union[str, Iterable[str]], + filenames: str | Iterable[str], ) -> subprocess.CompletedProcess: """ Calls repo pre-commit hooks for the given ``filenames``. @@ -532,16 +533,14 @@ def run_pre_commit_hooks( def main( domain: str, - version: Optional[int] = None, - type_inference: Optional[dict[str, str]] = None, - value_propagation: Optional[dict[str, str]] = None, - out_variadic_solutions: Optional[dict[str, str]] = None, - subgraphs_solutions: Optional[dict[str, dict[str, str]]] = None, - attr_type_overrides: Optional[ - list[tuple[Optional[str], str, tuple[str, str]]] - ] = None, + version: int | None = None, + type_inference: dict[str, str] | None = None, + value_propagation: dict[str, str] | None = None, + out_variadic_solutions: dict[str, str] | None = None, + subgraphs_solutions: dict[str, dict[str, str]] | None = None, + attr_type_overrides: list[tuple[str | None, str, tuple[str, str]]] | None = None, allow_extra_constructor_arguments: Iterable[str] = (), - inherited_schemas: Optional[dict[onnx.defs.OpSchema, str]] = None, + inherited_schemas: dict[onnx.defs.OpSchema, str] | None = None, extras: Sequence[str] = (), target: str = "src/spox/opset/", pre_commit_hooks: bool = True, diff --git a/tools/templates/extras/promote.jinja2 b/tools/templates/extras/promote.jinja2 index 332229eb..00528711 100644 --- a/tools/templates/extras/promote.jinja2 +++ b/tools/templates/extras/promote.jinja2 @@ -1,5 +1,5 @@ def promote( - *types: Union[Var, np.generic, int, float, None] + *types: Var | np.generic | int | float | None ) -> tuple[Optional[Var], ...]: """ Apply constant promotion and type promotion to given parameters, creating constants and/or casting. @@ -14,7 +14,7 @@ def promote( target_type = result_type(*promotable) - def _promote_target(obj: Union[Var, np.generic, int, float, None]) -> Optional[Var]: + def _promote_target(obj: Var | np.generic | int | float | None) -> Optional[Var]: if isinstance(obj, (np.generic, int, float)): return const(np.array(obj, dtype=target_type)) elif isinstance(obj, Var):