Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: improve type hints #1784

Merged
merged 1 commit into from
Nov 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 47 additions & 64 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import importlib
import importlib.metadata as importlib_metadata
Expand All @@ -7,20 +9,9 @@
from copy import deepcopy
from inspect import isclass
from types import ModuleType
from typing import (
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from typing import Any, Callable, Coroutine, Iterable, Type, cast

from pypika import Table
from pypika import Query, Table

from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.backends.base.config_generator import expand_db_url, generate_config
Expand All @@ -40,8 +31,8 @@


class Tortoise:
apps: Dict[str, Dict[str, Type["Model"]]] = {}
table_name_generator: Optional[Callable[[Type["Model"]], str]] = None
apps: dict[str, dict[str, Type["Model"]]] = {}
table_name_generator: Callable[[Type["Model"]], str] | None = None
_inited: bool = False

@classmethod
Expand All @@ -60,7 +51,7 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient:
@classmethod
def describe_model(
cls, model: Type["Model"], serializable: bool = True
) -> dict: # pragma: nocoverage
) -> dict[str, Any]: # pragma: nocoverage
"""
Describes the given list of models or ALL registered models.

Expand All @@ -85,8 +76,8 @@ def describe_model(

@classmethod
def describe_models(
cls, models: Optional[List[Type["Model"]]] = None, serializable: bool = True
) -> Dict[str, dict]:
cls, models: list[Type["Model"]] | None = None, serializable: bool = True
) -> dict[str, dict[str, Any]]:
"""
Describes the given list of models or ALL registered models.

Expand Down Expand Up @@ -142,7 +133,7 @@ def get_related_model(related_app_name: str, related_model_name: str) -> Type["M
f" app '{related_app_name}'."
)

def split_reference(reference: str) -> Tuple[str, str]:
def split_reference(reference: str) -> tuple[str, str]:
"""
Validate, if reference follow the official naming conventions. Throws a
ConfigurationError with a hopefully helpful message. If successful,
Expand All @@ -158,12 +149,9 @@ def split_reference(reference: str) -> Tuple[str, str]:
return items[0], items[1]

def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
if is_o2o:
fk_object: Union[OneToOneFieldInstance, ForeignKeyFieldInstance] = cast(
OneToOneFieldInstance, model._meta.fields_map[field]
)
else:
fk_object = cast(ForeignKeyFieldInstance, model._meta.fields_map[field])
fk_object = cast(
"OneToOneFieldInstance | ForeignKeyFieldInstance", model._meta.fields_map[field]
)
related_app_name, related_model_name = split_reference(fk_object.model_name)
related_model = get_related_model(related_app_name, related_model_name)

Expand Down Expand Up @@ -206,24 +194,24 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
f'backward relation "{backward_relation_name}" duplicates in'
f" model {related_model_name}"
)
if is_o2o:
fk_relation: Union[BackwardOneToOneRelation, BackwardFKRelation] = (
BackwardOneToOneRelation(
model,
key_field,
key_fk_object.source_field,
null=True,
description=fk_object.description,
)

fk_relation = (
BackwardOneToOneRelation(
model,
key_field,
key_fk_object.source_field,
null=True,
description=fk_object.description,
)
else:
fk_relation = BackwardFKRelation(
if is_o2o
else BackwardFKRelation(
model,
key_field,
key_fk_object.source_field,
null=fk_object.null,
description=fk_object.description,
)
)
fk_relation.to_field_instance = fk_object.to_field_instance # type:ignore
related_model._meta.add_field(backward_relation_name, fk_relation)
if is_o2o and fk_object.pk:
Expand Down Expand Up @@ -251,8 +239,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
m2m_object = cast(ManyToManyFieldInstance, model._meta.fields_map[field])
if m2m_object._generated:
continue
backward_key = m2m_object.backward_key
if not backward_key:
if not (backward_key := m2m_object.backward_key):
backward_key = f"{model._meta.db_table}_id"
if backward_key == m2m_object.forward_key:
backward_key = f"{model._meta.db_table}_rel_id"
Expand All @@ -264,8 +251,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:

m2m_object.related_model = related_model

backward_relation_name = m2m_object.related_name
if not backward_relation_name:
if not (backward_relation_name := m2m_object.related_name):
backward_relation_name = m2m_object.related_name = (
f"{model._meta.db_table}s"
)
Expand Down Expand Up @@ -295,9 +281,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
related_model._meta.add_field(backward_relation_name, m2m_relation)

@classmethod
def _discover_models(
cls, models_path: Union[ModuleType, str], app_label: str
) -> List[Type["Model"]]:
def _discover_models(cls, models_path: ModuleType | str, app_label: str) -> list[Type["Model"]]:
if isinstance(models_path, ModuleType):
module = models_path
else:
Expand All @@ -306,11 +290,11 @@ def _discover_models(
except ImportError:
raise ConfigurationError(f'Module "{models_path}" not found')
discovered_models = []
possible_models = getattr(module, "__models__", None)
try:
possible_models = [*possible_models] # type:ignore
except TypeError:
possible_models = None
if possible_models := getattr(module, "__models__", None):
try:
possible_models = [*possible_models]
except TypeError:
possible_models = None
if not possible_models:
possible_models = [getattr(module, attr_name) for attr_name in dir(module)]
for attr in possible_models:
Expand All @@ -326,7 +310,7 @@ def _discover_models(
@classmethod
def init_models(
cls,
models_paths: Iterable[Union[ModuleType, str]],
models_paths: Iterable[ModuleType | str],
app_label: str,
_init_relations: bool = True,
) -> None:
Expand All @@ -342,7 +326,7 @@ def init_models(

:raises ConfigurationError: If models are invalid.
"""
app_models: List[Type[Model]] = []
app_models: list[Type[Model]] = []
for models_path in models_paths:
app_models += cls._discover_models(models_path, app_label)

