Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add error_type extension to Strawberry Permission classes & Drop Python 3.9/3.10 #64

Merged
merged 3 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 2.0.2
current_version = 2.1.0
commit = False
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
Expand Down
10 changes: 6 additions & 4 deletions .github/workflows/test-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
python-version: ['3.11', '3.12']
pydantic-version: ['1.*', '2.*']
fail-fast: false
steps:
Expand All @@ -28,10 +28,12 @@ jobs:
flit install --deps develop
pip install -U "pydantic==${{ matrix.pydantic-version }}"
pip install pydantic_settings || true
- name: Lint
- name: Check formatting
run: |
black . --check
ruff .
black --check .
- name: Lint with ruff
run: |
ruff check .
- name: MyPy
run: |
mypy -p oauth2_lib
Expand Down
22 changes: 11 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
repos:
- repo: https://github.com/asottile/pyupgrade
rev: v3.8.0
rev: v3.16.0
hooks:
- id: pyupgrade
args:
- --py39-plus
- --py311-plus
- --keep-runtime-typing
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 24.4.2
hooks:
- id: black
language_version: python3.9
language_version: python3.11
- repo: https://github.com/asottile/blacken-docs
rev: 1.14.0
rev: 1.18.0
hooks:
- id: blacken-docs
additional_dependencies: [black==22.10.0]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
exclude: .bumpversion.cfg
Expand All @@ -31,15 +31,15 @@ repos:
- id: detect-private-key
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.275
rev: v0.5.1
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix, --show-fixes ]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
rev: v1.10.1
hooks:
- id: mypy
language_version: python3.9
language_version: python3.11
additional_dependencies: [pydantic<2.0.0, strawberry-graphql]
args:
- --no-warn-unused-ignores
Expand All @@ -53,10 +53,10 @@ repos:
- id: python-check-mock-methods
- id: rst-backticks
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.9.0.5
rev: v0.10.0.1
hooks:
- id: shellcheck
- repo: https://github.com/andreoliwa/nitpick
rev: v0.33.2
rev: v0.35.0
hooks:
- id: nitpick
2 changes: 1 addition & 1 deletion oauth2_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

"""This is the SURF Oauth2 module that interfaces with the oauth2 setup."""

__version__ = "2.0.2"
__version__ = "2.1.0"
4 changes: 2 additions & 2 deletions oauth2_lib/async_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
from asyncio import new_event_loop
from http import HTTPStatus
from typing import Any, Union
from typing import Any

import structlog
from authlib.integrations.base_client import BaseOAuth
Expand Down Expand Up @@ -61,7 +61,7 @@ class FubarApiClient(AuthMixin, fubar_client.ApiClient)

"""

_token: Union[dict, None]
_token: dict | None

def __init__(
self,
Expand Down
32 changes: 16 additions & 16 deletions oauth2_lib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# limitations under the License.
import ssl
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Mapping
from collections.abc import Awaitable, Callable, Mapping
from http import HTTPStatus
from json import JSONDecodeError
from typing import Any, Callable, Optional, Union, cast
from typing import Any, Optional, cast

from fastapi import HTTPException
from fastapi.requests import Request
Expand Down Expand Up @@ -93,8 +93,8 @@ class OIDCConfig(BaseModel):
authorization_endpoint: str
token_endpoint: str
userinfo_endpoint: str
introspect_endpoint: Optional[str] = None
introspection_endpoint: Optional[str] = None
introspect_endpoint: str | None = None
introspection_endpoint: str | None = None
jwks_uri: str
response_types_supported: list[str]
response_modes_supported: list[str]
Expand Down Expand Up @@ -126,7 +126,7 @@ class Authentication(ABC):
"""

@abstractmethod
async def authenticate(self, request: HTTPConnection, token: Optional[str] = None) -> Optional[dict]:
async def authenticate(self, request: HTTPConnection, token: str | None = None) -> dict | None:
"""Authenticate the user."""
pass

