Skip to content

Commit

Permalink
remove auto_error and make user optioanl in authorize method
Browse files Browse the repository at this point in the history
  • Loading branch information
torkashvandmt committed Jun 6, 2024
1 parent b9ced59 commit bb15218
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
29 changes: 17 additions & 12 deletions oauth2_lib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from httpx import AsyncClient, NetworkError
from pydantic import BaseModel
from starlette.requests import ClientDisconnect, HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN
from starlette.websockets import WebSocket
from structlog import get_logger

Expand Down Expand Up @@ -137,22 +138,22 @@ class IdTokenExtractor(ABC):
"""

@abstractmethod
async def extract(self, request: Request, auto_error: bool = True) -> Optional[str]:
async def extract(self, request: Request) -> Optional[str]:
pass


class HttpBearerExtractor(IdTokenExtractor):
"""Extracts bearer tokens using FastAPI's HTTPBearer.
Specifically designed for HTTP Authorization header token extraction.
By default, if an HTTP Bearer token is not provided in the `Authorization` header,
the `extract` method will cancel the request and send an error unless `auto_error`
is set to `False`, allowing optional or multiple authentication methods.
"""

async def extract(self, request: Request, auto_error: bool = True) -> Optional[str]:
http_bearer = HTTPBearer(auto_error=auto_error)
async def extract(self, request: Request) -> Optional[str]:
http_bearer = HTTPBearer(auto_error=False)

if not http_bearer:
return None

Check warning on line 155 in oauth2_lib/fastapi.py

View check run for this annotation

Codecov / codecov/patch

oauth2_lib/fastapi.py#L155

Added line #L155 was not covered by tests

credential = await http_bearer(request)

return credential.credentials if credential else None
Expand Down Expand Up @@ -218,7 +219,11 @@ async def authenticate(self, request: HTTPConnection, token: Optional[str] = Non
return None

if token is None:
token_or_extracted_id_token = await self.id_token_extractor.extract(request, auto_error=True) or ""
extracted_id_token = await self.id_token_extractor.extract(request)
if not extracted_id_token:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Not authenticated")

token_or_extracted_id_token = extracted_id_token
else:
token_or_extracted_id_token = token

Expand Down Expand Up @@ -262,7 +267,7 @@ class Authorization(ABC):
"""

@abstractmethod
async def authorize(self, request: HTTPConnection, user: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: HTTPConnection, user: Optional[OIDCUserModel] = None) -> Optional[bool]:
pass


Expand All @@ -273,7 +278,7 @@ class GraphqlAuthorization(ABC):
"""

@abstractmethod
async def authorize(self, request: RequestPath, user: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: RequestPath, user: Optional[OIDCUserModel] = None) -> Optional[bool]:
pass


Expand Down Expand Up @@ -323,7 +328,7 @@ class OPAAuthorization(Authorization, OPAMixin):
Uses OAUTH2 settings and request information to authorize actions.
"""

async def authorize(self, request: HTTPConnection, user_info: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: HTTPConnection, user_info: Optional[OIDCUserModel] = None) -> Optional[bool]:
if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE):
return None

Expand Down Expand Up @@ -379,7 +384,7 @@ def __init__(self, opa_url: str, auto_error: bool = False, opa_kwargs: Union[Map
# By default don't raise HTTP 403 because partial results are preferred
super().__init__(opa_url, auto_error, opa_kwargs)

async def authorize(self, request: RequestPath, user_info: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: RequestPath, user_info: Optional[OIDCUserModel] = None) -> Optional[bool]:
if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE):
return None

Expand Down
6 changes: 2 additions & 4 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,11 @@ async def test_extract_token_success():


@pytest.mark.asyncio
async def test_extract_token_failure():
async def test_extract_token_returns_none():
request = mock.MagicMock()
request.headers = {}
extractor = HttpBearerExtractor()
with pytest.raises(HTTPException) as exc_info:
await extractor.extract(request)
assert exc_info.value.status_code == 403, "Expected HTTP 403 error for missing token"
assert await extractor.extract(request) is None


@pytest.mark.asyncio
Expand Down

0 comments on commit bb15218

Please sign in to comment.