Skip to content

Commit

Permalink
test: add error handling test for find_batch_size() and improve mod…
Browse files Browse the repository at this point in the history
…el loading
  • Loading branch information
qin-yu committed Nov 19, 2024
1 parent c75a401 commit 4fcd3f7
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions tests/functionals/prediction/test_size_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
import torch

from plantseg.core.zoo import model_zoo
from plantseg.functionals.prediction.utils.size_finder import find_a_max_patch_shape, find_patch_and_halo_shapes
from plantseg.functionals.prediction.utils.size_finder import (
find_a_max_patch_shape,
find_batch_size,
find_patch_and_halo_shapes,
)

IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" # set to true in GitHub Actions by default to skip CUDA tests
DOWNLOAD_MODELS = os.getenv("DOWNLOAD_MODELS") == "true" # set to false in locall testing to skip downloading models
LARGE_VRAM_GPUS = ['NVIDIA A100', 'NVIDIA A40']
MAX_PATCH_SHAPES = {
'generic_confocal_3D_unet': {
"NVIDIA GeForce RTX 2080 Ti": (208, 208, 208),
Expand Down Expand Up @@ -51,7 +57,17 @@ def test_find_patch_and_halo_shapes(full_volume_shape, max_patch_shape, min_halo
)
@pytest.mark.parametrize("model_name", MAX_PATCH_SHAPES.keys())
def test_find_patch_shape(model_name):
model, _, _ = model_zoo.get_model_by_name(model_name, model_update=False)
model, _, _ = model_zoo.get_model_by_name(model_name, model_update=DOWNLOAD_MODELS)

Check warning on line 60 in tests/functionals/prediction/test_size_finder.py

View check run for this annotation

Codecov / codecov/patch

tests/functionals/prediction/test_size_finder.py#L60

Added line #L60 was not covered by tests
found_patch_shape = find_a_max_patch_shape(model, 1, "cuda:0")
expected_patch_shape = MAX_PATCH_SHAPES[model_name][GPU_DEVICE_NAME]
assert found_patch_shape == expected_patch_shape


@pytest.mark.skipif(
not any(gpu in GPU_DEVICE_NAME for gpu in LARGE_VRAM_GPUS),
reason="Test requires a large VRAM device (e.g., NVIDIA A100 or NVIDIA A40).",
)
def test_find_batch_size_error_handling():
model, _, _ = model_zoo.get_model_by_name('confocal_3D_unet_ovules_ds3x', model_update=DOWNLOAD_MODELS)
batch_size = find_batch_size(model, 1, (86, 395, 395), (0, 44, 44), "cuda:0")
assert batch_size == 1

Check warning on line 73 in tests/functionals/prediction/test_size_finder.py

View check run for this annotation

Codecov / codecov/patch

tests/functionals/prediction/test_size_finder.py#L71-L73

Added lines #L71 - L73 were not covered by tests

0 comments on commit 4fcd3f7

Please sign in to comment.