Expand All @@ -138,7 +138,7 @@ class IdTokenExtractor(ABC):
"""

@abstractmethod
async def extract(self, request: Request) -> Optional[str]:
async def extract(self, request: Request) -> str | None:
pass


Expand All @@ -148,7 +148,7 @@ class HttpBearerExtractor(IdTokenExtractor):
Specifically designed for HTTP Authorization header token extraction.
"""

async def extract(self, request: Request) -> Optional[str]:
async def extract(self, request: Request) -> str | None:
http_bearer = HTTPBearer(auto_error=False)
credential = await http_bearer(request)

Expand All @@ -168,7 +168,7 @@ def __init__(
resource_server_id: str,
resource_server_secret: str,
oidc_user_model_cls: type[OIDCUserModel],
id_token_extractor: Optional[IdTokenExtractor] = None,
id_token_extractor: IdTokenExtractor | None = None,
):
if not id_token_extractor:
self.id_token_extractor = HttpBearerExtractor()
Expand All @@ -179,9 +179,9 @@ def __init__(
self.resource_server_secret = resource_server_secret
self.user_model_cls = oidc_user_model_cls

self.openid_config: Optional[OIDCConfig] = None
self.openid_config: OIDCConfig | None = None

async def authenticate(self, request: HTTPConnection, token: Optional[str] = None) -> Optional[OIDCUserModel]:
async def authenticate(self, request: HTTPConnection, token: str | None = None) -> OIDCUserModel | None:
"""Return the OIDC user from OIDC introspect endpoint.

This is used as a security module in Fastapi projects
Expand Down Expand Up @@ -263,7 +263,7 @@ class Authorization(ABC):
"""

@abstractmethod
async def authorize(self, request: HTTPConnection, user: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: HTTPConnection, user: OIDCUserModel) -> bool | None:
pass


Expand All @@ -274,7 +274,7 @@ class GraphqlAuthorization(ABC):
"""

@abstractmethod
async def authorize(self, request: RequestPath, user: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: RequestPath, user: OIDCUserModel) -> bool | None:
pass


Expand All @@ -284,7 +284,7 @@ class OPAMixin:
Supports getting and evaluating OPA policy decisions.
"""

def __init__(self, opa_url: str, auto_error: bool = True, opa_kwargs: Union[Mapping[str, Any], None] = None):
def __init__(self, opa_url: str, auto_error: bool = True, opa_kwargs: Mapping[str, Any] | None = None):
self.opa_url = opa_url
self.auto_error = auto_error
self.opa_kwargs = opa_kwargs
Expand Down Expand Up @@ -324,7 +324,7 @@ class OPAAuthorization(Authorization, OPAMixin):
Uses OAUTH2 settings and request information to authorize actions.
"""

async def authorize(self, request: HTTPConnection, user_info: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: HTTPConnection, user_info: OIDCUserModel) -> bool | None:
if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE):
return None

Expand Down Expand Up @@ -376,11 +376,11 @@ class GraphQLOPAAuthorization(GraphqlAuthorization, OPAMixin):
Customizable to handle partial results without raising HTTP 403.
"""

def __init__(self, opa_url: str, auto_error: bool = False, opa_kwargs: Union[Mapping[str, Any], None] = None):
def __init__(self, opa_url: str, auto_error: bool = False, opa_kwargs: Mapping[str, Any] | None = None):
# By default don't raise HTTP 403 because partial results are preferred
super().__init__(opa_url, auto_error, opa_kwargs)

async def authorize(self, request: RequestPath, user_info: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: RequestPath, user_info: OIDCUserModel) -> bool | None:
if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE):
return None

Expand Down
39 changes: 27 additions & 12 deletions oauth2_lib/strawberry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from enum import StrEnum, auto
from typing import Any

