diff --git a/.github/workflows/nightly-package.yaml b/.github/workflows/nightly-package.yaml index 11be073c..cd2f58ea 100644 --- a/.github/workflows/nightly-package.yaml +++ b/.github/workflows/nightly-package.yaml @@ -94,7 +94,7 @@ jobs: matrix: os: [ ubuntu-latest ] python-version: [ '3.9', '3.10', '3.11' ] - jax-version: [ 0.4.17, 0.3.25 ] + jax-version: [ 0.4.18, 0.3.25 ] steps: - name: Checkout @@ -258,7 +258,7 @@ jobs: matrix: os: [ ubuntu-latest ] python-version: [ '3.9', '3.10', '3.11' ] - xarray-version: [ '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ] + xarray-version: [ '2023.9', '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ] steps: - name: Checkout @@ -278,45 +278,5 @@ jobs: python -c "import coola; import xarray as xr; import numpy as np; " \ "assert coola.objects_are_equal(xr.DataArray(np.arange(6), dims=["z"]), xr.DataArray(np.arange(6), dims=["z"]))" - - cyclic-import: - runs-on: ${{ matrix.os }} - timeout-minutes: 10 - strategy: - max-parallel: 8 - fail-fast: false - matrix: - os: [ ubuntu-latest ] - python-version: [ '3.10' ] - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - - name: Install package - run: | - pip install "coola[all]" - - - name: check coola.comparators - run: | - python -c "from coola import comparators" - - name: check coola.formatters - run: | - python -c "from coola import formatters" - - name: check coola.reducers - run: | - python -c "from coola import reducers" - - name: check coola.summarizers - run: | - python -c "from coola import summarizers" - - name: check coola.testers - run: | - python -c "from coola import testers" - - name: check coola.utils - run: | - python -c "from coola import utils" + cyclic-imports: + uses: ./.github/workflows/cyclic-imports.yaml diff --git a/.github/workflows/test-deps.yaml b/.github/workflows/test-deps.yaml index 8b0e855c..dd5bf07b 100644 --- a/.github/workflows/test-deps.yaml +++ b/.github/workflows/test-deps.yaml @@ -13,7 +13,7 @@ jobs: matrix: os: [ ubuntu-latest ] python-version: [ '3.9', '3.10', '3.11' ] - jax-version: [ 0.4.17, 0.3.25 ] + jax-version: [ 0.4.18, 0.3.25 ] steps: - name: Checkout @@ -240,7 +240,7 @@ jobs: matrix: os: [ ubuntu-latest ] python-version: [ '3.9', '3.10', '3.11' ] - xarray-version: [ '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ] + xarray-version: [ '2023.9', '2023.8', '2023.7', '2023.6', '2023.5', '2023.4', '2023.3' ] steps: - name: Checkout diff --git a/README.md b/README.md index 4ab6b37f..ec9a367e 100644 --- a/README.md +++ b/README.md @@ -142,7 +142,7 @@ The following is the corresponding `coola` versions and supported dependencies. | `coola` | `jax`* | `numpy`* | `pandas`* | `polars`* | `torch`* | `xarray`* | `python` | |----------|-------------------|---------------------|----------------------|----------------------|---------------------|----------------------|---------------| -| `0.0.25` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.2` | `>=2023.3,<2023.9` | `>=3.9,<3.12` | +| `0.0.25` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.2` | `>=2023.3,<2023.10` | `>=3.9,<3.12` | | `0.0.24` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.2` | `>=2023.3,<2023.9` | `>=3.9,<3.12` | | `0.0.23` | `>=0.3,<0.5` | `>=1.21,<1.27` | `>=1.3,<2.2` | `>=0.18.3,<0.20` | `>=1.10,<2.1` | `>=2023.3,<2023.9` | `>=3.9,<3.12` | | `0.0.22` | `>=0.3,<0.5` | `>=1.20,<1.26` | `>=1.3,<2.1` | `>=0.18.3,<0.19` | `>=1.10,<2.1` | `>=2023.3,<2023.9` | `>=3.9,<3.12` | diff --git a/poetry.lock b/poetry.lock index 692751de..b4886381 100644 --- a/poetry.lock +++ b/poetry.lock @@ -442,58 +442,65 @@ files = [ [[package]] name = "jax" -version = "0.4.17" +version = "0.4.18" description = "Differentiate, compile, and transform Numpy code." optional = true python-versions = ">=3.9" files = [ - {file = "jax-0.4.17-py3-none-any.whl", hash = "sha256:c3ab72ea2f1c5d8ccf2561e79f6562fb2964629f3e55b3ac1c11c48b64c20336"}, - {file = "jax-0.4.17.tar.gz", hash = "sha256:d7508a69e87835f534cb07a2f21d79cc1cb8c4cfdcf7fb010927267ef7355f1d"}, + {file = "jax-0.4.18-py3-none-any.whl", hash = "sha256:2ded3f558b74593c3533036a90c20d41ea35f35c74b25ca0fc86f4aafc388746"}, + {file = "jax-0.4.18.tar.gz", hash = "sha256:776cf33890100803e98f45f9af10aa727271c6993d4e766c069118733c928132"}, ] [package.dependencies] importlib-metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} ml-dtypes = ">=0.2.0" -numpy = ">=1.22" +numpy = [ + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.22", markers = "python_version < \"3.11\""}, +] opt-einsum = "*" scipy = ">=1.7" [package.extras] australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.16)"] -cpu = ["jaxlib (==0.4.17)"] -cuda = ["jaxlib (==0.4.17+cuda11.cudnn86)"] -cuda11-cudnn86 = ["jaxlib (==0.4.17+cuda11.cudnn86)"] -cuda11-local = ["jaxlib (==0.4.17+cuda11.cudnn86)"] -cuda11-pip = ["jaxlib (==0.4.17+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] -cuda12-local = ["jaxlib (==0.4.17+cuda12.cudnn89)"] -cuda12-pip = ["jaxlib (==0.4.17+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.2.5.6)", "nvidia-cuda-cupti-cu12 (>=12.2.142)", "nvidia-cuda-nvcc-cu12 (>=12.2.140)", "nvidia-cuda-runtime-cu12 (>=12.2.140)", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12 (>=11.0.8.103)", "nvidia-cusolver-cu12 (>=11.5.2)", "nvidia-cusparse-cu12 (>=12.1.2.141)"] +ci = ["jaxlib (==0.4.17)"] +cpu = ["jaxlib (==0.4.18)"] +cuda = ["jaxlib (==0.4.18+cuda11.cudnn86)"] +cuda11-cudnn86 = ["jaxlib (==0.4.18+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.18+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.18+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)", "nvidia-nccl-cu11 (>=2.18.3)"] +cuda12-local = ["jaxlib (==0.4.18+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.18+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.2.5.6)", "nvidia-cuda-cupti-cu12 (>=12.2.142)", "nvidia-cuda-nvcc-cu12 (>=12.2.140)", "nvidia-cuda-runtime-cu12 (>=12.2.140)", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12 (>=11.0.8.103)", "nvidia-cusolver-cu12 (>=11.5.2)", "nvidia-cusparse-cu12 (>=12.1.2.141)", "nvidia-nccl-cu12 (>=2.18.3)"] minimum-jaxlib = ["jaxlib (==0.4.14)"] -tpu = ["jaxlib (==0.4.17)", "libtpu-nightly (==0.1.dev20231003)", "requests"] +tpu = ["jaxlib (==0.4.18)", "libtpu-nightly (==0.1.dev20231006)", "requests"] [[package]] name = "jaxlib" -version = "0.4.17" +version = "0.4.18" description = "XLA library for JAX" optional = true python-versions = ">=3.9" files = [ - {file = "jaxlib-0.4.17-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:d4be1ac4bf1be1ae1cd8f5f4da414a6d0de8de36cf2effdb5758d4d677896078"}, - {file = "jaxlib-0.4.17-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:392c779f902c43e1a0af49159daffef9b5af952aba001463f98cf95a59ef17ff"}, - {file = "jaxlib-0.4.17-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:160fce68b82a79a6c522652e8dd9a10aac9c00d1599cb7e166671ad909aa139e"}, - {file = "jaxlib-0.4.17-cp310-cp310-win_amd64.whl", hash = "sha256:61b3788c6cfe46f307e6e67d4a942de72cf34711ff349f4f11500cdf6dc67199"}, - {file = "jaxlib-0.4.17-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:6c3524150bd85098f291fac81f73e285f3e095dbbb49751647cc27bed5327a78"}, - {file = "jaxlib-0.4.17-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e0d84a756b47ef6db52a6532b1f242cb8dc9035c102c60075470d65e71f7afb"}, - {file = "jaxlib-0.4.17-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:e04a0b8bb18ac24e25c15ed03be771815566f118c16f585ffe2e0f75bf7c064d"}, - {file = "jaxlib-0.4.17-cp311-cp311-win_amd64.whl", hash = "sha256:73173f1aff8d277110d32bdd5e073dc7d50e6618b5567b3bfbc53864b4613439"}, - {file = "jaxlib-0.4.17-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:909971337aabf5f2724a84c3166cea454b37024908d830695dc6b4ba4440676f"}, - {file = "jaxlib-0.4.17-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:98d42a402201fd0cb332bad4177b20c942d1acd1487581ee0c3cb5ef6766531a"}, - {file = "jaxlib-0.4.17-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:593aa9d1c940b9215968878561ad59feee2438386c3868d6524ff4ca730cfdf1"}, - {file = "jaxlib-0.4.17-cp312-cp312-win_amd64.whl", hash = "sha256:a4384cc7187f4f10749c6c623211d1a6b55575f921c00af38ff8f05fd3f7ecfd"}, - {file = "jaxlib-0.4.17-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:84b6cd54143ffe2ce45d5bcf2f9eafa1f9b4cf51dab0cc8e7703622fb624549d"}, - {file = "jaxlib-0.4.17-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a521d8323ef4d8155efc11b788f29dd3794cbb80f83533da957921a058fd3abe"}, - {file = "jaxlib-0.4.17-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:c933a6fb74f9fc16a2610566b32ff0135077ae9032e50f695653377d2cfbc9e5"}, - {file = "jaxlib-0.4.17-cp39-cp39-win_amd64.whl", hash = "sha256:44a2bff9966fe3b5783595d3214d3598bea48a2aa502de9d7d44c2ca39426929"}, + {file = "jaxlib-0.4.18-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:e0d78703fd1219d9875f20c6c692bba0973b744d66791e0b3e3cdb230c65a2a5"}, + {file = "jaxlib-0.4.18-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:89fff93b90d054715db0bc3d3b572b799071e63f1fb44edfb1630c5f53631cfd"}, + {file = "jaxlib-0.4.18-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:43287e8ece61f69b1d2a13d7f5e4d540c6edf4ab60bc1606b2b5f9321a9e8471"}, + {file = "jaxlib-0.4.18-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:3293689a8bef495c7837a82ca3038b92c4f21204cadbad6f497306c58aa554e1"}, + {file = "jaxlib-0.4.18-cp310-cp310-win_amd64.whl", hash = "sha256:a72ee7baf663ed5b9c6a426c1919f3755d4d71d4db6abf1979e6ce1f2451fb7f"}, + {file = "jaxlib-0.4.18-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:02895bc15ec578d3bdbdf2c3a2195852d45870f611b26e2e7261cc6b0353a928"}, + {file = "jaxlib-0.4.18-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e73b17ac3a6a3e034bca5e5752b1bf035a2eb50ced721a4651b287f5e2c672f7"}, + {file = "jaxlib-0.4.18-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:f3a8ce7096b3eadd531773c5ef7f7c3bb7552cdf163d682ccd0b0c7f7240d109"}, + {file = "jaxlib-0.4.18-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:4ce7b001fb070e2b7926553bedb9490b7671b3dcc176fd7b52df3932c3593cb0"}, + {file = "jaxlib-0.4.18-cp311-cp311-win_amd64.whl", hash = "sha256:055950e663fdc101b544597c1361596ff82816575640685effd4779949c8cf06"}, + {file = "jaxlib-0.4.18-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:85572a9fa84a17cffd05b771d528012297be4c0e227a07e7dd082c15094749b8"}, + {file = "jaxlib-0.4.18-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c429a15165b6b5ded5b0c46c5861d0b978a82aaa2200b2e517366d220f3f01ee"}, + {file = "jaxlib-0.4.18-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:dfe84a294ab3de2557c49a48c0d83c555018a5190aefa1134d5cb3219865edcd"}, + {file = "jaxlib-0.4.18-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:ed1ba86c1a2adea8235269f3e1f5561696068ca60c68b8e5a6f0eb3b978a305a"}, + {file = "jaxlib-0.4.18-cp312-cp312-win_amd64.whl", hash = "sha256:f0d5414bc79bdd667b81ee3c5836641bbd52d6d9c0054043dec8e025857e1260"}, + {file = "jaxlib-0.4.18-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:9593ff69f424947567e206f3e356b2a2df55ca68e6d815d5adc6cae308e8f652"}, + {file = "jaxlib-0.4.18-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:116a0d6aedd3e856b52493d7e392fb1b40952b84fb72448fde1c1ab5687db667"}, + {file = "jaxlib-0.4.18-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:6cb20bbbdafd90e71ad0deb9295519a0175c108c8c557b84fb9fe94f751daee4"}, + {file = "jaxlib-0.4.18-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:f7787a5531d226d6cc9ec2baa7141260bb713435e1cfc053cb9f5cefa9756ac3"}, + {file = "jaxlib-0.4.18-cp39-cp39-win_amd64.whl", hash = "sha256:4771e8439c48d1c3cf65e01016da02c6592a310bf973c9609fc3be7df9b49b22"}, ] [package.dependencies] @@ -524,20 +531,20 @@ i18n = ["Babel (>=2.7)"] [[package]] name = "markdown" -version = "3.4.4" +version = "3.5" description = "Python implementation of John Gruber's Markdown." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "Markdown-3.4.4-py3-none-any.whl", hash = "sha256:a4c1b65c0957b4bd9e7d86ddc7b3c9868fb9670660f6f99f6d1bca8954d5a941"}, - {file = "Markdown-3.4.4.tar.gz", hash = "sha256:225c6123522495d4119a90b3a3ba31a1e87a70369e03f14799ea9c0d7183a3d6"}, + {file = "Markdown-3.5-py3-none-any.whl", hash = "sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3"}, + {file = "Markdown-3.5.tar.gz", hash = "sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3"}, ] [package.dependencies] importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} [package.extras] -docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.0)", "mkdocs-nature (>=0.4)"] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] testing = ["coverage", "pyyaml"] [[package]] @@ -658,13 +665,13 @@ mkdocs = ">=1.1" [[package]] name = "mkdocs-material" -version = "9.4.3" +version = "9.4.4" description = "Documentation that simply works" optional = false python-versions = ">=3.8" files = [ - {file = "mkdocs_material-9.4.3-py3-none-any.whl", hash = "sha256:3274a47a4e55a541b25bd8fa4937cf3f3c82a51763453511661e0052062758b9"}, - {file = "mkdocs_material-9.4.3.tar.gz", hash = "sha256:5c9abc3f6ba8f88be1f9f13df23d695ca4dddbdd8a3538e4e6279c055c3936bc"}, + {file = "mkdocs_material-9.4.4-py3-none-any.whl", hash = "sha256:86fe79253afccc7f085f89a2d8e9e3300f82c4813d9b910d9081ce57a7e68380"}, + {file = "mkdocs_material-9.4.4.tar.gz", hash = "sha256:ab84a7cfaf009c47cd2926cdd7e6040b8cc12c3806cc533e8b16d57bd16d9c47"}, ] [package.dependencies] @@ -934,8 +941,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -1622,13 +1629,13 @@ watchmedo = ["PyYAML (>=3.10)"] [[package]] name = "xarray" -version = "2023.8.0" +version = "2023.9.0" description = "N-D labeled arrays and datasets in Python" optional = true python-versions = ">=3.9" files = [ - {file = "xarray-2023.8.0-py3-none-any.whl", hash = "sha256:eb42b56aea2c7d5db2a7d0c33fb005b78eb5c4421eb747f2ced138c70b5c204e"}, - {file = "xarray-2023.8.0.tar.gz", hash = "sha256:825c6d64202a731a4e49321edd1e9dfabf4be06802f1b8c8a3c00a3ebfc8cedf"}, + {file = "xarray-2023.9.0-py3-none-any.whl", hash = "sha256:3fc4a558bd70968040a4e1cefc6ddb3f9a7a86ef6a48e67857156ffe655d3a66"}, + {file = "xarray-2023.9.0.tar.gz", hash = "sha256:271955c05dc626dad37791a7807d920aaf9c64cac71d03b45ec7e402cc646603"}, ] [package.dependencies] @@ -1638,8 +1645,7 @@ pandas = ">=1.4" [package.extras] accel = ["bottleneck", "flox", "numbagg", "scipy"] -complete = ["bottleneck", "cftime", "dask[complete]", "flox", "fsspec", "h5netcdf", "matplotlib", "nc-time-axis", "netCDF4", "numbagg", "pooch", "pydap", "scipy", "seaborn", "zarr"] -docs = ["bottleneck", "cftime", "dask[complete]", "flox", "fsspec", "h5netcdf", "ipykernel", "ipython", "jupyter-client", "matplotlib", "nbsphinx", "nc-time-axis", "netCDF4", "numbagg", "pooch", "pydap", "scanpydoc", "scipy", "seaborn", "sphinx-autosummary-accessors", "sphinx-rtd-theme", "zarr"] +complete = ["xarray[accel,io,parallel,viz]"] io = ["cftime", "fsspec", "h5netcdf", "netCDF4", "pooch", "pydap", "scipy", "zarr"] parallel = ["dask[complete]"] viz = ["matplotlib", "nc-time-axis", "seaborn"] @@ -1692,4 +1698,4 @@ all = ["jax", "jaxlib", "numpy", "pandas", "polars", "torch", "xarray"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "c0637f6b1a5db29cf5c7b4c7939e741d5cbcece7e11349edddb9d7ac1d3852b9" +content-hash = "44d92d6398383fbe5fb5ba2121cc2cdaa56195963211d4e65feec5ac683bca5f" diff --git a/pyproject.toml b/pyproject.toml index 4c9be07d..5f480980 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ pandas = { version = ">=1.3,<2.2", optional = true } # polars: 0.18.3 is the minimal version because of https://github.com/pola-rs/polars/issues/9358 polars = { version = ">=0.18.3,<0.20", optional = true } torch = { version = ">=1.10,<2.2", optional = true } -xarray = { version = ">=2023.3,<2023.9", optional = true } +xarray = { version = ">=2023.3,<2023.10", optional = true } [tool.poetry.extras] all = ["jax", "jaxlib", "numpy", "pandas", "polars", "torch", "xarray"] diff --git a/src/coola/utils/tensor.py b/src/coola/utils/tensor.py index 94d6b23c..c9a1ab4d 100644 --- a/src/coola/utils/tensor.py +++ b/src/coola/utils/tensor.py @@ -2,6 +2,7 @@ __all__ = ["get_available_devices", "is_cuda_available", "is_mps_available"] +from functools import lru_cache from unittest.mock import Mock from coola.utils.imports import is_torch_available @@ -12,6 +13,7 @@ torch = Mock() +@lru_cache(1) def get_available_devices() -> tuple[str, ...]: r"""Gets the available PyTorch devices on the machine. @@ -35,6 +37,7 @@ def get_available_devices() -> tuple[str, ...]: return tuple(devices) +@lru_cache(1) def is_cuda_available() -> bool: r"""Indicates if CUDA is currently available. @@ -52,6 +55,7 @@ def is_cuda_available() -> bool: return is_torch_available() and torch.cuda.is_available() +@lru_cache(1) def is_mps_available() -> bool: r"""Indicates if MPS is currently available. @@ -66,8 +70,10 @@ def is_mps_available() -> bool: >>> from coola.utils.tensor import is_mps_available >>> is_mps_available() """ - return ( - is_torch_available() - and hasattr(torch.backends, "mps") - and torch.backends.mps.is_available() - ) + if not is_torch_available(): + return False + try: + torch.ones(1, device="mps") + return True + except RuntimeError: + return False diff --git a/tests/unit/utils/test_tensor.py b/tests/unit/utils/test_tensor.py index 141b6343..3b71a4dc 100644 --- a/tests/unit/utils/test_tensor.py +++ b/tests/unit/utils/test_tensor.py @@ -1,12 +1,25 @@ -from unittest.mock import patch +from __future__ import annotations + +from unittest.mock import Mock, patch + +from pytest import fixture from coola.testing import torch_available from coola.utils.tensor import ( get_available_devices, is_cuda_available, is_mps_available, + torch, ) + +@fixture(autouse=True) +def reset() -> None: + get_available_devices.cache_clear() + is_cuda_available.cache_clear() + is_mps_available.cache_clear() + + ########################################### # Tests for get_available_devices # ########################################### @@ -64,11 +77,35 @@ def test_is_cuda_available_false() -> None: assert not is_cuda_available() +@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: False) +def test_is_cuda_available_no_torch() -> None: + assert not is_cuda_available() + + ###################################### -# Tests for is_mpa_available # +# Tests for is_mps_available # ###################################### @torch_available def test_is_mps_available() -> None: assert isinstance(is_mps_available(), bool) + + +@torch_available +@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: True) +def test_is_mps_available_with_mps() -> None: + with patch("coola.utils.tensor.torch.ones", Mock(return_value=torch.ones(1))): + assert is_mps_available() + + +@torch_available +@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: True) +def test_is_mps_available_without_mps() -> None: + with patch("coola.utils.tensor.torch.ones", Mock(side_effect=RuntimeError)): + assert not is_mps_available() + + +@patch("coola.utils.tensor.is_torch_available", lambda *args, **kwargs: False) +def test_is_mps_available_no_torch() -> None: + assert not is_mps_available()