Skip to content

Commit

Permalink
Added support for "OIDC_TOKEN" for Dapla Lab. Fixed tests to support …
Browse files Browse the repository at this point in the history
…both auth via jupyterhub and Dapla Lab
  • Loading branch information
RupinderKaurSSB committed Oct 8, 2024
1 parent 7d49e37 commit 7ffc94d
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 37 deletions.
66 changes: 59 additions & 7 deletions src/dapla/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def _get_current_dapla_metadata() -> (

return env, service, region

@staticmethod
def get_dapla_region() -> Optional[DaplaRegion]:
"""Checks if the current Dapla Region is Dapla Lab."""
env, service, region = AuthClient._get_current_dapla_metadata()
return region

@staticmethod
def _refresh_handler(
request: google.auth.transport.Request,
Expand Down Expand Up @@ -302,13 +308,40 @@ def fetch_google_credentials(force_token_exchange: bool = False) -> Credentials:

@staticmethod
def fetch_personal_token() -> str:
"""Fetches the personal access token for the current user."""
try:
personal_token = AuthClient.fetch_local_user_from_jupyter()["access_token"]
return t.cast(str, personal_token)
except AuthError as err:
err._print_warning()
raise err
"""If Dapla Region is Dapla Lab, retrieve the OIDC token/Keycloak token from the environment.
Returns:
str: The OIDC token.
Raises:
MissingConfigurationException: If the OIDC_TOKEN environment variable is missing or is not set.
If Dapla Region is BIP, retrieve the Keycloak token jupyterhub.
Returns:
str: personal/keycloak token.
Raises:
AuthError: Handles AuthError.
"""
env, service, region = AuthClient._get_current_dapla_metadata()
if region == DaplaRegion.DAPLA_LAB:
logger.debug("Auth - Dapla Lab detected, using OIDC_TOKEN")
keycloak_token = os.getenv("OIDC_TOKEN")
if not keycloak_token:
raise MissingConfigurationException("OIDC_TOKEN")
else:
return keycloak_token
else:
logger.debug("Auth - BIP detected, using jupyterhub personal token")
try:
personal_token = AuthClient.fetch_local_user_from_jupyter()[
"access_token"
]
return t.cast(str, personal_token)
except AuthError as err:
err._print_warning()
raise err

@staticmethod
@lru_cache(maxsize=1)
Expand Down Expand Up @@ -336,3 +369,22 @@ def _print_warning(self) -> None:
)
)
)


class MissingConfigurationException(Exception):
"""Exception raised when a required environment variable or configuration is missing."""

def __init__(self, variable_name: str) -> None:
"""Initializes a new instance of the MissingConfigurationException class.
Args:
variable_name (str): The name of the missing environment variable or configuration.
message (str): The error message to be displayed. Defaults to an empty string.
"""
self.variable_name = variable_name
self.message = f"Missing required environment variable: {variable_name}"
super().__init__(self.message)

def __str__(self) -> str:
"""Returns a string representation of the exception."""
return f"Configuration error: {self.message}"
26 changes: 18 additions & 8 deletions src/dapla/doctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gcsfs.retry import HttpError

from dapla.auth import AuthClient
from dapla.auth import DaplaRegion

logger = logging.getLogger(__name__)

Expand All @@ -21,14 +22,23 @@ class Doctor:

@staticmethod
def jupyterhub_auth_valid() -> bool:
"""Checks wheter user is logged in and authenticated to Jupyterhub."""
print("Checking authentication to JupyterHub...")
try:
# Attempt fetching the Jupyterhub user
AuthClient.fetch_local_user_from_jupyter()
except Exception:
return False
return True
"""Checks whether user is logged in and authenticated to Jupyterhub or Dapla Lab."""
print("Checking dapla region")
if AuthClient.get_dapla_region() == DaplaRegion.DAPLA_LAB:
print("Checking authentication to Dapla Lab...")
try:
AuthClient.fetch_personal_token()
except Exception:
return False
return True
else:
print("Checking authentication to JupyterHub...")
try:
# Attempt fetching the Jupyterhub user
AuthClient.fetch_local_user_from_jupyter()
except Exception:
return False
return True

@staticmethod
def keycloak_token_valid() -> bool:
Expand Down
45 changes: 42 additions & 3 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,39 @@
import dapla
from dapla.auth import AuthClient
from dapla.auth import AuthError
from dapla.auth import MissingConfigurationException

auth_endpoint_url = "https://mock-auth.no/user"


