Skip to content

Commit

Permalink
Merge pull request #193 from statisticsnorway/auth-gcs-file-system
Browse files Browse the repository at this point in the history
Use ADC for GCSFileSystem when applicable
  • Loading branch information
mallport authored Dec 6, 2024
2 parents c805280 + cd7aa10 commit dff00b0
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 91 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dapla-toolbelt"
version = "3.2.0"
version = "3.2.1"
description = "Dapla Toolbelt"
authors = ["Dapla Developers <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -52,7 +52,7 @@ pytest = ">=7.1.2"
responses = ">=0.24.0"
types-requests = ">=2.28.11"
pyarrow-stubs = ">=10.0.1.7"
google-auth-stubs = ">=0.2.0" # Not maintained by Google, should change if Google releases their own stubs
google-auth-stubs = ">=0.2.0" # Not maintained by Google, should change if Google releases their own stubs
pandas-stubs = ">=2.0.0"
pytest-timeout = ">=2.3.1"
pytest-mock = ">=3.14.0"
Expand Down
10 changes: 5 additions & 5 deletions src/dapla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@

__all__ = [
"AuthClient",
"details",
"show",
"CollectorClient",
"ConverterClient",
"Doctor",
"FileClient",
"repo_root_dir",
"GuardianClient",
"details",
"generate_api_token",
"get_secret_version",
"read_pandas",
"write_pandas",
"repo_root_dir",
"show",
"trigger_source_data_processing",
"get_secret_version",
"write_pandas",
]


Expand Down
2 changes: 1 addition & 1 deletion src/dapla/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def fetch_personal_token() -> str:
@lru_cache(maxsize=1)
def fetch_email_from_credentials() -> Optional[str]:
"""Retrieves an e-mail based on current Google Credentials. Potentially makes a Google API call."""
if os.getenv("DAPLA_REGION") == str(DaplaRegion.DAPLA_LAB):
if os.getenv("DAPLA_REGION") == DaplaRegion.DAPLA_LAB.value:
return os.getenv("DAPLA_USER")

credentials = AuthClient.fetch_google_credentials()
Expand Down
2 changes: 1 addition & 1 deletion src/dapla/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_gcs_file_system(**kwargs: Any) -> GCSFileSystem:
See https://gcsfs.readthedocs.io/en/latest for advanced usage
"""
return GCSFileSystem(token=AuthClient.fetch_google_credentials(), **kwargs)
return GCSFileSystem(**kwargs)

@staticmethod
def ls(gcs_path: str, detail: bool = False, **kwargs: Any) -> Any:
Expand Down
20 changes: 14 additions & 6 deletions src/dapla/gcs.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import os
import typing as t
from typing import Any
from typing import Optional

import gcsfs
from google.oauth2.credentials import Credentials

from dapla.auth import AuthClient
from dapla.const import DaplaRegion


class GCSFileSystem(gcsfs.GCSFileSystem): # type: ignore [misc]
"""GCSFileSystem is a wrapper around gcsfs.GCSFileSystem."""

def __init__(
self, token: Optional[dict[str, str] | str | Credentials] = None, **kwargs: Any
) -> None:
def __init__(self, **kwargs: Any) -> None:
"""Initialize GCSFileSystem."""
super().__init__(token=token, **kwargs)
if (
os.getenv("DAPLA_REGION") == DaplaRegion.DAPLA_LAB.value
or os.getenv("DAPLA_REGION") == DaplaRegion.CLOUD_RUN.value
):
# When using environments with ADC, return a GCSFS using auth
# from the environment
super().__init__(**kwargs)
else:
super().__init__(token=AuthClient.fetch_google_credentials(), **kwargs)

def isdir(self, path: str) -> bool:
"""Check if path is a directory."""
Expand Down
34 changes: 0 additions & 34 deletions tests/test_backports.py

This file was deleted.

42 changes: 0 additions & 42 deletions tests/test_gcs.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,7 @@
from datetime import timedelta
from unittest.mock import Mock
from unittest.mock import patch

import pytest
from gcsfs.retry import HttpError
from google.auth._helpers import utcnow

from dapla import pandas as dp
from dapla.gcs import GCSFileSystem


def test_instance() -> None:
# Chack that instantiation works with the current version of pyarrow
client = GCSFileSystem()
assert client is not None


@pytest.mark.timeout(
30
) # Times the test out after 30 sec, this is will happen if a deadlock happens
@patch.dict("dapla.auth.os.environ", {"OIDC_TOKEN": "fake-token"}, clear=True)
@patch.dict(
"dapla.auth.os.environ", {"DAPLA_TOOLBELT_FORCE_TOKEN_EXCHANGE": "1"}, clear=True
)
@patch("dapla.auth.AuthClient.fetch_google_token")
def test_gcs_deadlock(mock_fetch_google_token: Mock) -> None:
# When overriding the refresh method we experienced a deadlock, resulting in the credentials never being refreshed
# This test checks that the credentials object is updated on refresh
# and that it proceeds to the next step when a valid token is provided.

mock_fetch_google_token.side_effect = [
("FakeToken1", utcnow()),
("FakeToken2", utcnow()),
("FakeToken3", utcnow()),
("FakeToken4", utcnow()),
("FakeToken5Valid", utcnow() + timedelta(seconds=30)),
]

gcs_path = "gs://ssb-dapla-pseudo-data-produkt-test/integration_tests_data/personer.parquet"
with pytest.raises(
HttpError
) as exc_info: # Since we supply invalid credentials an error should be raised
dp.read_pandas(gcs_path)
assert "Invalid Credentials" in str(exc_info.value)
assert (
mock_fetch_google_token.call_count == 5
) # mock_fetch_google_token is called as part of refresh
# until a token that has not expired is returned

0 comments on commit dff00b0

Please sign in to comment.