Skip to content

Commit

Permalink
fix request context email verification
Browse files Browse the repository at this point in the history
  • Loading branch information
davidgamez committed Dec 4, 2024
1 parent f308820 commit fef4bcd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 47 deletions.
11 changes: 7 additions & 4 deletions api/src/middleware/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def _extract_from_headers(self, headers: dict, scope: Scope) -> None:
def __repr__(self) -> str:
# Omitting sensitive data like email and jwt assertion
safe_properties = dict(
user_id=self.user_id, client_user_agent=self.client_user_agent, client_host=self.client_host
user_id=self.user_id,
client_user_agent=self.client_user_agent,
client_host=self.client_host,
email=self.user_email,
)
return f"request-context={safe_properties})"

Expand All @@ -108,8 +111,8 @@ def is_user_email_restricted() -> bool:
Check if an email's domain is restricted (e.g., for WIP visibility).
"""
request_context = get_request_context()
if not isinstance(request_context, RequestContext):
return True # Default to restricted
email = get_request_context().user_email
if not request_context:
return True
email = request_context["user_email"]
unrestricted_domains = ["mobilitydata.org"]
return not email or not any(email.endswith(f"@{domain}") for domain in unrestricted_domains)
44 changes: 1 addition & 43 deletions api/tests/unittest/middleware/test_request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from starlette.datastructures import Headers

from middleware.request_context import RequestContext, get_request_context, _request_context, is_user_email_restricted
from middleware.request_context import RequestContext, get_request_context, _request_context


class TestRequestContext(unittest.TestCase):
Expand Down Expand Up @@ -54,45 +54,3 @@ def test_get_request_context(self):
request_context = RequestContext(MagicMock())
_request_context.set(request_context)
self.assertEqual(request_context, get_request_context())

def test_is_user_email_restricted(self):
self.assertTrue(is_user_email_restricted())
scope_instance = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": "GET",
"headers": [
(b"host", b"localhost"),
(b"x-forwarded-proto", b"https"),
(b"x-forwarded-for", b"client, proxy1"),
(b"server", b"server"),
(b"user-agent", b"user-agent"),
(b"x-goog-iap-jwt-assertion", b"jwt"),
(b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"),
(b"x-goog-authenticated-user-id", b"user_id"),
(b"x-goog-authenticated-user-email", b"email"),
],
"path": "/",
"raw_path": b"/",
"query_string": b"",
"client": ("127.0.0.1", 32767),
"server": ("127.0.0.1", 80),
}
request_context = RequestContext(scope=scope_instance)
_request_context.set(request_context)
self.assertTrue(is_user_email_restricted())
scope_instance["headers"] = [
(b"host", b"localhost"),
(b"x-forwarded-proto", b"https"),
(b"x-forwarded-for", b"client, proxy1"),
(b"server", b"server"),
(b"user-agent", b"user-agent"),
(b"x-goog-iap-jwt-assertion", b"jwt"),
(b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"),
(b"x-goog-authenticated-user-id", b"user_id"),
(b"x-goog-authenticated-user-email", b"[email protected]"),
]
request_context = RequestContext(scope=scope_instance)
_request_context.set(request_context)
self.assertFalse(is_user_email_restricted())

0 comments on commit fef4bcd

Please sign in to comment.