From 71f49a5d2a00ccf5417b923b424177bd617e80a9 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 4 Apr 2024 12:16:04 +0200 Subject: [PATCH] Skip `test_freeu_enabled ` on MPS (#7570) * Skip `test_freeu_enabled ` on MPS * Small fixes - import skip_mps correctly - disable all instances of test_freeu_enabled * Empty commit to trigger tests * Empty commit to trigger CI --- tests/pipelines/stable_diffusion/test_stable_diffusion.py | 3 +++ tests/pipelines/test_pipelines_common.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 9a71cc462b10..bb3869947f12 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -54,6 +54,7 @@ require_torch_2, require_torch_gpu, run_test_in_subprocess, + skip_mps, slow, torch_device, ) @@ -639,6 +640,8 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) + # MPS currently doesn't support ComplexFloats, which are required for freeU - see https://github.com/huggingface/diffusers/issues/7569. + @skip_mps def test_freeu_enabled(self): components = self.get_dummy_components() sd_pipe = StableDiffusionPipeline(**components) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f0e6818bfc2b..73a171797036 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -39,7 +39,7 @@ from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available -from diffusers.utils.testing_utils import CaptureLogger, require_torch, torch_device +from diffusers.utils.testing_utils import CaptureLogger, require_torch, skip_mps, torch_device from ..models.autoencoders.test_models_vae import ( get_asym_autoencoder_kl_config, @@ -125,6 +125,8 @@ def test_vae_tiling(self): zeros = torch.zeros(shape).to(torch_device) pipe.vae.decode(zeros) + # MPS currently doesn't support ComplexFloats, which are required for freeU - see https://github.com/huggingface/diffusers/issues/7569. + @skip_mps def test_freeu_enabled(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components)