From f02f8c88649fe8da57d420a59feece7fad2ae743 Mon Sep 17 00:00:00 2001 From: Matthew Evans Date: Mon, 14 Oct 2024 11:53:48 +0100 Subject: [PATCH] Extend annotation -> type map to include py310 native union types --- optimade/models/types.py | 7 ++++--- tests/models/test_types.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) create mode 100644 tests/models/test_types.py diff --git a/optimade/models/types.py b/optimade/models/types.py index b230a0fde..8ed7f99de 100644 --- a/optimade/models/types.py +++ b/optimade/models/types.py @@ -1,3 +1,4 @@ +from types import UnionType from typing import Annotated, Optional, Union, get_args from pydantic import Field @@ -23,7 +24,7 @@ AnnotatedType = type(ChemicalSymbol) OptionalType = type(Optional[str]) -UnionType = type(Union[str, int]) +_UnionType = type(Union[str, int]) NoneType = type(None) @@ -39,7 +40,7 @@ def _get_origin_type(annotation: type) -> type: """ # If the annotation is a Union, get the first non-None type (this includes # Optional[T]) - if isinstance(annotation, (OptionalType, UnionType)): + if isinstance(annotation, (OptionalType, UnionType, _UnionType)): for arg in get_args(annotation): if arg not in (None, NoneType): annotation = arg @@ -50,7 +51,7 @@ def _get_origin_type(annotation: type) -> type: annotation = get_args(annotation)[0] # Recursively unpack annotation, if it is a Union, Optional, or Annotated type - while isinstance(annotation, (OptionalType, UnionType, AnnotatedType)): + while isinstance(annotation, (OptionalType, UnionType, _UnionType, AnnotatedType)): annotation = _get_origin_type(annotation) # Special case for Literal diff --git a/tests/models/test_types.py b/tests/models/test_types.py new file mode 100644 index 000000000..5dbccb996 --- /dev/null +++ b/tests/models/test_types.py @@ -0,0 +1,13 @@ +from typing import Annotated, Optional + + +def test_origin_type(): + from optimade.models.types import _get_origin_type + + assert _get_origin_type(int | None) is int + assert _get_origin_type(str | None) is str + assert _get_origin_type(Optional[int]) is int + assert _get_origin_type(Optional[str]) is str + assert _get_origin_type(Annotated[int, "test"]) is int + assert _get_origin_type(Annotated[str, "test"]) is str + assert _get_origin_type(int | str | None) is int