Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document name error #34

Merged
merged 5 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .copier-answers.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ author_name: Ryan Morshead
project_description: Dependency injection without the boilerplate.
project_title: PyBooster
python_package_name: pybooster
python_version_range: '>=3.11,<4'
python_version_range: ">=3.11,<4"
repo_url: https://github.com/rmorshea/pybooster
99 changes: 99 additions & 0 deletions docs/src/recipes.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,102 @@ def transaction_provider() -> Iterator[Transaction]:
with session_provider() as session, session.begin():
yield session
```

## NameError in Type Hints

!!! note

This should not be an issue in Python 3.14 with [PEP-649](https://peps.python.org/pep-0649).

If you're encountering a `NameError` when PyBooster tries to infer what type is supplied
by a provider or required for an injector this is likely because you're using
`from __future__ import annotations` and the type hint is imported in an
`if TYPE_CHECKING` block. For example, this code raises `NameError`s because the
`Connection` type is not present at runtime:

```python
from __future__ import annotations

from contextlib import suppress
from typing import TYPE_CHECKING

from pybooster import injector
from pybooster import provider
from pybooster import required

if TYPE_CHECKING:
from sqlite3 import Connection


with suppress(NameError):

@provider.function
def connection_provider() -> Connection: ...

raise AssertionError("This should not be reached")


with suppress(NameError):

@injector.function
def query_database(*, conn: Connection = required) -> None: ...

raise AssertionError("This should not be reached")
```

To fix this, you can move the import outside of the block:

```python
from __future__ import annotations

from sqlite3 import Connection

from pybooster import injector
from pybooster import provider
from pybooster import required


@provider.function
def connection_provider() -> Connection: ...


@injector.function
def query_database(*, conn: Connection = required) -> None: ...
```

However, some linters like [Ruff](https://github.com/astral-sh/ruff) will automatically
move the import back into the block when they discover that the imported value is only
used as a type hint. To work around this, you can ignore the linter errors or use the
types in such a way that your linter understands they are required at runtime. In the
case of Ruff, you'd ignore the following errors:

- [TC001](https://docs.astral.sh/ruff/rules/typing-only-first-party-import/)
- [TC002](https://docs.astral.sh/ruff/rules/typing-only-third-party-import/)
- [TC003](https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/)

To convince the linter that types used by PyBooster are required at runtime, you can
pass them to the `provides` argument of the `provider` decorator or the `requires`
argument of an `injector` or `provider` decorator.

```python
from __future__ import annotations

from sqlite3 import Connection

from pybooster import injector
from pybooster import provider
from pybooster import required


@provider.function(provides=Connection)
def connection_provider() -> Connection: ...


