diff --git a/functions-python/operations_api/.coveragerc b/functions-python/operations_api/.coveragerc new file mode 100644 index 000000000..b664793c1 --- /dev/null +++ b/functions-python/operations_api/.coveragerc @@ -0,0 +1,11 @@ +[run] +omit = + */test*/* + */helpers/* + */database_gen/* + */dataset_service/* + */feeds_operations_gen/* + +[report] +exclude_lines = + if __name__ == .__main__.: \ No newline at end of file diff --git a/functions-python/operations_api/src/middleware/request_context_middleware.py b/functions-python/operations_api/src/middleware/request_context_middleware.py index d0d9dfc40..dc4d676e2 100644 --- a/functions-python/operations_api/src/middleware/request_context_middleware.py +++ b/functions-python/operations_api/src/middleware/request_context_middleware.py @@ -32,20 +32,6 @@ def __init__(self, app: ASGIApp) -> None: self.logger = logging.getLogger() self.app = app - @staticmethod - def extract_response_info(headers): - """ - Extracts the content type and content length from the response headers. - """ - content_type = None - content_length = None - for key, value in headers: - if key == b"content-length": - content_length = int(value) - elif key == b"content-type": - content_type = value - return content_type, content_length - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ Middleware to set the request context and authorize requests. diff --git a/functions-python/operations_api/src/middleware/request_context_oauth2.py b/functions-python/operations_api/src/middleware/request_context_oauth2.py index c889c6858..7f7d2100d 100644 --- a/functions-python/operations_api/src/middleware/request_context_oauth2.py +++ b/functions-python/operations_api/src/middleware/request_context_oauth2.py @@ -43,19 +43,19 @@ def validate_token_with_google(token: str, google_client_id: str) -> dict: """ try: response = get_tokeninfo_response(token) - if response.status_code != 200: - raise HTTPException(status_code=401, detail="Invalid access token") + except Exception as e: + logging.error(f"Token validation failed: {e}") + raise HTTPException(status_code=500, detail="Token validation failed") - token_info = response.json() + if response.status_code != 200: + raise HTTPException(status_code=401, detail="Invalid access token") - # Ensure the token is for the expected client - if token_info.get("audience") != google_client_id: - raise HTTPException(status_code=401, detail="Invalid token audience") + token_info = response.json() + # Ensure the token is for the expected client + if token_info.get("audience") != google_client_id: + raise HTTPException(status_code=401, detail="Invalid token audience") - return token_info - except requests.exceptions.RequestException as e: - logging.error(f"Token validation failed: {e}") - raise HTTPException(status_code=500, detail="Token validation failed") + return token_info def get_tokeninfo_response(token): @@ -113,12 +113,10 @@ def extract_authorization_oauth(headers: dict, google_client_id: str) -> str: ) token = auth_header.split(" ")[1] - logging.info(f"Token: {token}") token_info = get_token_info(token, google_client_id) email = token_info.get("email") - logging.info(f"Email: {email}") if not email: raise HTTPException(status_code=400, detail="Email not found in token") @@ -197,7 +195,7 @@ def _extract_from_headers(self, headers, scope: Scope) -> None: # auth header is used for local development self.user_email = headers.get("x-goog-authenticated-user-email") - if headers.get("authorization"): + if headers.get("authorization") is not None: google_client_id = os.getenv("GOOGLE_CLIENT_ID") self.user_email = extract_authorization_oauth(headers, google_client_id) else: diff --git a/functions-python/operations_api/tests/middleware/test_request_context_middleware.py b/functions-python/operations_api/tests/middleware/test_request_context_middleware.py index 94cd6eb2d..bf7920ee5 100644 --- a/functions-python/operations_api/tests/middleware/test_request_context_middleware.py +++ b/functions-python/operations_api/tests/middleware/test_request_context_middleware.py @@ -19,6 +19,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send +import asyncio from middleware.request_context_middleware import ( RequestContextMiddleware, @@ -66,8 +67,6 @@ async def mock_call_next(scope: Scope, receive: Receive, send: Send) -> None: async def mock_send(message): pass - import asyncio - try: await asyncio.wait_for( middleware(request.scope, request.receive, mock_send), timeout=5.0 diff --git a/functions-python/operations_api/tests/middleware/test_request_context_oauth2.py b/functions-python/operations_api/tests/middleware/test_request_context_oauth2.py index c9ed4fb12..53bee9617 100644 --- a/functions-python/operations_api/tests/middleware/test_request_context_oauth2.py +++ b/functions-python/operations_api/tests/middleware/test_request_context_oauth2.py @@ -24,7 +24,7 @@ @pytest.fixture def scope(): def _scope(token): - return { + result = { "type": "http", "headers": [ (b"host", b"example.com"), @@ -33,12 +33,15 @@ def _scope(token): (b"user-agent", b"test-agent"), (b"x-goog-iap-jwt-assertion", b"test-assertion"), (b"x-cloud-trace-context", b"trace-id/span-id;o=1"), - (b"authorization", f"Bearer {token}".encode("utf-8")), ], "client": ("192.168.1.1", 12345), "server": ("127.0.0.1", 8000), "scheme": "https", } + if token is not None: + if token is not None: + result["headers"].append((b"authorization", f"Bearer {token}".encode())) + return result return _scope @@ -70,9 +73,69 @@ def test_request_context_initialization( assert request_context.trace_id == "trace-id" assert request_context.span_id == "span-id" assert request_context.trace_sampled is True - assert ( - request_context.user_email == "test-email@example.com" - ) # Mock the email extraction + assert request_context.user_email == "test-email@example.com" + + +@patch("middleware.request_context_oauth2.get_tokeninfo_response") +def test_request_context_invalid_audience( + mock_get_tokeninfo_response, scope, monkeypatch +): + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id_audience") + monkeypatch.setenv("LOCAL_ENV", "true") + + mock_get_tokeninfo_response.return_value.status_code = 200 + mock_get_tokeninfo_response.return_value.json.return_value = { + "email": "test-email@example.com", + "audience": "not-test-client-id", + "email_verified": True, + "expires_in": 3600, + } + + mocked_scope = scope("test_request_context_invalid_audience") + + with pytest.raises(HTTPException) as exc_info: + RequestContext(mocked_scope) + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid token audience" + + +@patch("middleware.request_context_oauth2.get_tokeninfo_response") +def test_request_context_email_not_found( + mock_get_tokeninfo_response, scope, monkeypatch +): + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id") + monkeypatch.setenv("LOCAL_ENV", "true") + + mock_get_tokeninfo_response.return_value.status_code = 200 + mock_get_tokeninfo_response.return_value.json.return_value = { + "audience": "test-client-id", + "email_verified": True, + "expires_in": 3600, + } + + mocked_scope = scope("test_request_context_email_not_found") + + with pytest.raises(HTTPException) as exc_info: + RequestContext(mocked_scope) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Email not found in token" + + +@patch("middleware.request_context_oauth2.get_tokeninfo_response") +def test_request_context_invalid_tokeninfo_exception( + mock_get_tokeninfo_response, scope, monkeypatch +): + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id") + monkeypatch.setenv("LOCAL_ENV", "true") + + mock_get_tokeninfo_response.side_effect = Exception("Test exception") + + mocked_scope = scope("test_request_context_invalid_tokeninfo_exception") + + with pytest.raises(HTTPException) as exc_info: + RequestContext(mocked_scope) + assert exc_info.value.status_code == 500 + assert exc_info.value.detail == "Token validation failed" def test_request_context_missing_authorization(scope, monkeypatch): @@ -101,3 +164,16 @@ def test_request_context_invalid_token(mock_get_tokeninfo_response, scope, monke RequestContext(scope("test_token_test_request_context_invalid_token")) assert exc_info.value.status_code == 401 assert exc_info.value.detail == "Invalid access token" + + +@patch("middleware.request_context_oauth2.get_tokeninfo_response") +def test_request_context_no_token(mock_get_tokeninfo_response, scope, monkeypatch): + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id") + monkeypatch.setenv("LOCAL_ENV", "False") + + mock_get_tokeninfo_response.return_value.status_code = 400 + + with pytest.raises(HTTPException) as exc_info: + RequestContext(scope(token=None)) + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Authorization header not found"