Skip to content

Commit

Permalink
delete hf cache only for single models in testing
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 committed Nov 8, 2024
1 parent 2a3cc02 commit 6054002
Showing 1 changed file with 40 additions and 8 deletions.
48 changes: 40 additions & 8 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import json
import os
import re
import shutil

import numpy as np
Expand All @@ -11,17 +12,23 @@


@pytest.fixture
def clear_transformers_cache():
def clear_transformers_cache(request):
"""Clears the Transformers model cache after each test when running in a CI."""
yield
clear_transformers_cache_in_ci()
model_name = request.param if hasattr(request, "param") else request.node.nodeid
clear_transformers_cache_in_ci(model_name)


def clear_transformers_cache_in_ci():
def clear_transformers_cache_in_ci(model_name: str):
import transformers

match = re.search(r"\[(.*?)\]", model_name)
path = os.path.join(
transformers.utils.default_cache_path,
f"models--{match.group(1).replace('/', '--')}",
)
if os.environ.get("CI") == "true":
shutil.rmtree(transformers.utils.default_cache_path)
shutil.rmtree(path)


_TRANSFORMERS_TRANSLATION_TESTS = [
Expand Down Expand Up @@ -228,6 +235,9 @@ def test_transformers_generation(


@test_utils.only_on_linux
@pytest.mark.parametrize(
"clear_transformers_cache", ["[facebook/opt-350m]"], indirect=True
)
def test_transformers_dtype(clear_transformers_cache, tmp_dir):
converter = ctranslate2.converters.TransformersConverter("facebook/opt-350m")
output_dir = str(tmp_dir.join("ctranslate2_model"))
Expand All @@ -240,6 +250,9 @@ def test_transformers_dtype(clear_transformers_cache, tmp_dir):


@test_utils.only_on_linux
@pytest.mark.parametrize(
"clear_transformers_cache", ["[Helsinki-NLP/opus-mt-en-de]"], indirect=True
)
def test_transformers_marianmt_vocabulary(clear_transformers_cache, tmp_dir):
converter = ctranslate2.converters.TransformersConverter(
"Helsinki-NLP/opus-mt-en-de"
Expand All @@ -256,6 +269,9 @@ def test_transformers_marianmt_vocabulary(clear_transformers_cache, tmp_dir):

@test_utils.only_on_linux
@pytest.mark.parametrize("beam_size", [1, 2])
@pytest.mark.parametrize(
"clear_transformers_cache", ["[Helsinki-NLP/opus-mt-en-roa]"], indirect=True
)
def test_transformers_marianmt_disable_unk(
clear_transformers_cache, tmp_dir, beam_size
):
Expand All @@ -282,6 +298,16 @@ def test_transformers_marianmt_disable_unk(
"typeform/distilbert-base-uncased-mnli",
],
)
@pytest.mark.parametrize(
"clear_transformers_cache",
[
"[bert-base-uncased]",
"[distilbert-base-uncased]",
"[distilbert-base-cased-distilled-squad]",
"[typeform/distilbert-base-uncased-mnli]",
],
indirect=True,
)
def test_transformers_encoder(clear_transformers_cache, tmp_dir, device, model_name):
import torch
import transformers
Expand Down Expand Up @@ -344,6 +370,11 @@ def _to_numpy(storage, device):


@test_utils.only_on_linux
@pytest.mark.parametrize(
"clear_transformers_cache",
["[hf-internal-testing/tiny-random-GPTBigCodeForCausalLM]"],
indirect=True,
)
def test_transformers_gptbigcode(clear_transformers_cache, tmp_dir):
import transformers

Expand Down Expand Up @@ -385,7 +416,7 @@ def _check_generator_logits(
class TestGeneration:
@classmethod
def teardown_class(cls):
clear_transformers_cache_in_ci()
clear_transformers_cache_in_ci("[gpt2]")

@test_utils.only_on_linux
def test_transformers_lm_scoring(self, tmp_dir):
Expand Down Expand Up @@ -683,7 +714,8 @@ def test_transformers_generator_token_streaming_early_stop(self, tmp_dir):
class TestWhisper:
@classmethod
def teardown_class(cls):
clear_transformers_cache_in_ci()
clear_transformers_cache_in_ci("[openai/whisper-tiny]")
clear_transformers_cache_in_ci("[openai/whisper-tiny.en]")

@test_utils.only_on_linux
@test_utils.on_available_devices
Expand Down Expand Up @@ -948,7 +980,7 @@ def test_transformers_whisper_include_tokenizer_json(self, tmp_dir):
class TestWav2Vec2:
@classmethod
def teardown_class(cls):
clear_transformers_cache_in_ci()
clear_transformers_cache_in_ci("[facebook/wav2vec2-large-robust-ft-swbd-300h]")

@test_utils.only_on_linux
@test_utils.on_available_devices
Expand Down Expand Up @@ -1028,7 +1060,7 @@ def test_transformers_wav2vec2(
class TestWav2Vec2Bert:
@classmethod
def teardown_class(cls):
clear_transformers_cache_in_ci()
clear_transformers_cache_in_ci("[hf-audio/wav2vec2-bert-CV16-en]")

@test_utils.only_on_linux
@test_utils.on_available_devices
Expand Down

0 comments on commit 6054002

Please sign in to comment.