diff --git a/src/coola/utils/tensor.py b/src/coola/utils/tensor.py index ea6548db..f33be1f9 100644 --- a/src/coola/utils/tensor.py +++ b/src/coola/utils/tensor.py @@ -70,9 +70,17 @@ def is_mps_available() -> bool: >>> from coola.utils.tensor import is_mps_available >>> is_mps_available() """ - return ( - is_torch_available() - and hasattr(torch.backends, "mps") - and torch.backends.mps.is_available() - and torch.backends.mps.is_macos13_or_newer() - ) + if not is_torch_available(): + return False + try: + torch.ones(1, device=torch.device("mps")) + return True + except RuntimeError: + return False + # return ( + # is_torch_available() + # and hasattr(torch.backends, "mps") + # and torch.backends.mps.is_available() + # and hasattr(torch.backends.mps, "is_macos13_or_newer") + # and torch.backends.mps.is_macos13_or_newer() + # ) diff --git a/tests/unit/utils/test_tensor.py b/tests/unit/utils/test_tensor.py index d669ebc1..2f93ecb6 100644 --- a/tests/unit/utils/test_tensor.py +++ b/tests/unit/utils/test_tensor.py @@ -76,6 +76,11 @@ def test_is_cuda_available_false() -> None: assert not is_cuda_available() +@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: False) +def test_is_cuda_available_no_torch() -> None: + assert not is_cuda_available() + + ###################################### # Tests for is_mpa_available # ###################################### @@ -86,21 +91,6 @@ def test_is_mps_available() -> None: assert isinstance(is_mps_available(), bool) -@torch_available -@patch("torch.backends.mps.is_available", lambda *args, **kwargs: True) -@patch("torch.backends.mps.is_macos13_or_newer", lambda *args, **kwargs: True) -def test_is_mps_available_true() -> None: - assert is_mps_available() - - -@torch_available -@patch("torch.backends.mps.is_available", lambda *args, **kwargs: False) -def test_is_mps_available_false_not_available() -> None: - assert not is_mps_available() - - -@torch_available -@patch("torch.backends.mps.is_available", lambda *args, **kwargs: True) -@patch("torch.backends.mps.is_macos13_or_newer", lambda *args, **kwargs: False) -def test_is_mps_available_false_old_macos() -> None: +@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: False) +def test_is_mps_available_no_torch() -> None: assert not is_mps_available()