@mock.patch.dict(
"dapla.auth.os.environ", {"LOCAL_USER_PATH": auth_endpoint_url}, clear=True
"dapla.auth.os.environ",
{
"DAPLA_SERVICE": "JUPYTERLAB",
"DAPLA_REGION": "DAPLA_LAB",
"OIDC_TOKEN": "dummy_token",
},
clear=True,
)
@responses.activate
def test_fetch_personal_token() -> None:
def test_fetch_personal_token_for_dapla_lab() -> None:
client = AuthClient()
token = client.fetch_personal_token()

assert token == "dummy_token"


@mock.patch.dict(
"dapla.auth.os.environ",
{
"DAPLA_SERVICE": "JUPYTERLAB",
"DAPLA_REGION": "BIP",
"LOCAL_USER_PATH": auth_endpoint_url,
},
clear=True,
)
@responses.activate
def test_fetch_personal_token_for_jupyterhub() -> None:
mock_response = {
"access_token": "fake_access_token",
}
Expand All @@ -37,7 +61,7 @@ def test_fetch_personal_token() -> None:
)
@mock.patch("dapla.auth.display")
@responses.activate
def test_fetch_personal_token_error(mock_display: Mock) -> None:
def test_fetch_personal_token_error_on_jupyterhub(mock_display: Mock) -> None:
mock_response = {
"message": "There was an error",
}
Expand All @@ -49,6 +73,21 @@ def test_fetch_personal_token_error(mock_display: Mock) -> None:
mock_display.assert_called_once()


@mock.patch.dict(
"dapla.auth.os.environ",
{"DAPLA_SERVICE": "JUPYTERLAB", "DAPLA_REGION": "DAPLA_LAB", "OIDC_TOKEN": ""},
clear=True,
)
@responses.activate
def test_fetch_personal_token_error_on_dapla_lab() -> None:
with pytest.raises(MissingConfigurationException) as exception:
AuthClient().fetch_personal_token()
assert (
str(exception.value)
== "Configuration error: Missing required environment variable: OIDC_TOKEN"
)


