Skip to content

Commit

Permalink
Extend annotation -> type map to include py310 native union types
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Oct 14, 2024
1 parent 29fde1e commit f02f8c8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
7 changes: 4 additions & 3 deletions optimade/models/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from types import UnionType
from typing import Annotated, Optional, Union, get_args

from pydantic import Field
Expand All @@ -23,7 +24,7 @@

AnnotatedType = type(ChemicalSymbol)
OptionalType = type(Optional[str])
UnionType = type(Union[str, int])
_UnionType = type(Union[str, int])
NoneType = type(None)


Expand All @@ -39,7 +40,7 @@ def _get_origin_type(annotation: type) -> type:
"""
# If the annotation is a Union, get the first non-None type (this includes
# Optional[T])
if isinstance(annotation, (OptionalType, UnionType)):
if isinstance(annotation, (OptionalType, UnionType, _UnionType)):
for arg in get_args(annotation):
if arg not in (None, NoneType):
annotation = arg
Expand All @@ -50,7 +51,7 @@ def _get_origin_type(annotation: type) -> type:
annotation = get_args(annotation)[0]

# Recursively unpack annotation, if it is a Union, Optional, or Annotated type
while isinstance(annotation, (OptionalType, UnionType, AnnotatedType)):
while isinstance(annotation, (OptionalType, UnionType, _UnionType, AnnotatedType)):
annotation = _get_origin_type(annotation)

# Special case for Literal
Expand Down
13 changes: 13 additions & 0 deletions tests/models/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Annotated, Optional


def test_origin_type():
from optimade.models.types import _get_origin_type

assert _get_origin_type(int | None) is int
assert _get_origin_type(str | None) is str
assert _get_origin_type(Optional[int]) is int
assert _get_origin_type(Optional[str]) is str
assert _get_origin_type(Annotated[int, "test"]) is int
assert _get_origin_type(Annotated[str, "test"]) is str
assert _get_origin_type(int | str | None) is int

0 comments on commit f02f8c8

Please sign in to comment.