@injector.function(requires=[Connection])
def query_database(*, conn: Connection = required) -> None: ...
```

!!! tip

Type checkers should still be able to check the return type using the `provides`
argument so it may not be necessary to annotate it in the function signature.
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ lint = [
"mdformat-tables==1.0.0",
"mdformat==0.7.19",
"pyright==1.1.389",
"ruff==0.7.3",
"ruff==0.8.1",
"yamlfix==1.17.0",
"doccmd==2024.11.14",
]
Expand Down Expand Up @@ -151,14 +151,18 @@ ban-relative-imports = "all"
"ANN", # Type annotations
"B018", # Useless expression
"D", # Docstrings
"EM101", # Assign error message to string
"FA102", # Unsafe __futures__ annotations usage
"INP001", # Implicit namespace package
"RUF029", # No await in async function
"S101", # Assert statements
"S106", # Possible passwords
"SIM115", # Use context manager for opening files
"T201", # Print
"TCH002", # Move third-party import into a type-checking block
"TC001", # Move first-party import into a type-checking block
"TC002", # Move third-party import into a type-checking block
"TC003", # Move standard-libarary import into a type-checking block
"TRY003", # Avoid specifying long messages outside the exception class
]

[tool.yamlfix]
Expand Down
6 changes: 3 additions & 3 deletions src/pybooster/_private/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_provides_type(
provides: type[R] | Callable[..., type[R]], *args: Any, **kwargs: Any
) -> type[R]:
if is_type(provides):
return cast(type[R], provides)
return cast("type[R]", provides)
elif callable(provides):
return provides(*args, **kwargs)
else:
Expand Down Expand Up @@ -141,15 +141,15 @@ def _get_scalar_provider_infos(
if is_sync:
info = SyncProviderInfo(
is_sync=is_sync,
producer=cast(ContextManagerCallable[[], Any], producer),
producer=cast("ContextManagerCallable[[], Any]", producer),
provides=provides,
required_parameters=required_parameters,
getter=getter,
)
else:
info = AsyncProviderInfo(
is_sync=is_sync,
producer=cast(AsyncContextManagerCallable[[], Any], producer),
producer=cast("AsyncContextManagerCallable[[], Any]", producer),
provides=provides,
required_parameters=required_parameters,
getter=getter,
Expand Down
42 changes: 33 additions & 9 deletions src/pybooster/_private/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Callable
from collections.abc import Coroutine
from collections.abc import Iterator
from collections.abc import Mapping
from collections.abc import Sequence
from dataclasses import dataclass
from inspect import Parameter
Expand Down Expand Up @@ -37,6 +38,7 @@
from anyio.abc import TaskGroup

from pybooster.types import HintMap
from pybooster.types import HintSeq

P = ParamSpec("P")
R = TypeVar("R")
Expand Down Expand Up @@ -81,23 +83,45 @@ def make_sentinel_value(module: str, name: str) -> Any:
"""Represents an undefined default."""


def get_required_parameters(func: Callable, dependencies: HintMap | None = None) -> HintMap:
return dependencies if dependencies is not None else _get_required_parameters(func)
def get_required_parameters(
func: Callable, dependencies: HintMap | HintSeq | None = None
) -> HintMap:
match dependencies:
case None:
return _get_required_parameters(func)
case Mapping():
return dependencies
case Sequence():
params = _get_required_sig_parameters(func)
if (lpar := len(params)) != (ldep := len(dependencies)):
msg = f"Could not match {ldep} dependencies to {lpar} required parameters."
raise TypeError(msg)
return dict(zip((p.name for p in params), dependencies, strict=False))
case _: # nocov
msg = f"Expected a mapping or sequence of dependencies, but got {dependencies!r}."
raise TypeError(msg)


def _get_required_parameters(func: Callable[P, R]) -> HintMap:
required_params: dict[str, type] = {}
hints = get_type_hints(func, include_extras=True)
for param in signature(func).parameters.values():
if param.default is pybooster.required:
if param.kind is not Parameter.KEYWORD_ONLY:
msg = f"Expected dependant parameter {param!r} to be keyword-only."
raise TypeError(msg)
check_is_required_type(hint := hints[param.name])
required_params[param.name] = hint
for param in _get_required_sig_parameters(func):
check_is_required_type(hint := hints[param.name])
required_params[param.name] = hint
return required_params


def _get_required_sig_parameters(func: Callable[P, R]) -> list[Parameter]:
params: list[Parameter] = []
for p in signature(func).parameters.values():
if p.default is pybooster.required:
if p.kind is not Parameter.KEYWORD_ONLY:
msg = f"Expected dependant parameter {p!r} to be keyword-only."
raise TypeError(msg)
params.append(p)
return params


def get_raw_annotation(anno: Any) -> RawAnnotation:
return RawAnnotation(get_args(anno)[0] if get_origin(anno) is Annotated else anno)

Expand Down
18 changes: 10 additions & 8 deletions src/pybooster/core/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
if TYPE_CHECKING:
from collections.abc import Callable

from pybooster.types import HintSeq

if TYPE_CHECKING:
from collections.abc import AsyncIterator
from collections.abc import Coroutine
Expand All @@ -46,7 +48,7 @@
def function(
func: Callable[P, R],
*,
requires: HintMap | None = None,
requires: HintMap | HintSeq | None = None,
shared: bool = False,
) -> Callable[P, R]:
"""Inject dependencies into the given function.
Expand Down Expand Up @@ -102,7 +104,7 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore[repor
def iterator(
func: IteratorCallable[P, R],
*,
requires: HintMap | None = None,
requires: HintMap | HintSeq | None = None,
shared: bool = False,
) -> IteratorCallable[P, R]:
"""Inject dependencies into the given iterator.
Expand Down Expand Up @@ -130,7 +132,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[R]:
def asynciterator(
func: AsyncIteratorCallable[P, R],
*,
requires: HintMap | None = None,
requires: HintMap | HintSeq | None = None,
shared: bool = False,
) -> AsyncIteratorCallable[P, R]:
"""Inject dependencies into the given async iterator.
Expand Down Expand Up @@ -159,7 +161,7 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[R]:
def contextmanager(
func: IteratorCallable[P, R],
*,
requires: HintMap | None = None,
requires: HintMap | HintSeq | None = None,
shared: bool = False,
) -> Callable[P, AbstractContextManager[R]]:
"""Inject dependencies into the given context manager function.
Expand All @@ -176,7 +178,7 @@ def contextmanager(
def asynccontextmanager(
func: AsyncIteratorCallable[P, R],
*,
requires: HintMap | None = None,
requires: HintMap | HintSeq | None = None,
shared: bool = False,
) -> Callable[P, AbstractAsyncContextManager[R]]:
"""Inject dependencies into the given async context manager function.
Expand Down Expand Up @@ -206,7 +208,7 @@ def shared(*args: type | tuple[type, Any]) -> _SharedContext:

def current_values() -> CurrentValues:
"""Get a mapping from dependency types to their current values."""
return cast(CurrentValues, dict(_CURRENT_VALUES.get()))
return cast("CurrentValues", dict(_CURRENT_VALUES.get()))


class CurrentValues(Mapping[type, Any]):
Expand Down Expand Up @@ -239,7 +241,7 @@ def __enter__(self) -> CurrentValues:
self._param_deps,
keep_current_values=True,
)
return cast(CurrentValues, {self._param_deps[k]: v for k, v in params.items()})
return cast("CurrentValues", {self._param_deps[k]: v for k, v in params.items()})

def __exit__(self, *_: Any) -> None:
try:
Expand All @@ -259,7 +261,7 @@ async def __aenter__(self) -> CurrentValues:
self._param_deps,
keep_current_values=True,
)
return cast(CurrentValues, {self._param_deps[k]: v for k, v in params.items()})
return cast("CurrentValues", {self._param_deps[k]: v for k, v in params.items()})

async def __aexit__(self, *exc: Any) -> None:
try:
Expand Down
13 changes: 7 additions & 6 deletions src/pybooster/core/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pybooster.types import AsyncIteratorCallable
from pybooster.types import ContextManagerCallable
from pybooster.types import HintMap
from pybooster.types import HintSeq
from pybooster.types import IteratorCallable

P = ParamSpec("P")
Expand All @@ -43,7 +44,7 @@
def function(
func: Callable[P, R],
*,
requires: HintMap | None = None,
requires: HintMap | HintSeq | None = None,
provides: type[R] | Callable[..., type[R]] | None = None,
) -> SyncProvider[P, R]:
"""Create a provider from the given function.
Expand All @@ -66,7 +67,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[R]:
def asyncfunction(
func: Callable[P, Awaitable[R]],
*,
requires: HintMap | None = None,
requires: HintMap | HintSeq | None = None,
provides: type[R] | Callable[..., type[R]] | None = None,
) -> AsyncProvider[P, R]:
"""Create a provider from the given coroutine.
Expand All @@ -89,7 +90,7 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[R]:
def iterator(
func: IteratorCallable[P, R],
*,
requires: HintMap | None = None,
requires: HintMap | HintSeq | None = None,
provides: type[R] | Callable[..., type[R]] | None = None,
) -> SyncProvider[P, R]:
"""Create a provider from the given iterator function.
Expand All @@ -101,14 +102,14 @@ def iterator(
"""
provides = provides or get_iterator_yield_type(func, sync=True)
requires = get_required_parameters(func, requires)
return SyncProvider(_contextmanager(func), cast(type[R], provides), requires)
return SyncProvider(_contextmanager(func), cast("type[R]", provides), requires)


@paramorator
def asynciterator(
func: AsyncIteratorCallable[P, R],
*,
requires: HintMap | None = None,
requires: HintMap | HintSeq | None = None,
provides: type[R] | Callable[..., type[R]] | None = None,
) -> AsyncProvider[P, R]:
"""Create a provider from the given async iterator function.
Expand All @@ -120,7 +121,7 @@ def asynciterator(
"""
provides = provides or get_iterator_yield_type(func, sync=False)
requires = get_required_parameters(func, requires)
return AsyncProvider(_asynccontextmanager(func), cast(type[R], provides), requires)
return AsyncProvider(_asynccontextmanager(func), cast("type[R]", provides), requires)


class _BaseProvider(Generic[R]):
Expand Down
Loading
Loading