diff --git a/.github/workflows/nightly-package.yaml b/.github/workflows/nightly-package.yaml
index 11be073c..cd2f58ea 100644
--- a/.github/workflows/nightly-package.yaml
+++ b/.github/workflows/nightly-package.yaml
@@ -94,7 +94,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
python-version: [ '3.9', '3.10', '3.11' ]
- jax-version: [ 0.4.17, 0.3.25 ]
+ jax-version: [ 0.4.18, 0.3.25 ]
steps:
- name: Checkout
@@ -258,7 +258,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
python-version: [ '3.9', '3.10', '3.11' ]
- xarray-version: [ '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ]
+ xarray-version: [ '2023.9', '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ]
steps:
- name: Checkout
@@ -278,45 +278,5 @@ jobs:
python -c "import coola; import xarray as xr; import numpy as np; " \
"assert coola.objects_are_equal(xr.DataArray(np.arange(6), dims=["z"]), xr.DataArray(np.arange(6), dims=["z"]))"
-
- cyclic-import:
- runs-on: ${{ matrix.os }}
- timeout-minutes: 10
- strategy:
- max-parallel: 8
- fail-fast: false
- matrix:
- os: [ ubuntu-latest ]
- python-version: [ '3.10' ]
-
- steps:
- - name: Checkout
- uses: actions/checkout@v3
-
- - name: Set up Python
- uses: actions/setup-python@v4
- with:
- python-version: ${{ matrix.python-version }}
-
- - name: Install package
- run: |
- pip install "coola[all]"
-
- - name: check coola.comparators
- run: |
- python -c "from coola import comparators"
- - name: check coola.formatters
- run: |
- python -c "from coola import formatters"
- - name: check coola.reducers
- run: |
- python -c "from coola import reducers"
- - name: check coola.summarizers
- run: |
- python -c "from coola import summarizers"
- - name: check coola.testers
- run: |
- python -c "from coola import testers"
- - name: check coola.utils
- run: |
- python -c "from coola import utils"
+ cyclic-imports:
+ uses: ./.github/workflows/cyclic-imports.yaml
diff --git a/.github/workflows/test-deps.yaml b/.github/workflows/test-deps.yaml
index 8b0e855c..dd5bf07b 100644
--- a/.github/workflows/test-deps.yaml
+++ b/.github/workflows/test-deps.yaml
@@ -13,7 +13,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
python-version: [ '3.9', '3.10', '3.11' ]
- jax-version: [ 0.4.17, 0.3.25 ]
+ jax-version: [ 0.4.18, 0.3.25 ]
steps:
- name: Checkout
@@ -240,7 +240,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
python-version: [ '3.9', '3.10', '3.11' ]
- xarray-version: [ '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ]
+ xarray-version: [ '2023.9', '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ]
steps:
- name: Checkout
diff --git a/README.md b/README.md
index 4ab6b37f..ec9a367e 100644
--- a/README.md
+++ b/README.md
@@ -142,7 +142,7 @@ The following is the corresponding `coola` versions and supported dependencies.
| `coola` | `jax`* | `numpy`* | `pandas`* | `polars`* | `torch`* | `xarray`* | `python` |
|----------|-------------------|---------------------|----------------------|----------------------|---------------------|----------------------|---------------|
-| `0.0.25` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.2` | `>=2023.3,<2023.9` | `>=3.9,<3.12` |
+| `0.0.25` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.2` | `>=2023.3,<2023.10` | `>=3.9,<3.12` |
| `0.0.24` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.2` | `>=2023.3,<2023.9` | `>=3.9,<3.12` |
| `0.0.23` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.1` | `>=2023.3,<2023.9` | `>=3.9,<3.12` |
| `0.0.22` | `>=0.3,<0.5` | `>=1.20,<1.26` | `>=1.3,<2.1` | `>=0.18.3,<0.19` | `>=1.10,<2.1` | `>=2023.3,<2023.9` | `>=3.9,<3.12` |
diff --git a/poetry.lock b/poetry.lock
index 692751de..b4886381 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -442,58 +442,65 @@ files = [
[[package]]
name = "jax"
-version = "0.4.17"
+version = "0.4.18"
description = "Differentiate, compile, and transform Numpy code."
optional = true
python-versions = ">=3.9"
files = [
- {file = "jax-0.4.17-py3-none-any.whl", hash = "sha256:c3ab72ea2f1c5d8ccf2561e79f6562fb2964629f3e55b3ac1c11c48b64c20336"},
- {file = "jax-0.4.17.tar.gz", hash = "sha256:d7508a69e87835f534cb07a2f21d79cc1cb8c4cfdcf7fb010927267ef7355f1d"},
+ {file = "jax-0.4.18-py3-none-any.whl", hash = "sha256:2ded3f558b74593c3533036a90c20d41ea35f35c74b25ca0fc86f4aafc388746"},
+ {file = "jax-0.4.18.tar.gz", hash = "sha256:776cf33890100803e98f45f9af10aa727271c6993d4e766c069118733c928132"},
]
[package.dependencies]
importlib-metadata = {version = ">=4.6", markers = "python_version < \"3.10\""}
ml-dtypes = ">=0.2.0"
-numpy = ">=1.22"
+numpy = [
+ {version = ">=1.23.2", markers = "python_version >= \"3.11\""},
+ {version = ">=1.22", markers = "python_version < \"3.11\""},
+]
opt-einsum = "*"
scipy = ">=1.7"
[package.extras]
australis = ["protobuf (>=3.13,<4)"]
-ci = ["jaxlib (==0.4.16)"]
-cpu = ["jaxlib (==0.4.17)"]
-cuda = ["jaxlib (==0.4.17+cuda11.cudnn86)"]
-cuda11-cudnn86 = ["jaxlib (==0.4.17+cuda11.cudnn86)"]
-cuda11-local = ["jaxlib (==0.4.17+cuda11.cudnn86)"]
-cuda11-pip = ["jaxlib (==0.4.17+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"]
-cuda12-local = ["jaxlib (==0.4.17+cuda12.cudnn89)"]
-cuda12-pip = ["jaxlib (==0.4.17+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.2.5.6)", "nvidia-cuda-cupti-cu12 (>=12.2.142)", "nvidia-cuda-nvcc-cu12 (>=12.2.140)", "nvidia-cuda-runtime-cu12 (>=12.2.140)", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12 (>=11.0.8.103)", "nvidia-cusolver-cu12 (>=11.5.2)", "nvidia-cusparse-cu12 (>=12.1.2.141)"]
+ci = ["jaxlib (==0.4.17)"]
+cpu = ["jaxlib (==0.4.18)"]
+cuda = ["jaxlib (==0.4.18+cuda11.cudnn86)"]
+cuda11-cudnn86 = ["jaxlib (==0.4.18+cuda11.cudnn86)"]
+cuda11-local = ["jaxlib (==0.4.18+cuda11.cudnn86)"]
+cuda11-pip = ["jaxlib (==0.4.18+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)", "nvidia-nccl-cu11 (>=2.18.3)"]
+cuda12-local = ["jaxlib (==0.4.18+cuda12.cudnn89)"]
+cuda12-pip = ["jaxlib (==0.4.18+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.2.5.6)", "nvidia-cuda-cupti-cu12 (>=12.2.142)", "nvidia-cuda-nvcc-cu12 (>=12.2.140)", "nvidia-cuda-runtime-cu12 (>=12.2.140)", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12 (>=11.0.8.103)", "nvidia-cusolver-cu12 (>=11.5.2)", "nvidia-cusparse-cu12 (>=12.1.2.141)", "nvidia-nccl-cu12 (>=2.18.3)"]
minimum-jaxlib = ["jaxlib (==0.4.14)"]
-tpu = ["jaxlib (==0.4.17)", "libtpu-nightly (==0.1.dev20231003)", "requests"]
+tpu = ["jaxlib (==0.4.18)", "libtpu-nightly (==0.1.dev20231006)", "requests"]
[[package]]
name = "jaxlib"
-version = "0.4.17"
+version = "0.4.18"
description = "XLA library for JAX"
optional = true
python-versions = ">=3.9"
files = [
- {file = "jaxlib-0.4.17-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:d4be1ac4bf1be1ae1cd8f5f4da414a6d0de8de36cf2effdb5758d4d677896078"},
- {file = "jaxlib-0.4.17-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:392c779f902c43e1a0af49159daffef9b5af952aba001463f98cf95a59ef17ff"},
- {file = "jaxlib-0.4.17-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:160fce68b82a79a6c522652e8dd9a10aac9c00d1599cb7e166671ad909aa139e"},
- {file = "jaxlib-0.4.17-cp310-cp310-win_amd64.whl", hash = "sha256:61b3788c6cfe46f307e6e67d4a942de72cf34711ff349f4f11500cdf6dc67199"},
- {file = "jaxlib-0.4.17-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:6c3524150bd85098f291fac81f73e285f3e095dbbb49751647cc27bed5327a78"},
- {file = "jaxlib-0.4.17-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e0d84a756b47ef6db52a6532b1f242cb8dc9035c102c60075470d65e71f7afb"},
- {file = "jaxlib-0.4.17-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:e04a0b8bb18ac24e25c15ed03be771815566f118c16f585ffe2e0f75bf7c064d"},
- {file = "jaxlib-0.4.17-cp311-cp311-win_amd64.whl", hash = "sha256:73173f1aff8d277110d32bdd5e073dc7d50e6618b5567b3bfbc53864b4613439"},
- {file = "jaxlib-0.4.17-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:909971337aabf5f2724a84c3166cea454b37024908d830695dc6b4ba4440676f"},
- {file = "jaxlib-0.4.17-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:98d42a402201fd0cb332bad4177b20c942d1acd1487581ee0c3cb5ef6766531a"},
- {file = "jaxlib-0.4.17-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:593aa9d1c940b9215968878561ad59feee2438386c3868d6524ff4ca730cfdf1"},
- {file = "jaxlib-0.4.17-cp312-cp312-win_amd64.whl", hash = "sha256:a4384cc7187f4f10749c6c623211d1a6b55575f921c00af38ff8f05fd3f7ecfd"},
- {file = "jaxlib-0.4.17-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:84b6cd54143ffe2ce45d5bcf2f9eafa1f9b4cf51dab0cc8e7703622fb624549d"},
- {file = "jaxlib-0.4.17-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a521d8323ef4d8155efc11b788f29dd3794cbb80f83533da957921a058fd3abe"},
- {file = "jaxlib-0.4.17-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:c933a6fb74f9fc16a2610566b32ff0135077ae9032e50f695653377d2cfbc9e5"},
- {file = "jaxlib-0.4.17-cp39-cp39-win_amd64.whl", hash = "sha256:44a2bff9966fe3b5783595d3214d3598bea48a2aa502de9d7d44c2ca39426929"},
+ {file = "jaxlib-0.4.18-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:e0d78703fd1219d9875f20c6c692bba0973b744d66791e0b3e3cdb230c65a2a5"},
+ {file = "jaxlib-0.4.18-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:89fff93b90d054715db0bc3d3b572b799071e63f1fb44edfb1630c5f53631cfd"},
+ {file = "jaxlib-0.4.18-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:43287e8ece61f69b1d2a13d7f5e4d540c6edf4ab60bc1606b2b5f9321a9e8471"},
+ {file = "jaxlib-0.4.18-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:3293689a8bef495c7837a82ca3038b92c4f21204cadbad6f497306c58aa554e1"},
+ {file = "jaxlib-0.4.18-cp310-cp310-win_amd64.whl", hash = "sha256:a72ee7baf663ed5b9c6a426c1919f3755d4d71d4db6abf1979e6ce1f2451fb7f"},
+ {file = "jaxlib-0.4.18-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:02895bc15ec578d3bdbdf2c3a2195852d45870f611b26e2e7261cc6b0353a928"},
+ {file = "jaxlib-0.4.18-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e73b17ac3a6a3e034bca5e5752b1bf035a2eb50ced721a4651b287f5e2c672f7"},
+ {file = "jaxlib-0.4.18-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:f3a8ce7096b3eadd531773c5ef7f7c3bb7552cdf163d682ccd0b0c7f7240d109"},
+ {file = "jaxlib-0.4.18-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:4ce7b001fb070e2b7926553bedb9490b7671b3dcc176fd7b52df3932c3593cb0"},
+ {file = "jaxlib-0.4.18-cp311-cp311-win_amd64.whl", hash = "sha256:055950e663fdc101b544597c1361596ff82816575640685effd4779949c8cf06"},
+ {file = "jaxlib-0.4.18-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:85572a9fa84a17cffd05b771d528012297be4c0e227a07e7dd082c15094749b8"},
+ {file = "jaxlib-0.4.18-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c429a15165b6b5ded5b0c46c5861d0b978a82aaa2200b2e517366d220f3f01ee"},
+ {file = "jaxlib-0.4.18-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:dfe84a294ab3de2557c49a48c0d83c555018a5190aefa1134d5cb3219865edcd"},
+ {file = "jaxlib-0.4.18-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:ed1ba86c1a2adea8235269f3e1f5561696068ca60c68b8e5a6f0eb3b978a305a"},
+ {file = "jaxlib-0.4.18-cp312-cp312-win_amd64.whl", hash = "sha256:f0d5414bc79bdd667b81ee3c5836641bbd52d6d9c0054043dec8e025857e1260"},
+ {file = "jaxlib-0.4.18-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:9593ff69f424947567e206f3e356b2a2df55ca68e6d815d5adc6cae308e8f652"},
+ {file = "jaxlib-0.4.18-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:116a0d6aedd3e856b52493d7e392fb1b40952b84fb72448fde1c1ab5687db667"},
+ {file = "jaxlib-0.4.18-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:6cb20bbbdafd90e71ad0deb9295519a0175c108c8c557b84fb9fe94f751daee4"},
+ {file = "jaxlib-0.4.18-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:f7787a5531d226d6cc9ec2baa7141260bb713435e1cfc053cb9f5cefa9756ac3"},
+ {file = "jaxlib-0.4.18-cp39-cp39-win_amd64.whl", hash = "sha256:4771e8439c48d1c3cf65e01016da02c6592a310bf973c9609fc3be7df9b49b22"},
]
[package.dependencies]
@@ -524,20 +531,20 @@ i18n = ["Babel (>=2.7)"]
[[package]]
name = "markdown"
-version = "3.4.4"
+version = "3.5"
description = "Python implementation of John Gruber's Markdown."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "Markdown-3.4.4-py3-none-any.whl", hash = "sha256:a4c1b65c0957b4bd9e7d86ddc7b3c9868fb9670660f6f99f6d1bca8954d5a941"},
- {file = "Markdown-3.4.4.tar.gz", hash = "sha256:225c6123522495d4119a90b3a3ba31a1e87a70369e03f14799ea9c0d7183a3d6"},
+ {file = "Markdown-3.5-py3-none-any.whl", hash = "sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3"},
+ {file = "Markdown-3.5.tar.gz", hash = "sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3"},
]
[package.dependencies]
importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""}
[package.extras]
-docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.0)", "mkdocs-nature (>=0.4)"]
+docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"]
testing = ["coverage", "pyyaml"]
[[package]]
@@ -658,13 +665,13 @@ mkdocs = ">=1.1"
[[package]]
name = "mkdocs-material"
-version = "9.4.3"
+version = "9.4.4"
description = "Documentation that simply works"
optional = false
python-versions = ">=3.8"
files = [
- {file = "mkdocs_material-9.4.3-py3-none-any.whl", hash = "sha256:3274a47a4e55a541b25bd8fa4937cf3f3c82a51763453511661e0052062758b9"},
- {file = "mkdocs_material-9.4.3.tar.gz", hash = "sha256:5c9abc3f6ba8f88be1f9f13df23d695ca4dddbdd8a3538e4e6279c055c3936bc"},
+ {file = "mkdocs_material-9.4.4-py3-none-any.whl", hash = "sha256:86fe79253afccc7f085f89a2d8e9e3300f82c4813d9b910d9081ce57a7e68380"},
+ {file = "mkdocs_material-9.4.4.tar.gz", hash = "sha256:ab84a7cfaf009c47cd2926cdd7e6040b8cc12c3806cc533e8b16d57bd16d9c47"},
]
[package.dependencies]
@@ -934,8 +941,8 @@ files = [
[package.dependencies]
numpy = [
- {version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
+ {version = ">=1.22.4", markers = "python_version < \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@@ -1622,13 +1629,13 @@ watchmedo = ["PyYAML (>=3.10)"]
[[package]]
name = "xarray"
-version = "2023.8.0"
+version = "2023.9.0"
description = "N-D labeled arrays and datasets in Python"
optional = true
python-versions = ">=3.9"
files = [
- {file = "xarray-2023.8.0-py3-none-any.whl", hash = "sha256:eb42b56aea2c7d5db2a7d0c33fb005b78eb5c4421eb747f2ced138c70b5c204e"},
- {file = "xarray-2023.8.0.tar.gz", hash = "sha256:825c6d64202a731a4e49321edd1e9dfabf4be06802f1b8c8a3c00a3ebfc8cedf"},
+ {file = "xarray-2023.9.0-py3-none-any.whl", hash = "sha256:3fc4a558bd70968040a4e1cefc6ddb3f9a7a86ef6a48e67857156ffe655d3a66"},
+ {file = "xarray-2023.9.0.tar.gz", hash = "sha256:271955c05dc626dad37791a7807d920aaf9c64cac71d03b45ec7e402cc646603"},
]
[package.dependencies]
@@ -1638,8 +1645,7 @@ pandas = ">=1.4"
[package.extras]
accel = ["bottleneck", "flox", "numbagg", "scipy"]
-complete = ["bottleneck", "cftime", "dask[complete]", "flox", "fsspec", "h5netcdf", "matplotlib", "nc-time-axis", "netCDF4", "numbagg", "pooch", "pydap", "scipy", "seaborn", "zarr"]
-docs = ["bottleneck", "cftime", "dask[complete]", "flox", "fsspec", "h5netcdf", "ipykernel", "ipython", "jupyter-client", "matplotlib", "nbsphinx", "nc-time-axis", "netCDF4", "numbagg", "pooch", "pydap", "scanpydoc", "scipy", "seaborn", "sphinx-autosummary-accessors", "sphinx-rtd-theme", "zarr"]
+complete = ["xarray[accel,io,parallel,viz]"]
io = ["cftime", "fsspec", "h5netcdf", "netCDF4", "pooch", "pydap", "scipy", "zarr"]
parallel = ["dask[complete]"]
viz = ["matplotlib", "nc-time-axis", "seaborn"]
@@ -1692,4 +1698,4 @@ all = ["jax", "jaxlib", "numpy", "pandas", "polars", "torch", "xarray"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.12"
-content-hash = "c0637f6b1a5db29cf5c7b4c7939e741d5cbcece7e11349edddb9d7ac1d3852b9"
+content-hash = "44d92d6398383fbe5fb5ba2121cc2cdaa56195963211d4e65feec5ac683bca5f"
diff --git a/pyproject.toml b/pyproject.toml
index 4c9be07d..5f480980 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,7 +39,7 @@ pandas = { version = ">=1.3,<2.2", optional = true }
# polars: 0.18.3 is the minimal version because of https://github.com/pola-rs/polars/issues/9358
polars = { version = ">=0.18.3,<0.20", optional = true }
torch = { version = ">=1.10,<2.2", optional = true }
-xarray = { version = ">=2023.3,<2023.9", optional = true }
+xarray = { version = ">=2023.3,<2023.10", optional = true }
[tool.poetry.extras]
all = ["jax", "jaxlib", "numpy", "pandas", "polars", "torch", "xarray"]
diff --git a/src/coola/utils/tensor.py b/src/coola/utils/tensor.py
index 94d6b23c..c9a1ab4d 100644
--- a/src/coola/utils/tensor.py
+++ b/src/coola/utils/tensor.py
@@ -2,6 +2,7 @@
__all__ = ["get_available_devices", "is_cuda_available", "is_mps_available"]
+from functools import lru_cache
from unittest.mock import Mock
from coola.utils.imports import is_torch_available
@@ -12,6 +13,7 @@
torch = Mock()
+@lru_cache(1)
def get_available_devices() -> tuple[str, ...]:
r"""Gets the available PyTorch devices on the machine.
@@ -35,6 +37,7 @@ def get_available_devices() -> tuple[str, ...]:
return tuple(devices)
+@lru_cache(1)
def is_cuda_available() -> bool:
r"""Indicates if CUDA is currently available.
@@ -52,6 +55,7 @@ def is_cuda_available() -> bool:
return is_torch_available() and torch.cuda.is_available()
+@lru_cache(1)
def is_mps_available() -> bool:
r"""Indicates if MPS is currently available.
@@ -66,8 +70,10 @@ def is_mps_available() -> bool:
>>> from coola.utils.tensor import is_mps_available
>>> is_mps_available()
"""
- return (
- is_torch_available()
- and hasattr(torch.backends, "mps")
- and torch.backends.mps.is_available()
- )
+ if not is_torch_available():
+ return False
+ try:
+ torch.ones(1, device="mps")
+ return True
+ except RuntimeError:
+ return False
diff --git a/tests/unit/utils/test_tensor.py b/tests/unit/utils/test_tensor.py
index 141b6343..3b71a4dc 100644
--- a/tests/unit/utils/test_tensor.py
+++ b/tests/unit/utils/test_tensor.py
@@ -1,12 +1,25 @@
-from unittest.mock import patch
+from __future__ import annotations
+
+from unittest.mock import Mock, patch
+
+from pytest import fixture
from coola.testing import torch_available
from coola.utils.tensor import (
get_available_devices,
is_cuda_available,
is_mps_available,
+ torch,
)
+
+@fixture(autouse=True)
+def reset() -> None:
+ get_available_devices.cache_clear()
+ is_cuda_available.cache_clear()
+ is_mps_available.cache_clear()
+
+
###########################################
# Tests for get_available_devices #
###########################################
@@ -64,11 +77,35 @@ def test_is_cuda_available_false() -> None:
assert not is_cuda_available()
+@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: False)
+def test_is_cuda_available_no_torch() -> None:
+ assert not is_cuda_available()
+
+
######################################
-# Tests for is_mpa_available #
+# Tests for is_mps_available #
######################################
@torch_available
def test_is_mps_available() -> None:
assert isinstance(is_mps_available(), bool)
+
+
+@torch_available
+@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: True)
+def test_is_mps_available_with_mps() -> None:
+ with patch("coola.utils.tensor.torch.ones", Mock(return_value=torch.ones(1))):
+ assert is_mps_available()
+
+
+@torch_available
+@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: True)
+def test_is_mps_available_without_mps() -> None:
+ with patch("coola.utils.tensor.torch.ones", Mock(side_effect=RuntimeError)):
+ assert not is_mps_available()
+
+
+@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: False)
+def test_is_mps_available_no_torch() -> None:
+ assert not is_mps_available()