Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidgamez committed Dec 2, 2024
1 parent 4ad0194 commit 78dfb30
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 34 deletions.
11 changes: 11 additions & 0 deletions functions-python/operations_api/.coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[run]
omit =
*/test*/*
*/helpers/*
*/database_gen/*
*/dataset_service/*
*/feeds_operations_gen/*

[report]
exclude_lines =
if __name__ == .__main__.:
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,6 @@ def __init__(self, app: ASGIApp) -> None:
self.logger = logging.getLogger()
self.app = app

@staticmethod
def extract_response_info(headers):
"""
Extracts the content type and content length from the response headers.
"""
content_type = None
content_length = None
for key, value in headers:
if key == b"content-length":
content_length = int(value)
elif key == b"content-type":
content_type = value
return content_type, content_length

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Middleware to set the request context and authorize requests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@ def validate_token_with_google(token: str, google_client_id: str) -> dict:
"""
try:
response = get_tokeninfo_response(token)
if response.status_code != 200:
raise HTTPException(status_code=401, detail="Invalid access token")
except Exception as e:
logging.error(f"Token validation failed: {e}")
raise HTTPException(status_code=500, detail="Token validation failed")

token_info = response.json()
if response.status_code != 200:
raise HTTPException(status_code=401, detail="Invalid access token")

# Ensure the token is for the expected client
if token_info.get("audience") != google_client_id:
raise HTTPException(status_code=401, detail="Invalid token audience")
token_info = response.json()
# Ensure the token is for the expected client
if token_info.get("audience") != google_client_id:
raise HTTPException(status_code=401, detail="Invalid token audience")

return token_info
except requests.exceptions.RequestException as e:
logging.error(f"Token validation failed: {e}")
raise HTTPException(status_code=500, detail="Token validation failed")
return token_info


def get_tokeninfo_response(token):
Expand Down Expand Up @@ -113,12 +113,10 @@ def extract_authorization_oauth(headers: dict, google_client_id: str) -> str:
)

token = auth_header.split(" ")[1]
logging.info(f"Token: {token}")

token_info = get_token_info(token, google_client_id)

email = token_info.get("email")
logging.info(f"Email: {email}")
if not email:
raise HTTPException(status_code=400, detail="Email not found in token")

Expand Down Expand Up @@ -197,7 +195,7 @@ def _extract_from_headers(self, headers, scope: Scope) -> None:
# auth header is used for local development
self.user_email = headers.get("x-goog-authenticated-user-email")

if headers.get("authorization"):
if headers.get("authorization") is not None:
google_client_id = os.getenv("GOOGLE_CLIENT_ID")
self.user_email = extract_authorization_oauth(headers, google_client_id)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
import asyncio

from middleware.request_context_middleware import (
RequestContextMiddleware,
Expand Down Expand Up @@ -66,8 +67,6 @@ async def mock_call_next(scope: Scope, receive: Receive, send: Send) -> None:
async def mock_send(message):
pass

import asyncio

try:
await asyncio.wait_for(
middleware(request.scope, request.receive, mock_send), timeout=5.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@pytest.fixture
def scope():
def _scope(token):
return {
result = {
"type": "http",
"headers": [
(b"host", b"example.com"),
Expand All @@ -33,12 +33,15 @@ def _scope(token):
(b"user-agent", b"test-agent"),
(b"x-goog-iap-jwt-assertion", b"test-assertion"),
(b"x-cloud-trace-context", b"trace-id/span-id;o=1"),
(b"authorization", f"Bearer {token}".encode("utf-8")),
],
"client": ("192.168.1.1", 12345),
"server": ("127.0.0.1", 8000),
"scheme": "https",
}
if token is not None:
if token is not None:
result["headers"].append((b"authorization", f"Bearer {token}".encode()))
return result

return _scope

Expand Down Expand Up @@ -70,9 +73,69 @@ def test_request_context_initialization(
assert request_context.trace_id == "trace-id"
assert request_context.span_id == "span-id"
assert request_context.trace_sampled is True
assert (
request_context.user_email == "[email protected]"
) # Mock the email extraction
assert request_context.user_email == "[email protected]"


@patch("middleware.request_context_oauth2.get_tokeninfo_response")
def test_request_context_invalid_audience(
mock_get_tokeninfo_response, scope, monkeypatch
):
monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id_audience")
monkeypatch.setenv("LOCAL_ENV", "true")

mock_get_tokeninfo_response.return_value.status_code = 200
mock_get_tokeninfo_response.return_value.json.return_value = {
"email": "[email protected]",
"audience": "not-test-client-id",
"email_verified": True,
"expires_in": 3600,
}

mocked_scope = scope("test_request_context_invalid_audience")

with pytest.raises(HTTPException) as exc_info:
RequestContext(mocked_scope)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid token audience"


@patch("middleware.request_context_oauth2.get_tokeninfo_response")
def test_request_context_email_not_found(
mock_get_tokeninfo_response, scope, monkeypatch
):
monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id")
monkeypatch.setenv("LOCAL_ENV", "true")

mock_get_tokeninfo_response.return_value.status_code = 200
mock_get_tokeninfo_response.return_value.json.return_value = {
"audience": "test-client-id",
"email_verified": True,
"expires_in": 3600,
}

mocked_scope = scope("test_request_context_email_not_found")

with pytest.raises(HTTPException) as exc_info:
RequestContext(mocked_scope)
assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Email not found in token"


@patch("middleware.request_context_oauth2.get_tokeninfo_response")
def test_request_context_invalid_tokeninfo_exception(
mock_get_tokeninfo_response, scope, monkeypatch
):
monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id")
monkeypatch.setenv("LOCAL_ENV", "true")

mock_get_tokeninfo_response.side_effect = Exception("Test exception")

mocked_scope = scope("test_request_context_invalid_tokeninfo_exception")

with pytest.raises(HTTPException) as exc_info:
RequestContext(mocked_scope)
assert exc_info.value.status_code == 500
assert exc_info.value.detail == "Token validation failed"


def test_request_context_missing_authorization(scope, monkeypatch):
Expand Down Expand Up @@ -101,3 +164,16 @@ def test_request_context_invalid_token(mock_get_tokeninfo_response, scope, monke
RequestContext(scope("test_token_test_request_context_invalid_token"))
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid access token"


@patch("middleware.request_context_oauth2.get_tokeninfo_response")
def test_request_context_no_token(mock_get_tokeninfo_response, scope, monkeypatch):
monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id")
monkeypatch.setenv("LOCAL_ENV", "False")

mock_get_tokeninfo_response.return_value.status_code = 400

with pytest.raises(HTTPException) as exc_info:
RequestContext(scope(token=None))
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Authorization header not found"

0 comments on commit 78dfb30

Please sign in to comment.