diff --git a/pyproject.toml b/pyproject.toml index 8e497c3..1f141ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dapla-toolbelt" -version = "3.2.0" +version = "3.2.1" description = "Dapla Toolbelt" authors = ["Dapla Developers "] license = "MIT" @@ -52,7 +52,7 @@ pytest = ">=7.1.2" responses = ">=0.24.0" types-requests = ">=2.28.11" pyarrow-stubs = ">=10.0.1.7" -google-auth-stubs = ">=0.2.0" # Not maintained by Google, should change if Google releases their own stubs +google-auth-stubs = ">=0.2.0" # Not maintained by Google, should change if Google releases their own stubs pandas-stubs = ">=2.0.0" pytest-timeout = ">=2.3.1" pytest-mock = ">=3.14.0" diff --git a/src/dapla/__init__.py b/src/dapla/__init__.py index 87a31a7..fe1d722 100644 --- a/src/dapla/__init__.py +++ b/src/dapla/__init__.py @@ -22,19 +22,19 @@ __all__ = [ "AuthClient", - "details", - "show", "CollectorClient", "ConverterClient", "Doctor", "FileClient", - "repo_root_dir", "GuardianClient", + "details", "generate_api_token", + "get_secret_version", "read_pandas", - "write_pandas", + "repo_root_dir", + "show", "trigger_source_data_processing", - "get_secret_version", + "write_pandas", ] diff --git a/src/dapla/auth.py b/src/dapla/auth.py index 3534cdb..2b5f4e8 100644 --- a/src/dapla/auth.py +++ b/src/dapla/auth.py @@ -323,7 +323,7 @@ def fetch_personal_token() -> str: @lru_cache(maxsize=1) def fetch_email_from_credentials() -> Optional[str]: """Retrieves an e-mail based on current Google Credentials. Potentially makes a Google API call.""" - if os.getenv("DAPLA_REGION") == str(DaplaRegion.DAPLA_LAB): + if os.getenv("DAPLA_REGION") == DaplaRegion.DAPLA_LAB.value: return os.getenv("DAPLA_USER") credentials = AuthClient.fetch_google_credentials() diff --git a/src/dapla/files.py b/src/dapla/files.py index 12b0537..884f29a 100644 --- a/src/dapla/files.py +++ b/src/dapla/files.py @@ -49,7 +49,7 @@ def get_gcs_file_system(**kwargs: Any) -> GCSFileSystem: See https://gcsfs.readthedocs.io/en/latest for advanced usage """ - return GCSFileSystem(token=AuthClient.fetch_google_credentials(), **kwargs) + return GCSFileSystem(**kwargs) @staticmethod def ls(gcs_path: str, detail: bool = False, **kwargs: Any) -> Any: diff --git a/src/dapla/gcs.py b/src/dapla/gcs.py index 112903b..b24ef1b 100644 --- a/src/dapla/gcs.py +++ b/src/dapla/gcs.py @@ -1,19 +1,27 @@ +import os import typing as t from typing import Any -from typing import Optional import gcsfs -from google.oauth2.credentials import Credentials + +from dapla.auth import AuthClient +from dapla.const import DaplaRegion class GCSFileSystem(gcsfs.GCSFileSystem): # type: ignore [misc] """GCSFileSystem is a wrapper around gcsfs.GCSFileSystem.""" - def __init__( - self, token: Optional[dict[str, str] | str | Credentials] = None, **kwargs: Any - ) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize GCSFileSystem.""" - super().__init__(token=token, **kwargs) + if ( + os.getenv("DAPLA_REGION") == DaplaRegion.DAPLA_LAB.value + or os.getenv("DAPLA_REGION") == DaplaRegion.CLOUD_RUN.value + ): + # When using environments with ADC, return a GCSFS using auth + # from the environment + super().__init__(**kwargs) + else: + super().__init__(token=AuthClient.fetch_google_credentials(), **kwargs) def isdir(self, path: str) -> bool: """Check if path is a directory.""" diff --git a/tests/test_backports.py b/tests/test_backports.py deleted file mode 100644 index a698ff6..0000000 --- a/tests/test_backports.py +++ /dev/null @@ -1,34 +0,0 @@ -from unittest import mock -from unittest.mock import Mock - -from dapla.backports import show -from dapla.gcs import GCSFileSystem - - -@mock.patch("dapla.backports.FileClient") -def test_show_all_subfolders(file_client_mock: Mock) -> None: - file_client_mock.get_gcs_file_system.return_value = GCSFileSystem() - result = show("gs://anaconda-public-data/nyc-taxi/") - assert result == [ - "/nyc-taxi/2015.parquet", - "/nyc-taxi/csv", - "/nyc-taxi/csv/2014", - "/nyc-taxi/csv/2015", - "/nyc-taxi/csv/2016", - "/nyc-taxi/nyc.parquet", - "/nyc-taxi/taxi.parquet", - ] - - -@mock.patch("dapla.backports.FileClient") -def test_show_leaf_folder(file_client_mock: Mock) -> None: - file_client_mock.get_gcs_file_system.return_value = GCSFileSystem() - result = show("gs://anaconda-public-data/nyc-taxi/csv/2014") - assert result == ["/nyc-taxi/csv/2014"] - - -@mock.patch("dapla.backports.FileClient") -def test_show_invalid_folder(file_client_mock: Mock) -> None: - file_client_mock.get_gcs_file_system.return_value = GCSFileSystem() - result = show("gs://anaconda-public-data/nyc-taxi/unknown") - assert result == [] diff --git a/tests/test_gcs.py b/tests/test_gcs.py index c9cf5eb..b6a5573 100644 --- a/tests/test_gcs.py +++ b/tests/test_gcs.py @@ -1,12 +1,3 @@ -from datetime import timedelta -from unittest.mock import Mock -from unittest.mock import patch - -import pytest -from gcsfs.retry import HttpError -from google.auth._helpers import utcnow - -from dapla import pandas as dp from dapla.gcs import GCSFileSystem @@ -14,36 +5,3 @@ def test_instance() -> None: # Chack that instantiation works with the current version of pyarrow client = GCSFileSystem() assert client is not None - - -@pytest.mark.timeout( - 30 -) # Times the test out after 30 sec, this is will happen if a deadlock happens -@patch.dict("dapla.auth.os.environ", {"OIDC_TOKEN": "fake-token"}, clear=True) -@patch.dict( - "dapla.auth.os.environ", {"DAPLA_TOOLBELT_FORCE_TOKEN_EXCHANGE": "1"}, clear=True -) -@patch("dapla.auth.AuthClient.fetch_google_token") -def test_gcs_deadlock(mock_fetch_google_token: Mock) -> None: - # When overriding the refresh method we experienced a deadlock, resulting in the credentials never being refreshed - # This test checks that the credentials object is updated on refresh - # and that it proceeds to the next step when a valid token is provided. - - mock_fetch_google_token.side_effect = [ - ("FakeToken1", utcnow()), - ("FakeToken2", utcnow()), - ("FakeToken3", utcnow()), - ("FakeToken4", utcnow()), - ("FakeToken5Valid", utcnow() + timedelta(seconds=30)), - ] - - gcs_path = "gs://ssb-dapla-pseudo-data-produkt-test/integration_tests_data/personer.parquet" - with pytest.raises( - HttpError - ) as exc_info: # Since we supply invalid credentials an error should be raised - dp.read_pandas(gcs_path) - assert "Invalid Credentials" in str(exc_info.value) - assert ( - mock_fetch_google_token.call_count == 5 - ) # mock_fetch_google_token is called as part of refresh - # until a token that has not expired is returned