Skip to content

Commit

Permalink
TEST: Change mocking behaviour away from aiohttp
Browse files Browse the repository at this point in the history
  • Loading branch information
matthew-wong-bcg committed Sep 23, 2024
1 parent fa8272d commit 9747bbf
Showing 1 changed file with 12 additions and 26 deletions.
38 changes: 12 additions & 26 deletions test/artkit_test/model/llm/huggingface_tests/test_hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,39 +126,25 @@ async def test_huggingface_retry(
zephyr_7b_tokenizer: PreTrainedTokenizerBase,
caplog: pytest.LogCaptureFixture,
) -> None:
with patch("aiohttp.ClientSession") as MockClientSession:

# Mock the response object
mock_post = AsyncMock()
mock_post.read.return_value = b'[{"generated_text": "blue"}]'
mock_post.json.return_value = {"error": "Rate limit exceeded"}
mock_post.return_value.status = 429

# Request info mock
request_info = AsyncMock()
request_info.real_url = (
"https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
)
# Mock the get_client method to return a mock client
with patch.object(hugging_face_chat, "get_client") as mock_get_client:
mock_client = AsyncMock()
mock_get_client.return_value = mock_client

def f() -> None:
# Define the side effect function for text_generation
async def mock_text_generation(*args: Any, **kwargs: Any) -> None:
request_info = Mock()
request_info.real_url = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
err = ClientResponseError(
request_info=request_info,
history=AsyncMock(),
history=(),
status=429,
message="Rate limit exceeded",
)
raise err

mock_post.raise_for_status = f

# Set up the mock connection object
mock_connection = AsyncMock()
mock_connection.post.return_value = mock_post
MockClientSession.return_value = mock_connection

# Mock session close to prevent recursive close calls
mock_close = AsyncMock()
mock_connection.close = mock_close # Avoid recursion on session close
# Set the side effect
mock_client.text_generation.side_effect = mock_text_generation

# Get number of awaited retries
n_retries = hugging_face_chat.max_retries
Expand All @@ -172,7 +158,7 @@ def f() -> None:
"Your job is to answer a quiz question with a single word, "
"lowercase, with no punctuation"
).get_response(message="What color is the sky?")
assert mock_connection.post.call_count == n_retries
assert mock_client.text_generation.call_count == n_retries

assert (
len(
Expand Down

0 comments on commit 9747bbf

Please sign in to comment.