import asyncstdlib
import strawberry
Expand Down Expand Up @@ -39,7 +41,7 @@ def __init__(
super().__init__()

@asyncstdlib.cached_property
async def get_current_user(self) -> Optional[OIDCUserModel]:
async def get_current_user(self) -> OIDCUserModel | None:
"""Retrieve the OIDCUserModel once per graphql request.

Note:
Expand Down Expand Up @@ -118,8 +120,16 @@ async def is_authorized(info: OauthInfo, path: str) -> bool:
return authorized


class ErrorType(StrEnum):
"""Subset of the ErrorType enum in nwa-stdlib."""

NOT_AUTHENTICATED = auto()
NOT_AUTHORIZED = auto()


class IsAuthenticatedForQuery(BasePermission):
message = "User is not authenticated"
error_extensions = {"error_type": ErrorType.NOT_AUTHENTICATED}

async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore
if not oauth2lib_settings.OAUTH2_ACTIVE:
Expand All @@ -135,6 +145,7 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool:

class IsAuthenticatedForMutation(BasePermission):
message = "User is not authenticated"
error_extensions = {"error_type": ErrorType.NOT_AUTHENTICATED}

async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore
mutations_active = oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.MUTATIONS_ENABLED
Expand All @@ -145,6 +156,8 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool:


class IsAuthorizedForQuery(BasePermission):
error_extensions = {"error_type": ErrorType.NOT_AUTHORIZED}

async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore
if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE):
logger.debug(
Expand All @@ -163,6 +176,8 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool:


class IsAuthorizedForMutation(BasePermission):
error_extensions = {"error_type": ErrorType.NOT_AUTHORIZED}

async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore
mutations_active = (
oauth2lib_settings.OAUTH2_ACTIVE
Expand All @@ -182,9 +197,9 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool:

def authenticated_field(
description: str,
resolver: Union[StrawberryResolver, Callable, staticmethod, classmethod, None] = None,
deprecation_reason: Union[str, None] = None,
permission_classes: Union[list[type[BasePermission]], None] = None,
resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None,
deprecation_reason: str | None = None,
permission_classes: list[type[BasePermission]] | None = None,
) -> Any:
permissions = permission_classes if permission_classes else []
return strawberry.field(
Expand All @@ -197,9 +212,9 @@ def authenticated_field(

def authenticated_mutation_field(
description: str,
resolver: Union[StrawberryResolver, Callable, staticmethod, classmethod, None] = None,
deprecation_reason: Union[str, None] = None,
permission_classes: Union[list[type[BasePermission]], None] = None,
resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None,
deprecation_reason: str | None = None,
permission_classes: list[type[BasePermission]] | None = None,
) -> Any:
permissions = permission_classes if permission_classes else []
return strawberry.field(
Expand All @@ -212,10 +227,10 @@ def authenticated_mutation_field(

def authenticated_federated_field( # type: ignore
description: str,
resolver: Union[StrawberryResolver, Callable, staticmethod, classmethod, None] = None,
deprecation_reason: Union[str, None] = None,
requires: Union[list[str], None] = None,
permission_classes: Union[list[type[BasePermission]], None] = None,
resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None,
deprecation_reason: str | None = None,
requires: list[str] | None = None,
permission_classes: list[type[BasePermission]] | None = None,
**kwargs,
) -> Any:
permissions = permission_classes if permission_classes else []
Expand Down
19 changes: 12 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ classifiers = [
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.9",
]
requires = [
"requests>=2.19.0",
Expand All @@ -41,7 +39,7 @@ requires = [
"asyncstdlib",
]
description-file = "README.md"
requires-python = ">=3.9,<3.13"
requires-python = ">=3.11,<3.13"

[tool.flit.metadata.urls]
Documentation = "https://workfloworchestrator.org/"
Expand Down Expand Up @@ -102,6 +100,10 @@ exclude = [
"build",
".venv",
]
target-version = "py311"
line-length = 120

[tool.ruff.lint]
ignore = [
"B008",
"D100",
Expand All @@ -118,9 +120,9 @@ ignore = [
"B904",
"N802",
"N801",
"N818"
"N818",
"S113", # HTTPX has a default timeout
]
line-length = 120
select = [
"B",
"C",
Expand All @@ -134,7 +136,10 @@ select = [
"T",
"W",
]
target-version = "py310"

[tool.ruff.pydocstyle]
[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"


[tool.ruff.lint.pydocstyle]
convention = "google"
Loading
Loading