From 8e5fea60227fb9e5141109a13e78bcd16e2378d2 Mon Sep 17 00:00:00 2001 From: aldbr Date: Tue, 26 Nov 2024 18:16:20 +0100 Subject: [PATCH] fix: simplify tests and make them more deterministic --- diracx-client/tests/test_auth.py | 342 ++++++++++++++----------------- diracx-core/tests/test_utils.py | 70 +++---- 2 files changed, 185 insertions(+), 227 deletions(-) diff --git a/diracx-client/tests/test_auth.py b/diracx-client/tests/test_auth.py index 5d3e38b1..75cb05af 100644 --- a/diracx-client/tests/test_auth.py +++ b/diracx-client/tests/test_auth.py @@ -1,11 +1,8 @@ import fcntl import json -from datetime import datetime, time, timedelta, timezone -from multiprocessing import Pool -from pathlib import Path -from tempfile import NamedTemporaryFile -from unittest.mock import patch +from datetime import datetime, timedelta, timezone +import jwt import pytest from azure.core.credentials import AccessToken @@ -13,117 +10,72 @@ from diracx.core.models import TokenResponse from diracx.core.utils import serialize_credentials +# Create a fake jwt dictionary +REFRESH_CONTENT = { + "jti": "f0706e0a-af1e-4538-9f1f-7b9620783cba", + "exp": int((datetime.now(tz=timezone.utc) + timedelta(days=1)).timestamp()), + "legacy_exchange": False, + "dirac_policies": {}, +} + TOKEN_RESPONSE_DICT = { "access_token": "test_token", - "expires_in": int((datetime.now(tz=timezone.utc) + timedelta(days=1)).timestamp()), + "expires_in": 3600, "token_type": "Bearer", - "refresh_token": "test_refresh", + "refresh_token": jwt.encode(REFRESH_CONTENT, "secret_key"), } CREDENTIALS_CONTENT: str = serialize_credentials(TokenResponse(**TOKEN_RESPONSE_DICT)) -def lock_and_read_file(file_path): - """Lock and read file.""" - with open(file_path, "r") as f: - fcntl.flock(f, fcntl.LOCK_SH) - f.read() - time.sleep(2) - fcntl.flock(f, fcntl.LOCK_UN) - - -def lock_and_write_file(file_path: Path): - """Lock and write file.""" - with open(file_path, "a") as f: - fcntl.flock(f, fcntl.LOCK_EX) - f.write(CREDENTIALS_CONTENT) - time.sleep(2) - fcntl.flock(f, fcntl.LOCK_UN) - - -@pytest.fixture -def concurrent_access_to_lock_file(): - - def run_processes(proc_to_test, *, read=True): - """Run the process to be tested and attempt to read or write concurrently.""" - location = proc_to_test[1]["location"] - error_dict = dict() - with Pool(2) as pool: - if read: - # Creating the file before reading it - with open(location, "w") as f: - f.write(CREDENTIALS_CONTENT) - pool.apply_async( - lock_and_read_file, - args=(location,), - error_callback=lambda e: error_callback( - e, error_dict, "lock_and_read_file" - ), - ) - else: - pool.apply_async( - lock_and_write_file, - args=(location,), - error_callback=lambda e: error_callback( - e, error_dict, "lock_and_write_file" - ), - ) - time.sleep(1) - result = pool.apply_async( - proc_to_test[0], - kwds=proc_to_test[1], - error_callback=lambda e: error_callback( - e, error_dict, f"{proc_to_test[0].__name__}" - ), - ) - pool.close() - pool.join() - res = result.get(timeout=1) - return res, error_dict - - return run_processes - - -def error_callback(error, error_dict, process_name): - """Called if the process fails.""" - error_dict[process_name] = error - - -@pytest.fixture -def token_setup() -> tuple[TokenResponse, Path, AccessToken]: - """Setup token response and location.""" - with NamedTemporaryFile(delete=False) as tmp: - token_location = Path(tmp.name) - token_response = TokenResponse(**TOKEN_RESPONSE_DICT) - access_token = AccessToken(token_response.access_token, token_response.expires_in) +def test_get_token_accessing_lock_file(monkeypatch, tmp_path): + """Test get_token is waiting to read token from locked file.""" + token_location = tmp_path / "credentials.json" + token_location.write_text(CREDENTIALS_CONTENT) - yield token_response, token_location, access_token + # Patch 'fcntl.flock' within the 'diracx.client.patches.utils' module + flock_calls = [] - if token_location.exists(): - token_location.unlink() + def mock_flock(file, operation): + flock_calls.append((file, operation)) + if operation == fcntl.LOCK_EX: + raise BlockingIOError("File is locked") + monkeypatch.setattr("diracx.client.patches.utils.fcntl.flock", mock_flock) -def test_get_token_accessing_lock_file(token_setup, concurrent_access_to_lock_file): - """Test get_token is waiting to read token from locked file.""" - token_response, token_location, _ = token_setup - process_to_test = ( - get_token, - { - "location": token_location, - "token": None, - "token_endpoint": "/endpoint", - "client_id": "ID", - "verify": False, - }, - ) - result, error_dict = concurrent_access_to_lock_file(process_to_test, read=False) - assert not error_dict - assert isinstance(result, AccessToken) - assert result.token == token_response.access_token + # Attempt to get a token, expecting a BlockingIOError due to the lock + with pytest.raises(BlockingIOError) as exc_info: + get_token( + location=token_location, + token=None, + token_endpoint="/endpoint", + client_id="ID", + verify=False, + ) + + # Verify that flock was called with LOCK_EX + assert len(flock_calls) == 1, "fcntl.flock was not called" + assert ( + flock_calls[-1][1] == fcntl.LOCK_EX + ), f"Expected LOCK_SH, got {flock_calls[-1][1]}" + assert "File is locked" in str(exc_info.value) -def test_get_token_valid_input_token(token_setup): - """Test that get_token return the valid token.""" - _, token_location, access_token = token_setup +def test_get_token_valid_input_token(tmp_path): + """Test that get_token return the valid provided token.""" + token_location = tmp_path / "credentials.json" + # Create a valid access token + token_response = TokenResponse(**TOKEN_RESPONSE_DICT) + access_token = AccessToken( + token_response.access_token, + int( + ( + datetime.now(tz=timezone.utc) + + timedelta(seconds=token_response.expires_in) + ).timestamp() + ), + ) + + # Call get_token result = get_token( location=token_location, token=access_token, @@ -135,20 +87,31 @@ def test_get_token_valid_input_token(token_setup): assert result == access_token -def test_get_token_valid_input_credential(): +def test_get_token_valid_input_credential(tmp_path): """Test that get_token return the valid token given in the credential file.""" - with NamedTemporaryFile(delete=False) as tmp: - tmp.write(CREDENTIALS_CONTENT.encode()) - temp_file = Path(tmp.name) + token_location = tmp_path / "credentials.json" + token_location.write_text(CREDENTIALS_CONTENT) + + # Call get_token result = get_token( - location=temp_file, token=None, token_endpoint="", client_id="ID", verify=False + location=token_location, + token=None, + token_endpoint="", + client_id="ID", + verify=False, ) - temp_file.unlink() + + # Verify that the returned token is the expected token assert isinstance(result, AccessToken) + assert result.token == TOKEN_RESPONSE_DICT["access_token"] + assert result.expires_on > datetime.now(tz=timezone.utc).timestamp() + +def test_get_token_input_token_not_exists(tmp_path): + """Test that get_token return an empty token when the provided token does not exist.""" + token_location = tmp_path / "credentials.json" -def test_get_token_input_token_not_exists(token_setup): - _, token_location, _ = token_setup + # Call get_token result = get_token( location=token_location, token=None, @@ -156,108 +119,111 @@ def test_get_token_input_token_not_exists(token_setup): client_id="ID", verify=False, ) - assert result is None + assert isinstance(result, AccessToken) + assert result.token == "" + assert result.expires_on == 0 -def test_get_token_invalid_input(): - """Test that get_token manage invalid input token.""" +def test_get_token_invalid_input(tmp_path): + """Test that get_token manages invalid input token.""" # Test wrong key in credential - wrong_credential_content = "'{\"wrong_key\": False}'" - with NamedTemporaryFile(delete=False) as tmp: - tmp.write(json.dumps(wrong_credential_content).encode()) - temp_file = Path(tmp.name) - result = get_token( - location=temp_file, token=None, token_endpoint="", client_id="ID", verify=False - ) - temp_file.unlink() - assert result is None + wrong_credential_content = {"wrong_key": False} - # Test with invalid token date - token_response = TOKEN_RESPONSE_DICT.copy() - token_response["expires_in"] = int(datetime.now(tz=timezone.utc).timestamp()) - with NamedTemporaryFile(delete=False) as tmp: - tmp.write(json.dumps(token_response).encode()) - temp_file = Path(tmp.name) + token_location = tmp_path / "credentials.json" + token_location.write_text(json.dumps(wrong_credential_content)) + # Call get_token result = get_token( - location=temp_file, token=None, token_endpoint="", client_id="ID", verify=False + location=token_location, + token=None, + token_endpoint="", + client_id="ID", + verify=False, ) - temp_file.unlink() - assert result is None + # Verify that the returned token is empty + assert isinstance(result, AccessToken) + assert result.token == "" + assert result.expires_on == 0 -def test_get_token_refresh_valid(): +def test_get_token_refresh_valid(monkeypatch, tmp_path): """Test that get_token refresh a valid outdated token.""" token_response = TOKEN_RESPONSE_DICT.copy() - # the future content of the refreshed token - refresh_token = TokenResponse(**token_response) - # Create expired credential file + # Expected future content of the refreshed token + expected_token_response = TokenResponse(**token_response) + + # Create expired access token token_response["expires_on"] = int( (datetime.now(tz=timezone.utc) - timedelta(seconds=10)).timestamp() ) token_response.pop("expires_in") - with NamedTemporaryFile(delete=False) as tmp: - tmp.write(json.dumps(token_response).encode()) - temp_file = Path(tmp.name) - - with ( - patch( - "diracx.client.patches.utils.is_refresh_token_valid", return_value=True - ) as mock_is_refresh_valid, - patch( - "diracx.client.patches.utils.refresh_token", return_value=refresh_token - ) as mock_refresh_token, - ): - result = get_token( - location=temp_file, - token=None, - token_endpoint="", - client_id="ID", - verify=False, - ) - # Verify that the credential fil has been refreshed: - with open(temp_file, "r") as f: - content = f.read() - assert content == serialize_credentials(refresh_token) + # Write expired credentials to a file + token_location = tmp_path / "credentials.json" + token_location.write_text(json.dumps(token_response)) + + # Mock the refresh_token function + was_refresh_called = False + + def mock_refresh(token_endpoint, client_id, refresh_token, verify): + nonlocal was_refresh_called + was_refresh_called = True + return TokenResponse(**TOKEN_RESPONSE_DICT) + + monkeypatch.setattr("diracx.client.patches.utils.refresh_token", mock_refresh) + + # Call get_token + result = get_token( + location=token_location, + token=None, + token_endpoint="", + client_id="ID", + verify=False, + ) - temp_file.unlink() + # Verify that the credential file has been refreshed: + with open(token_location, "r") as f: + content = f.read() + assert content == serialize_credentials(expected_token_response) + # Verify that the returned token is the expected refreshed token assert result is not None assert isinstance(result, AccessToken) - assert result.token == refresh_token.access_token - assert result.expires_on > refresh_token.expires_in - mock_is_refresh_valid.assert_called_once_with(refresh_token.refresh_token) - mock_refresh_token.assert_called_once_with( - "", "ID", refresh_token.refresh_token, verify=False - ) + assert result.token == expected_token_response.access_token + assert result.expires_on > token_response["expires_on"] + assert was_refresh_called -def test_get_token_refresh_invalid(): - """Test that get_token manages an invalid refresh token.""" +def test_get_token_refresh_expired(tmp_path): + """Test that get_token manages an expired refresh token: should return an empty token.""" + # Create expired access token and refresh token token_response = TOKEN_RESPONSE_DICT.copy() - refresh_token = TokenResponse(**token_response) + refresh_token = REFRESH_CONTENT.copy() + + refresh_token["exp"] = int( + (datetime.now(tz=timezone.utc) - timedelta(seconds=10)).timestamp() + ) + token_response["expires_on"] = int( (datetime.now(tz=timezone.utc) - timedelta(seconds=10)).timestamp() ) token_response.pop("expires_in") - with NamedTemporaryFile(delete=False) as tmp: - tmp.write(json.dumps(token_response).encode()) - temp_file = Path(tmp.name) - - with ( - patch( - "diracx.client.patches.utils.is_refresh_token_valid", return_value=False - ) as mock_is_refresh_valid, - ): - result = get_token( - location=temp_file, - token=None, - token_endpoint="", - client_id="ID", - verify=False, - ) + token_response["refresh_token"] = jwt.encode(refresh_token, "secret_key") - temp_file.unlink() - assert result is None - mock_is_refresh_valid.assert_called_once_with(refresh_token.refresh_token) + # Write expired credentials to a file + token_location = tmp_path / "credentials.json" + token_location.write_text(json.dumps(token_response)) + + # Call get_token + result = get_token( + location=token_location, + token=None, + token_endpoint="", + client_id="ID", + verify=False, + ) + + # Verify that the returned token is empty + assert isinstance(result, AccessToken) + assert result.token == "" + assert result.expires_on == 0 diff --git a/diracx-core/tests/test_utils.py b/diracx-core/tests/test_utils.py index 4ccabc9f..aebd55ad 100644 --- a/diracx-core/tests/test_utils.py +++ b/diracx-core/tests/test_utils.py @@ -3,7 +3,6 @@ import fcntl from datetime import datetime, timedelta, timezone from pathlib import Path -from tempfile import NamedTemporaryFile import pytest @@ -48,27 +47,19 @@ def test_dotenv_files_from_environment(monkeypatch): CREDENTIALS_CONTENT: str = serialize_credentials(TokenResponse(**TOKEN_RESPONSE_DICT)) -@pytest.fixture -def token_setup() -> tuple[TokenResponse, Path]: - """Setup token response and location.""" - with NamedTemporaryFile(delete=False) as tmp: - token_location = Path(tmp.name) - token_response = TokenResponse(**TOKEN_RESPONSE_DICT) - - yield token_response, token_location - - if token_location.exists(): - token_location.unlink() +def test_read_credentials_reading_locked_file(monkeypatch, tmp_path): + """Test that read_credentials waits to read a locked file. - -def test_read_credentials_reading_locked_file(monkeypatch, token_setup): - """Test that read_credentials is waiting to read a locked file end in error.""" - _, token_location = token_setup + To keep the test simple and deterministic, we patch 'fcntl.flock' within the 'diracx.core.utils' module. + This will raise a BlockingIOError when attempting to read the file. + """ + token_location = tmp_path / "credentials.json" # Write valid credentials to the file to ensure read_credentials attempts to lock token_location.write_text(CREDENTIALS_CONTENT) # Patch 'fcntl.flock' within the 'diracx.core.utils' module + # This will raise a BlockingIOError when attempting to read the file flock_calls = [] def mock_flock(file, operation): @@ -90,9 +81,13 @@ def mock_flock(file, operation): assert "File is locked" in str(exc_info.value) -def test_write_credentials_writing_locked_file(monkeypatch, token_setup): - """Test that write_credentials is waiting to write a locked file end in error.""" - token_response, token_location = token_setup +def test_write_credentials_writing_locked_file(monkeypatch, tmp_path): + """Test that write_credentials waits to write into a locked file. + + To keep the test simple and deterministic, we patch 'fcntl.flock' within the 'diracx.core.utils' module. + This will raise a BlockingIOError when attempting to write to the file. + """ + token_location = tmp_path / "credentials.json" # Write valid credentials to the file to ensure write_credentials attempts to lock token_location.write_text(CREDENTIALS_CONTENT) @@ -109,7 +104,7 @@ def mock_flock(file, operation): # Attempt to write credentials, expecting a BlockingIOError due to the lock with pytest.raises(BlockingIOError) as exc_info: - write_credentials(token_response, location=token_location) + write_credentials(TokenResponse(**TOKEN_RESPONSE_DICT), location=token_location) # Verify that flock was called (for LOCK_EX) assert len(flock_calls) == 1, "fcntl.flock was not called" @@ -119,22 +114,21 @@ def mock_flock(file, operation): assert "File is locked" in str(exc_info.value) -def test_read_credentials_empty_file(): +def test_read_credentials_empty_file(tmp_path): """Test that read_credentials raises an appropriate error for an empty file.""" - with NamedTemporaryFile(delete=False) as empty_file: - token_location = Path(empty_file.name) + empty_location = tmp_path / "credentials.json" + empty_location.touch() with pytest.raises(RuntimeError) as exc_info: - read_credentials(location=token_location) + read_credentials(location=empty_location) - token_location.unlink() assert "Error reading credentials:" in str(exc_info.value) assert "Expecting value" in str(exc_info.value) -def test_read_credentials_missing_file(): +def test_read_credentials_missing_file(tmp_path): """Test that read_credentials raises an appropriate error for a missing file.""" - missing_file = Path("/path/to/nonexistent/file.txt") + missing_file = tmp_path / "missing.json" with pytest.raises(RuntimeError) as exc_info: read_credentials(location=missing_file) @@ -142,36 +136,34 @@ def test_read_credentials_missing_file(): assert "No such file or directory" in str(exc_info.value) -def test_write_credentials_unavailable_path(token_setup): +def test_write_credentials_unavailable_path(): """Test that write_credentials raises error when it can't create path.""" wrong_path = Path("/wrong/path/file.txt") - token_response, _ = token_setup with pytest.raises(PermissionError): - write_credentials(token_response, location=wrong_path) + write_credentials(TokenResponse(**TOKEN_RESPONSE_DICT), location=wrong_path) -def test_read_credentials_invalid_content(): +def test_read_credentials_invalid_content(tmp_path): """Test that read_credentials raises an appropriate error for a file with invalid content.""" - with NamedTemporaryFile(delete=False) as invalid_file: - invalid_file.write(b"invalid content") - token_location = Path(invalid_file.name) + malformed_token_location = tmp_path / "credentials.json" + malformed_token_location.write_text("invalid content") with pytest.raises(RuntimeError) as exc_info: - read_credentials(location=token_location) + read_credentials(location=malformed_token_location) - token_location.unlink() assert "Error reading credentials:" in str(exc_info.value) assert "Expecting value" in str(exc_info.value) -def test_read_credentials_valid_file(token_setup): +def test_read_credentials_valid_file(tmp_path): """Test that read_credentials works correctly with a valid file.""" - token_response, token_location = token_setup + token_location = tmp_path / "credentials.json" token_location.write_text(CREDENTIALS_CONTENT) credentials = read_credentials(location=token_location) - token_location.unlink() + token_response = TokenResponse(**TOKEN_RESPONSE_DICT) + assert credentials.access_token == token_response.access_token assert credentials.expires_in < token_response.expires_in assert credentials.token_type == token_response.token_type