Expand All @@ -352,7 +336,7 @@ def init_models(
cls._init_relations()

@classmethod
def _init_apps(cls, apps_config: dict) -> None:
def _init_apps(cls, apps_config: dict[str, dict[str, Any]]) -> None:
for name, info in apps_config.items():
try:
connections.get(info.get("default_connection", "default"))
Expand Down Expand Up @@ -396,23 +380,23 @@ def _build_initial_querysets(cls) -> None:
model._meta.finalise_model()
model._meta.basetable = Table(name=model._meta.db_table, schema=model._meta.schema)
basequery = model._meta.db.query_class.from_(model._meta.basetable)
model._meta.basequery = basequery # type:ignore[assignment]
model._meta.basequery_all_fields = basequery.select(
*model._meta.db_fields
) # type:ignore[assignment]
model._meta.basequery = cast(Query, basequery)
model._meta.basequery_all_fields = cast(
Query, basequery.select(*model._meta.db_fields)
)

@classmethod
async def init(
cls,
config: Optional[dict] = None,
config_file: Optional[str] = None,
config: dict[str, Any] | None = None,
config_file: str | None = None,
_create_db: bool = False,
db_url: Optional[str] = None,
modules: Optional[Dict[str, Iterable[Union[str, ModuleType]]]] = None,
db_url: str | None = None,
modules: dict[str, Iterable[str | ModuleType]] | None = None,
use_tz: bool = False,
timezone: str = "UTC",
routers: Optional[List[Union[str, Type]]] = None,
table_name_generator: Optional[Callable[[Type["Model"]], str]] = None,
routers: list[str | type] | None = None,
table_name_generator: Callable[[Type["Model"]], str] | None = None,
) -> None:
"""
Sets up Tortoise-ORM.
Expand Down Expand Up @@ -516,8 +500,7 @@ async def init(
for name, info in connections_config.items():
if isinstance(info, str):
info = expand_db_url(info)
password = info.get("credentials", {}).get("password")
if password:
if password := info.get("credentials", {}).get("password"):
passwords.append(password)

str_connection_config = str(connections_config)
Expand All @@ -542,7 +525,7 @@ async def init(
cls._inited = True

@classmethod
def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None) -> None:
def _init_routers(cls, routers: list[str | type] | None = None) -> None:
from tortoise.router import router

routers = routers or []
Expand Down