From 6054002107f8bbab0b0e432c8e4be4d26940972e Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Fri, 8 Nov 2024 02:00:38 +0200 Subject: [PATCH] delete hf cache only for single models in testing --- python/tests/test_transformers.py | 48 +++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index 1fed8196d..96852d096 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -1,6 +1,7 @@ import inspect import json import os +import re import shutil import numpy as np @@ -11,17 +12,23 @@ @pytest.fixture -def clear_transformers_cache(): +def clear_transformers_cache(request): """Clears the Transformers model cache after each test when running in a CI.""" yield - clear_transformers_cache_in_ci() + model_name = request.param if hasattr(request, "param") else request.node.nodeid + clear_transformers_cache_in_ci(model_name) -def clear_transformers_cache_in_ci(): +def clear_transformers_cache_in_ci(model_name: str): import transformers + match = re.search(r"\[(.*?)\]", model_name) + path = os.path.join( + transformers.utils.default_cache_path, + f"models--{match.group(1).replace('/', '--')}", + ) if os.environ.get("CI") == "true": - shutil.rmtree(transformers.utils.default_cache_path) + shutil.rmtree(path) _TRANSFORMERS_TRANSLATION_TESTS = [ @@ -228,6 +235,9 @@ def test_transformers_generation( @test_utils.only_on_linux +@pytest.mark.parametrize( + "clear_transformers_cache", ["[facebook/opt-350m]"], indirect=True +) def test_transformers_dtype(clear_transformers_cache, tmp_dir): converter = ctranslate2.converters.TransformersConverter("facebook/opt-350m") output_dir = str(tmp_dir.join("ctranslate2_model")) @@ -240,6 +250,9 @@ def test_transformers_dtype(clear_transformers_cache, tmp_dir): @test_utils.only_on_linux +@pytest.mark.parametrize( + "clear_transformers_cache", ["[Helsinki-NLP/opus-mt-en-de]"], indirect=True +) def test_transformers_marianmt_vocabulary(clear_transformers_cache, tmp_dir): converter = ctranslate2.converters.TransformersConverter( "Helsinki-NLP/opus-mt-en-de" @@ -256,6 +269,9 @@ def test_transformers_marianmt_vocabulary(clear_transformers_cache, tmp_dir): @test_utils.only_on_linux @pytest.mark.parametrize("beam_size", [1, 2]) +@pytest.mark.parametrize( + "clear_transformers_cache", ["[Helsinki-NLP/opus-mt-en-roa]"], indirect=True +) def test_transformers_marianmt_disable_unk( clear_transformers_cache, tmp_dir, beam_size ): @@ -282,6 +298,16 @@ def test_transformers_marianmt_disable_unk( "typeform/distilbert-base-uncased-mnli", ], ) +@pytest.mark.parametrize( + "clear_transformers_cache", + [ + "[bert-base-uncased]", + "[distilbert-base-uncased]", + "[distilbert-base-cased-distilled-squad]", + "[typeform/distilbert-base-uncased-mnli]", + ], + indirect=True, +) def test_transformers_encoder(clear_transformers_cache, tmp_dir, device, model_name): import torch import transformers @@ -344,6 +370,11 @@ def _to_numpy(storage, device): @test_utils.only_on_linux +@pytest.mark.parametrize( + "clear_transformers_cache", + ["[hf-internal-testing/tiny-random-GPTBigCodeForCausalLM]"], + indirect=True, +) def test_transformers_gptbigcode(clear_transformers_cache, tmp_dir): import transformers @@ -385,7 +416,7 @@ def _check_generator_logits( class TestGeneration: @classmethod def teardown_class(cls): - clear_transformers_cache_in_ci() + clear_transformers_cache_in_ci("[gpt2]") @test_utils.only_on_linux def test_transformers_lm_scoring(self, tmp_dir): @@ -683,7 +714,8 @@ def test_transformers_generator_token_streaming_early_stop(self, tmp_dir): class TestWhisper: @classmethod def teardown_class(cls): - clear_transformers_cache_in_ci() + clear_transformers_cache_in_ci("[openai/whisper-tiny]") + clear_transformers_cache_in_ci("[openai/whisper-tiny.en]") @test_utils.only_on_linux @test_utils.on_available_devices @@ -948,7 +980,7 @@ def test_transformers_whisper_include_tokenizer_json(self, tmp_dir): class TestWav2Vec2: @classmethod def teardown_class(cls): - clear_transformers_cache_in_ci() + clear_transformers_cache_in_ci("[facebook/wav2vec2-large-robust-ft-swbd-300h]") @test_utils.only_on_linux @test_utils.on_available_devices @@ -1028,7 +1060,7 @@ def test_transformers_wav2vec2( class TestWav2Vec2Bert: @classmethod def teardown_class(cls): - clear_transformers_cache_in_ci() + clear_transformers_cache_in_ci("[hf-audio/wav2vec2-bert-CV16-en]") @test_utils.only_on_linux @test_utils.on_available_devices