Skip to content

Commit

Permalink
Update dependencies and workflow (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Oct 8, 2023
1 parent 52a9aa5 commit c996ebc
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 100 deletions.
48 changes: 4 additions & 44 deletions .github/workflows/nightly-package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
python-version: [ '3.9', '3.10', '3.11' ]
jax-version: [ 0.4.17, 0.3.25 ]
jax-version: [ 0.4.18, 0.3.25 ]

steps:
- name: Checkout
Expand Down Expand Up @@ -258,7 +258,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
python-version: [ '3.9', '3.10', '3.11' ]
xarray-version: [ '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ]
xarray-version: [ '2023.9', '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ]

steps:
- name: Checkout
Expand All @@ -278,45 +278,5 @@ jobs:
python -c "import coola; import xarray as xr; import numpy as np; " \
"assert coola.objects_are_equal(xr.DataArray(np.arange(6), dims=["z"]), xr.DataArray(np.arange(6), dims=["z"]))"
cyclic-import:
runs-on: ${{ matrix.os }}
timeout-minutes: 10
strategy:
max-parallel: 8
fail-fast: false
matrix:
os: [ ubuntu-latest ]
python-version: [ '3.10' ]

steps:
- name: Checkout
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install package
run: |
pip install "coola[all]"
- name: check coola.comparators
run: |
python -c "from coola import comparators"
- name: check coola.formatters
run: |
python -c "from coola import formatters"
- name: check coola.reducers
run: |
python -c "from coola import reducers"
- name: check coola.summarizers
run: |
python -c "from coola import summarizers"
- name: check coola.testers
run: |
python -c "from coola import testers"
- name: check coola.utils
run: |
python -c "from coola import utils"
cyclic-imports:
uses: ./.github/workflows/cyclic-imports.yaml
4 changes: 2 additions & 2 deletions .github/workflows/test-deps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
python-version: [ '3.9', '3.10', '3.11' ]
jax-version: [ 0.4.17, 0.3.25 ]
jax-version: [ 0.4.18, 0.3.25 ]

steps:
- name: Checkout
Expand Down Expand Up @@ -240,7 +240,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
python-version: [ '3.9', '3.10', '3.11' ]
xarray-version: [ '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ]
xarray-version: [ '2023.9', '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ]

steps:
- name: Checkout
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ The following is the corresponding `coola` versions and supported dependencies.

| `coola` | `jax`<sup>*</sup> | `numpy`<sup>*</sup> | `pandas`<sup>*</sup> | `polars`<sup>*</sup> | `torch`<sup>*</sup> | `xarray`<sup>*</sup> | `python` |
|----------|-------------------|---------------------|----------------------|----------------------|---------------------|----------------------|---------------|
| `0.0.25` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.2` | `>=2023.3,<2023.9` | `>=3.9,<3.12` |
| `0.0.25` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.2` | `>=2023.3,<2023.10` | `>=3.9,<3.12` |
| `0.0.24` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.2` | `>=2023.3,<2023.9` | `>=3.9,<3.12` |
| `0.0.23` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.1` | `>=2023.3,<2023.9` | `>=3.9,<3.12` |
| `0.0.22` | `>=0.3,<0.5` | `>=1.20,<1.26` | `>=1.3,<2.1` | `>=0.18.3,<0.19` | `>=1.10,<2.1` | `>=2023.3,<2023.9` | `>=3.9,<3.12` |
Expand Down
96 changes: 51 additions & 45 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pandas = { version = ">=1.3,<2.2", optional = true }
# polars: 0.18.3 is the minimal version because of https://github.com/pola-rs/polars/issues/9358
polars = { version = ">=0.18.3,<0.20", optional = true }
torch = { version = ">=1.10,<2.2", optional = true }
xarray = { version = ">=2023.3,<2023.9", optional = true }
xarray = { version = ">=2023.3,<2023.10", optional = true }

[tool.poetry.extras]
all = ["jax", "jaxlib", "numpy", "pandas", "polars", "torch", "xarray"]
Expand Down
16 changes: 11 additions & 5 deletions src/coola/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = ["get_available_devices", "is_cuda_available", "is_mps_available"]

from functools import lru_cache
from unittest.mock import Mock

from coola.utils.imports import is_torch_available
Expand All @@ -12,6 +13,7 @@
torch = Mock()


@lru_cache(1)
def get_available_devices() -> tuple[str, ...]:
r"""Gets the available PyTorch devices on the machine.
Expand All @@ -35,6 +37,7 @@ def get_available_devices() -> tuple[str, ...]:
return tuple(devices)


@lru_cache(1)
def is_cuda_available() -> bool:
r"""Indicates if CUDA is currently available.
Expand All @@ -52,6 +55,7 @@ def is_cuda_available() -> bool:
return is_torch_available() and torch.cuda.is_available()


@lru_cache(1)
def is_mps_available() -> bool:
r"""Indicates if MPS is currently available.
Expand All @@ -66,8 +70,10 @@ def is_mps_available() -> bool:
>>> from coola.utils.tensor import is_mps_available
>>> is_mps_available()
"""
return (
is_torch_available()
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
)
if not is_torch_available():
return False
try:
torch.ones(1, device="mps")
return True
except RuntimeError:
return False
41 changes: 39 additions & 2 deletions tests/unit/utils/test_tensor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
from unittest.mock import patch
from __future__ import annotations

from unittest.mock import Mock, patch

from pytest import fixture

from coola.testing import torch_available
from coola.utils.tensor import (
get_available_devices,
is_cuda_available,
is_mps_available,
torch,
)


@fixture(autouse=True)
def reset() -> None:
get_available_devices.cache_clear()
is_cuda_available.cache_clear()
is_mps_available.cache_clear()


###########################################
# Tests for get_available_devices #
###########################################
Expand Down Expand Up @@ -64,11 +77,35 @@ def test_is_cuda_available_false() -> None:
assert not is_cuda_available()


@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: False)
def test_is_cuda_available_no_torch() -> None:
assert not is_cuda_available()


######################################
# Tests for is_mpa_available #
# Tests for is_mps_available #
######################################


@torch_available
def test_is_mps_available() -> None:
assert isinstance(is_mps_available(), bool)


@torch_available
@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: True)
def test_is_mps_available_with_mps() -> None:
with patch("coola.utils.tensor.torch.ones", Mock(return_value=torch.ones(1))):
assert is_mps_available()


@torch_available
@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: True)
def test_is_mps_available_without_mps() -> None:
with patch("coola.utils.tensor.torch.ones", Mock(side_effect=RuntimeError)):
assert not is_mps_available()


@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: False)
def test_is_mps_available_no_torch() -> None:
assert not is_mps_available()

0 comments on commit c996ebc

Please sign in to comment.