diff --git a/.bumpversion.cfg b/.bumpversion.cfg index a0c0178..b238b4e 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 2.0.2 +current_version = 2.1.0 commit = False tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+)(?P\d+))? diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml index e86ffb0..7300bb2 100644 --- a/.github/workflows/test-package.yml +++ b/.github/workflows/test-package.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] + python-version: ['3.11', '3.12'] pydantic-version: ['1.*', '2.*'] fail-fast: false steps: @@ -28,10 +28,12 @@ jobs: flit install --deps develop pip install -U "pydantic==${{ matrix.pydantic-version }}" pip install pydantic_settings || true - - name: Lint + - name: Check formatting run: | - black . --check - ruff . + black --check . + - name: Lint with ruff + run: | + ruff check . - name: MyPy run: | mypy -p oauth2_lib diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 92a330d..d960a47 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,23 @@ repos: - repo: https://github.com/asottile/pyupgrade - rev: v3.8.0 + rev: v3.16.0 hooks: - id: pyupgrade args: - - --py39-plus + - --py311-plus - --keep-runtime-typing - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 24.4.2 hooks: - id: black - language_version: python3.9 + language_version: python3.11 - repo: https://github.com/asottile/blacken-docs - rev: 1.14.0 + rev: 1.18.0 hooks: - id: blacken-docs additional_dependencies: [black==22.10.0] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - id: trailing-whitespace exclude: .bumpversion.cfg @@ -31,15 +31,15 @@ repos: - id: detect-private-key - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.0.275 + rev: v0.5.1 hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix, --show-fixes ] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.4.1 + rev: v1.10.1 hooks: - id: mypy - language_version: python3.9 + language_version: python3.11 additional_dependencies: [pydantic<2.0.0, strawberry-graphql] args: - --no-warn-unused-ignores @@ -53,10 +53,10 @@ repos: - id: python-check-mock-methods - id: rst-backticks - repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.9.0.5 + rev: v0.10.0.1 hooks: - id: shellcheck - repo: https://github.com/andreoliwa/nitpick - rev: v0.33.2 + rev: v0.35.0 hooks: - id: nitpick diff --git a/oauth2_lib/__init__.py b/oauth2_lib/__init__.py index 42202ab..39a489c 100644 --- a/oauth2_lib/__init__.py +++ b/oauth2_lib/__init__.py @@ -13,4 +13,4 @@ """This is the SURF Oauth2 module that interfaces with the oauth2 setup.""" -__version__ = "2.0.2" +__version__ = "2.1.0" diff --git a/oauth2_lib/async_api_client.py b/oauth2_lib/async_api_client.py index c55e097..02698f5 100644 --- a/oauth2_lib/async_api_client.py +++ b/oauth2_lib/async_api_client.py @@ -12,7 +12,7 @@ # limitations under the License. from asyncio import new_event_loop from http import HTTPStatus -from typing import Any, Union +from typing import Any import structlog from authlib.integrations.base_client import BaseOAuth @@ -61,7 +61,7 @@ class FubarApiClient(AuthMixin, fubar_client.ApiClient) """ - _token: Union[dict, None] + _token: dict | None def __init__( self, diff --git a/oauth2_lib/fastapi.py b/oauth2_lib/fastapi.py index bcf3262..dc2ede7 100644 --- a/oauth2_lib/fastapi.py +++ b/oauth2_lib/fastapi.py @@ -12,10 +12,10 @@ # limitations under the License. import ssl from abc import ABC, abstractmethod -from collections.abc import Awaitable, Mapping +from collections.abc import Awaitable, Callable, Mapping from http import HTTPStatus from json import JSONDecodeError -from typing import Any, Callable, Optional, Union, cast +from typing import Any, Optional, cast from fastapi import HTTPException from fastapi.requests import Request @@ -93,8 +93,8 @@ class OIDCConfig(BaseModel): authorization_endpoint: str token_endpoint: str userinfo_endpoint: str - introspect_endpoint: Optional[str] = None - introspection_endpoint: Optional[str] = None + introspect_endpoint: str | None = None + introspection_endpoint: str | None = None jwks_uri: str response_types_supported: list[str] response_modes_supported: list[str] @@ -126,7 +126,7 @@ class Authentication(ABC): """ @abstractmethod - async def authenticate(self, request: HTTPConnection, token: Optional[str] = None) -> Optional[dict]: + async def authenticate(self, request: HTTPConnection, token: str | None = None) -> dict | None: """Authenticate the user.""" pass @@ -138,7 +138,7 @@ class IdTokenExtractor(ABC): """ @abstractmethod - async def extract(self, request: Request) -> Optional[str]: + async def extract(self, request: Request) -> str | None: pass @@ -148,7 +148,7 @@ class HttpBearerExtractor(IdTokenExtractor): Specifically designed for HTTP Authorization header token extraction. """ - async def extract(self, request: Request) -> Optional[str]: + async def extract(self, request: Request) -> str | None: http_bearer = HTTPBearer(auto_error=False) credential = await http_bearer(request) @@ -168,7 +168,7 @@ def __init__( resource_server_id: str, resource_server_secret: str, oidc_user_model_cls: type[OIDCUserModel], - id_token_extractor: Optional[IdTokenExtractor] = None, + id_token_extractor: IdTokenExtractor | None = None, ): if not id_token_extractor: self.id_token_extractor = HttpBearerExtractor() @@ -179,9 +179,9 @@ def __init__( self.resource_server_secret = resource_server_secret self.user_model_cls = oidc_user_model_cls - self.openid_config: Optional[OIDCConfig] = None + self.openid_config: OIDCConfig | None = None - async def authenticate(self, request: HTTPConnection, token: Optional[str] = None) -> Optional[OIDCUserModel]: + async def authenticate(self, request: HTTPConnection, token: str | None = None) -> OIDCUserModel | None: """Return the OIDC user from OIDC introspect endpoint. This is used as a security module in Fastapi projects @@ -263,7 +263,7 @@ class Authorization(ABC): """ @abstractmethod - async def authorize(self, request: HTTPConnection, user: OIDCUserModel) -> Optional[bool]: + async def authorize(self, request: HTTPConnection, user: OIDCUserModel) -> bool | None: pass @@ -274,7 +274,7 @@ class GraphqlAuthorization(ABC): """ @abstractmethod - async def authorize(self, request: RequestPath, user: OIDCUserModel) -> Optional[bool]: + async def authorize(self, request: RequestPath, user: OIDCUserModel) -> bool | None: pass @@ -284,7 +284,7 @@ class OPAMixin: Supports getting and evaluating OPA policy decisions. """ - def __init__(self, opa_url: str, auto_error: bool = True, opa_kwargs: Union[Mapping[str, Any], None] = None): + def __init__(self, opa_url: str, auto_error: bool = True, opa_kwargs: Mapping[str, Any] | None = None): self.opa_url = opa_url self.auto_error = auto_error self.opa_kwargs = opa_kwargs @@ -324,7 +324,7 @@ class OPAAuthorization(Authorization, OPAMixin): Uses OAUTH2 settings and request information to authorize actions. """ - async def authorize(self, request: HTTPConnection, user_info: OIDCUserModel) -> Optional[bool]: + async def authorize(self, request: HTTPConnection, user_info: OIDCUserModel) -> bool | None: if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE): return None @@ -376,11 +376,11 @@ class GraphQLOPAAuthorization(GraphqlAuthorization, OPAMixin): Customizable to handle partial results without raising HTTP 403. """ - def __init__(self, opa_url: str, auto_error: bool = False, opa_kwargs: Union[Mapping[str, Any], None] = None): + def __init__(self, opa_url: str, auto_error: bool = False, opa_kwargs: Mapping[str, Any] | None = None): # By default don't raise HTTP 403 because partial results are preferred super().__init__(opa_url, auto_error, opa_kwargs) - async def authorize(self, request: RequestPath, user_info: OIDCUserModel) -> Optional[bool]: + async def authorize(self, request: RequestPath, user_info: OIDCUserModel) -> bool | None: if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE): return None diff --git a/oauth2_lib/strawberry.py b/oauth2_lib/strawberry.py index 08bd23d..0e5a09b 100644 --- a/oauth2_lib/strawberry.py +++ b/oauth2_lib/strawberry.py @@ -10,7 +10,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from enum import StrEnum, auto +from typing import Any import asyncstdlib import strawberry @@ -39,7 +41,7 @@ def __init__( super().__init__() @asyncstdlib.cached_property - async def get_current_user(self) -> Optional[OIDCUserModel]: + async def get_current_user(self) -> OIDCUserModel | None: """Retrieve the OIDCUserModel once per graphql request. Note: @@ -118,8 +120,16 @@ async def is_authorized(info: OauthInfo, path: str) -> bool: return authorized +class ErrorType(StrEnum): + """Subset of the ErrorType enum in nwa-stdlib.""" + + NOT_AUTHENTICATED = auto() + NOT_AUTHORIZED = auto() + + class IsAuthenticatedForQuery(BasePermission): message = "User is not authenticated" + error_extensions = {"error_type": ErrorType.NOT_AUTHENTICATED} async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore if not oauth2lib_settings.OAUTH2_ACTIVE: @@ -135,6 +145,7 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: class IsAuthenticatedForMutation(BasePermission): message = "User is not authenticated" + error_extensions = {"error_type": ErrorType.NOT_AUTHENTICATED} async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore mutations_active = oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.MUTATIONS_ENABLED @@ -145,6 +156,8 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: class IsAuthorizedForQuery(BasePermission): + error_extensions = {"error_type": ErrorType.NOT_AUTHORIZED} + async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE): logger.debug( @@ -163,6 +176,8 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: class IsAuthorizedForMutation(BasePermission): + error_extensions = {"error_type": ErrorType.NOT_AUTHORIZED} + async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore mutations_active = ( oauth2lib_settings.OAUTH2_ACTIVE @@ -182,9 +197,9 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: def authenticated_field( description: str, - resolver: Union[StrawberryResolver, Callable, staticmethod, classmethod, None] = None, - deprecation_reason: Union[str, None] = None, - permission_classes: Union[list[type[BasePermission]], None] = None, + resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None, + deprecation_reason: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, ) -> Any: permissions = permission_classes if permission_classes else [] return strawberry.field( @@ -197,9 +212,9 @@ def authenticated_field( def authenticated_mutation_field( description: str, - resolver: Union[StrawberryResolver, Callable, staticmethod, classmethod, None] = None, - deprecation_reason: Union[str, None] = None, - permission_classes: Union[list[type[BasePermission]], None] = None, + resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None, + deprecation_reason: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, ) -> Any: permissions = permission_classes if permission_classes else [] return strawberry.field( @@ -212,10 +227,10 @@ def authenticated_mutation_field( def authenticated_federated_field( # type: ignore description: str, - resolver: Union[StrawberryResolver, Callable, staticmethod, classmethod, None] = None, - deprecation_reason: Union[str, None] = None, - requires: Union[list[str], None] = None, - permission_classes: Union[list[type[BasePermission]], None] = None, + resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None, + deprecation_reason: str | None = None, + requires: list[str] | None = None, + permission_classes: list[type[BasePermission]] | None = None, **kwargs, ) -> Any: permissions = permission_classes if permission_classes else [] diff --git a/pyproject.toml b/pyproject.toml index 0a66fd8..d4b6897 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,8 +27,6 @@ classifiers = [ "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.9", ] requires = [ "requests>=2.19.0", @@ -41,7 +39,7 @@ requires = [ "asyncstdlib", ] description-file = "README.md" -requires-python = ">=3.9,<3.13" +requires-python = ">=3.11,<3.13" [tool.flit.metadata.urls] Documentation = "https://workfloworchestrator.org/" @@ -102,6 +100,10 @@ exclude = [ "build", ".venv", ] +target-version = "py311" +line-length = 120 + +[tool.ruff.lint] ignore = [ "B008", "D100", @@ -118,9 +120,9 @@ ignore = [ "B904", "N802", "N801", - "N818" + "N818", + "S113", # HTTPX has a default timeout ] -line-length = 120 select = [ "B", "C", @@ -134,7 +136,10 @@ select = [ "T", "W", ] -target-version = "py310" -[tool.ruff.pydocstyle] +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" + + +[tool.ruff.lint.pydocstyle] convention = "google" diff --git a/tests/conftest.py b/tests/conftest.py index d92aeed..ccdb857 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest @@ -7,7 +7,7 @@ class MockResponse: - def __init__(self, json: Optional[Any] = None, status_code: int = 200, error: Optional[Exception] = None): + def __init__(self, json: Any | None = None, status_code: int = 200, error: Exception | None = None): self.json = json self.status_code = status_code self.error = error @@ -38,7 +38,7 @@ def make_mock_async_client(): Pass a MockResponse for single or list for multiple sequential HTTP responses. """ - def _make_mock_async_client(mock_response: Union[MockResponse, list[MockResponse], None] = None): + def _make_mock_async_client(mock_response: MockResponse | list[MockResponse] | None = None): mock_async_client = AsyncMock(spec=AsyncClient) mock_responses = ([mock_response] if isinstance(mock_response, MockResponse) else mock_response) or [] diff --git a/tests/strawberry/conftest.py b/tests/strawberry/conftest.py index 0228cc0..4b7c735 100644 --- a/tests/strawberry/conftest.py +++ b/tests/strawberry/conftest.py @@ -1,5 +1,3 @@ -from typing import Optional - import pytest import strawberry from fastapi import Depends, FastAPI @@ -19,7 +17,7 @@ async def get_oidc_authentication(): class OIDCAuthMock(OIDCAuth): - async def userinfo(self, request: Request, token: Optional[str] = None) -> Optional[OIDCUserModel]: + async def userinfo(self, request: Request, token: str | None = None) -> OIDCUserModel | None: return user_info_matching return OIDCAuthMock("openid_url", "openid_url/.well-known/openid-configuration", "id", "secret", OIDCUserModel)