Skip to content

Commit

Permalink
test: add error handling tests for large GPU find_a_max_patch_shape()
Browse files Browse the repository at this point in the history
  • Loading branch information
qin-yu committed Nov 21, 2024
1 parent f18b9d5 commit a930ed8
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions tests/functionals/prediction/test_size_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@

IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" # set to true in GitHub Actions by default to skip CUDA tests
DOWNLOAD_MODELS = os.getenv("DOWNLOAD_MODELS") == "true" # set to false in locall testing to skip downloading models
LARGE_VRAM_GPUS = ['NVIDIA A100', 'NVIDIA A40']
LARGE_VRAM_GPUS = ['NVIDIA A100', 'NVIDIA A40'] # these two are not full names because A100 has multiple models
ALL_TESTED_GPUS = ["NVIDIA GeForce RTX 2080 Ti", "NVIDIA GeForce RTX 3090", "NVIDIA A100-PCIE-40GB", "NVIDIA A40"]
MAX_PATCH_SHAPES = {
'generic_confocal_3D_unet': {
"NVIDIA GeForce RTX 2080 Ti": (208, 208, 208),
"NVIDIA GeForce RTX 3090": (256, 256, 256),
"NVIDIA A100-PCIE-40GB": (272, 272, 272),
"NVIDIA A40": (272, 272, 272),
},
'confocal_2D_unet_ovules_ds2x': {
"NVIDIA GeForce RTX 2080 Ti": (1, 1920, 1920), # (1, 2048, 2048) if search step is 1.
"NVIDIA GeForce RTX 3090": (1, 2880, 2880), # (1, 2960, 2960) if search step is 1.
"NVIDIA A100-PCIE-40GB": (1, 3200, 3200),
"NVIDIA A40": (1, 3200, 3200),
},
}

Expand Down Expand Up @@ -52,7 +57,7 @@ def test_find_patch_and_halo_shapes(full_volume_shape, max_patch_shape, min_halo


@pytest.mark.skipif(
GPU_DEVICE_NAME not in ["NVIDIA GeForce RTX 2080 Ti", "NVIDIA GeForce RTX 3090"],
GPU_DEVICE_NAME not in ALL_TESTED_GPUS,
reason="Measured devices are not available.",
)
@pytest.mark.parametrize("model_name", MAX_PATCH_SHAPES.keys())
Expand All @@ -69,5 +74,20 @@ def test_find_patch_shape(model_name):
)
def test_find_batch_size_error_handling():
model, _, _ = model_zoo.get_model_by_name('confocal_3D_unet_ovules_ds3x', model_update=DOWNLOAD_MODELS)
batch_size = find_batch_size(model, 1, (86, 395, 395), (0, 44, 44), "cuda:0")
assert batch_size == 1
found_batch_size = find_batch_size(model, 1, (86, 395, 395), (0, 44, 44), "cuda:0")
assert found_batch_size == 1


@pytest.mark.skipif(
not any(gpu in GPU_DEVICE_NAME for gpu in LARGE_VRAM_GPUS),
reason="Test requires a large VRAM device (e.g., NVIDIA A100 or NVIDIA A40).",
)
def test_find_patch_shape_error_handling():
model, _, _ = model_zoo.get_model_by_name('PlantSeg_3Dnuc_platinum', model_update=DOWNLOAD_MODELS)
found_patch_shape = find_a_max_patch_shape(model, 1, "cuda:0")
if 'NVIDIA A100-PCIE-40GB' == GPU_DEVICE_NAME:
print('NVIDIA A100-PCIE-40GB tested')
assert found_patch_shape == (352, 352, 352)
if 'NVIDIA A40' == GPU_DEVICE_NAME:
print('NVIDIA A40 tested')
assert found_patch_shape == (352, 352, 352)

0 comments on commit a930ed8

Please sign in to comment.