From fef4bcd240e50847f87a93c746c23d0efc19e31c Mon Sep 17 00:00:00 2001 From: David Gamez Diaz <1192523+davidgamez@users.noreply.github.com> Date: Wed, 4 Dec 2024 14:03:43 -0500 Subject: [PATCH] fix request context email verification --- api/src/middleware/request_context.py | 11 +++-- .../middleware/test_request_context.py | 44 +------------------ 2 files changed, 8 insertions(+), 47 deletions(-) diff --git a/api/src/middleware/request_context.py b/api/src/middleware/request_context.py index 4a02ea6d8..842120785 100644 --- a/api/src/middleware/request_context.py +++ b/api/src/middleware/request_context.py @@ -94,7 +94,10 @@ def _extract_from_headers(self, headers: dict, scope: Scope) -> None: def __repr__(self) -> str: # Omitting sensitive data like email and jwt assertion safe_properties = dict( - user_id=self.user_id, client_user_agent=self.client_user_agent, client_host=self.client_host + user_id=self.user_id, + client_user_agent=self.client_user_agent, + client_host=self.client_host, + email=self.user_email, ) return f"request-context={safe_properties})" @@ -108,8 +111,8 @@ def is_user_email_restricted() -> bool: Check if an email's domain is restricted (e.g., for WIP visibility). """ request_context = get_request_context() - if not isinstance(request_context, RequestContext): - return True # Default to restricted - email = get_request_context().user_email + if not request_context: + return True + email = request_context["user_email"] unrestricted_domains = ["mobilitydata.org"] return not email or not any(email.endswith(f"@{domain}") for domain in unrestricted_domains) diff --git a/api/tests/unittest/middleware/test_request_context.py b/api/tests/unittest/middleware/test_request_context.py index 9c405edd2..3cb32057d 100644 --- a/api/tests/unittest/middleware/test_request_context.py +++ b/api/tests/unittest/middleware/test_request_context.py @@ -3,7 +3,7 @@ from starlette.datastructures import Headers -from middleware.request_context import RequestContext, get_request_context, _request_context, is_user_email_restricted +from middleware.request_context import RequestContext, get_request_context, _request_context class TestRequestContext(unittest.TestCase): @@ -54,45 +54,3 @@ def test_get_request_context(self): request_context = RequestContext(MagicMock()) _request_context.set(request_context) self.assertEqual(request_context, get_request_context()) - - def test_is_user_email_restricted(self): - self.assertTrue(is_user_email_restricted()) - scope_instance = { - "type": "http", - "asgi": {"version": "3.0"}, - "http_version": "1.1", - "method": "GET", - "headers": [ - (b"host", b"localhost"), - (b"x-forwarded-proto", b"https"), - (b"x-forwarded-for", b"client, proxy1"), - (b"server", b"server"), - (b"user-agent", b"user-agent"), - (b"x-goog-iap-jwt-assertion", b"jwt"), - (b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"), - (b"x-goog-authenticated-user-id", b"user_id"), - (b"x-goog-authenticated-user-email", b"email"), - ], - "path": "/", - "raw_path": b"/", - "query_string": b"", - "client": ("127.0.0.1", 32767), - "server": ("127.0.0.1", 80), - } - request_context = RequestContext(scope=scope_instance) - _request_context.set(request_context) - self.assertTrue(is_user_email_restricted()) - scope_instance["headers"] = [ - (b"host", b"localhost"), - (b"x-forwarded-proto", b"https"), - (b"x-forwarded-for", b"client, proxy1"), - (b"server", b"server"), - (b"user-agent", b"user-agent"), - (b"x-goog-iap-jwt-assertion", b"jwt"), - (b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"), - (b"x-goog-authenticated-user-id", b"user_id"), - (b"x-goog-authenticated-user-email", b"test@mobilitydata.org"), - ] - request_context = RequestContext(scope=scope_instance) - _request_context.set(request_context) - self.assertFalse(is_user_email_restricted())