@mock.patch.dict(
"dapla.auth.os.environ",
{"OIDC_TOKEN_EXCHANGE_URL": auth_endpoint_url, "OIDC_TOKEN": "dummy_token"},
Expand Down
133 changes: 114 additions & 19 deletions tests/test_converter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
from unittest import mock
from unittest.mock import Mock

import responses

from dapla.converter import ConverterClient

auth_endpoint_url = "https://mock-auth.no/user"

converter_test_url = "https://mock-converter.no"
fake_token = "1234567890"

Expand Down Expand Up @@ -109,10 +110,17 @@
"""


@mock.patch("dapla.auth.AuthClient")
@mock.patch.dict(
"dapla.auth.os.environ",
{
"DAPLA_SERVICE": "JUPYTERLAB",
"DAPLA_REGION": "DAPLA_LAB",
"OIDC_TOKEN": "dummy_token",
},
clear=True,
)
@responses.activate
def test_converter_start_200_response(auth_client_mock: Mock) -> None:
auth_client_mock.fetch_personal_token.return_value = fake_token
def test_converter_start_200_response_on_dapla_lab() -> None:
job_config: dict[str, str] = {}
responses.add(
responses.POST,
Expand All @@ -127,10 +135,53 @@ def test_converter_start_200_response(auth_client_mock: Mock) -> None:
assert json_str["jobId"] == json.loads(sample_response_start_job)["jobId"]


@mock.patch("dapla.auth.AuthClient")
@mock.patch.dict(
"dapla.auth.os.environ",
{
"DAPLA_SERVICE": "JUPYTERLAB",
"DAPLA_REGION": "BIP",
"LOCAL_USER_PATH": auth_endpoint_url,
},
clear=True,
)
@responses.activate
def test_converter_start_simulation_200_response(auth_client_mock: Mock) -> None:
auth_client_mock.fetch_personal_token.return_value = fake_token
def test_converter_start_200_response_on_jupyterhub() -> None:
mock_response = {
"access_token": "fake_access_token",
}
responses.add(responses.GET, auth_endpoint_url, json=mock_response, status=200)

job_config: dict[str, str] = {}

responses.add(
responses.POST,
"https://mock-converter.no/jobs",
json=sample_response_start_job,
status=200,
)
client = ConverterClient(converter_test_url)
response = client.start(job_config)
json_str = json.loads(response.json())

assert json_str["jobId"] == json.loads(sample_response_start_job)["jobId"]


@mock.patch.dict(
"dapla.auth.os.environ",
{
"DAPLA_SERVICE": "JUPYTERLAB",
"DAPLA_REGION": "BIP",
"LOCAL_USER_PATH": auth_endpoint_url,
},
clear=True,
)
@responses.activate
def test_converter_start_simulation_200_response() -> None:
mock_response = {
"access_token": "fake_access_token",
}
responses.add(responses.GET, auth_endpoint_url, json=mock_response, status=200)

job_config: dict[str, str] = {}
responses.add(
responses.POST,
Expand All @@ -145,10 +196,21 @@ def test_converter_start_simulation_200_response(auth_client_mock: Mock) -> None
assert json_str["jobId"] == json.loads(sample_response_start_job)["jobId"]


@mock.patch("dapla.auth.AuthClient")
@mock.patch.dict(
"dapla.auth.os.environ",
{
"DAPLA_SERVICE": "JUPYTERLAB",
"DAPLA_REGION": "BIP",
"LOCAL_USER_PATH": auth_endpoint_url,
},
clear=True,
)
@responses.activate
def test_converter_get_job_summary_200_response(auth_client_mock: Mock) -> None:
auth_client_mock.fetch_personal_token.return_value = fake_token
def test_converter_get_job_summary_200_response() -> None:
mock_response = {
"access_token": "fake_access_token",
}
responses.add(responses.GET, auth_endpoint_url, json=mock_response, status=200)
responses.add(
responses.GET,
"https://mock-converter.no/jobs/01FZWP8R3PHDYD5QQS4CY1RKBW/execution-summary",
Expand All @@ -161,10 +223,21 @@ def test_converter_get_job_summary_200_response(auth_client_mock: Mock) -> None:
assert json.loads(response.json()) == json.loads(sample_response_get_job_summary)


@mock.patch("dapla.auth.AuthClient")
@mock.patch.dict(
"dapla.auth.os.environ",
{
"DAPLA_SERVICE": "JUPYTERLAB",
"DAPLA_REGION": "BIP",
"LOCAL_USER_PATH": auth_endpoint_url,
},
clear=True,
)
@responses.activate
def test_converter_stop_job_200_response(auth_client_mock: Mock) -> None:
auth_client_mock.fetch_personal_token.return_value = fake_token
def test_converter_stop_job_200_response() -> None:
mock_response = {
"access_token": "fake_access_token",
}
responses.add(responses.GET, auth_endpoint_url, json=mock_response, status=200)
responses.add(
responses.POST,
"https://mock-converter.no/jobs/01FZWP8R3PHDYD5QQS4CY1RKBW/stop",
Expand All @@ -177,10 +250,21 @@ def test_converter_stop_job_200_response(auth_client_mock: Mock) -> None:
assert response.status_code == 200


@mock.patch("dapla.auth.AuthClient")
@mock.patch.dict(
"dapla.auth.os.environ",
{
"DAPLA_SERVICE": "JUPYTERLAB",
"DAPLA_REGION": "BIP",
"LOCAL_USER_PATH": auth_endpoint_url,
},
clear=True,
)
@responses.activate
def test_converter_get_pseudo_report_200_response(auth_client_mock: Mock) -> None:
auth_client_mock.fetch_personal_token.return_value = fake_token
def test_converter_get_pseudo_report_200_response() -> None:
mock_response = {
"access_token": "fake_access_token",
}
responses.add(responses.GET, auth_endpoint_url, json=mock_response, status=200)
responses.add(
responses.GET,
"https://mock-converter.no/jobs/01FZWP8R3PHDYD5QQS4CY1RKBW/reports/pseudo",
Expand All @@ -193,10 +277,21 @@ def test_converter_get_pseudo_report_200_response(auth_client_mock: Mock) -> Non
assert json.loads(response.json()) == json.loads(sample_response_pseudo_report)


@mock.patch("dapla.auth.AuthClient")
@mock.patch.dict(
"dapla.auth.os.environ",
{
"DAPLA_SERVICE": "JUPYTERLAB",
"DAPLA_REGION": "BIP",
"LOCAL_USER_PATH": auth_endpoint_url,
},
clear=True,
)
@responses.activate
def test_converter_get_pseudo_schema_200_response(auth_client_mock: Mock) -> None:
auth_client_mock.fetch_personal_token.return_value = fake_token
def test_converter_get_pseudo_schema_200_response() -> None:
mock_response = {
"access_token": "fake_access_token",
}
responses.add(responses.GET, auth_endpoint_url, json=mock_response, status=200)
responses.add(
responses.GET,
"https://mock-converter.no/jobs/01FZWP8R3PHDYD5QQS4CY1RKBW/reports/pseudo-schema-hierarchy",
Expand Down
Loading

0 comments on commit 7ffc94d

Please sign in to comment.