From 753278786873a519334b625ba3b72da44a17084b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 21 Nov 2024 14:55:11 +0000 Subject: [PATCH] :twisted_rightwards_arrows: Merge changes from #578 - Merge changes from #578 which will improve the performance of main branch. - It will also help simplify #578 --- .gitignore | 3 + tests/models/test_arch_mapde.py | 4 +- tests/models/test_arch_micronet.py | 2 +- tests/models/test_arch_nuclick.py | 3 +- tests/models/test_arch_sccnn.py | 17 +++-- tests/models/test_arch_unet.py | 5 +- tests/models/test_arch_vanilla.py | 11 ++-- tests/models/test_hovernet.py | 9 +-- tests/models/test_hovernetplus.py | 3 +- tests/test_annotation_stores.py | 11 +--- tests/test_annotation_tilerendering.py | 1 + tests/test_init.py | 2 +- tests/test_utils.py | 21 +----- tests/test_wsimeta.py | 1 - tiatoolbox/annotation/storage.py | 16 ++++- tiatoolbox/cli/common.py | 12 ++++ tiatoolbox/cli/patch_predictor.py | 8 +-- tiatoolbox/models/architecture/hovernet.py | 8 +-- .../models/architecture/hovernetplus.py | 8 +-- tiatoolbox/models/architecture/mapde.py | 8 +-- tiatoolbox/models/architecture/micronet.py | 8 +-- tiatoolbox/models/architecture/nuclick.py | 11 ++-- tiatoolbox/models/architecture/sccnn.py | 9 +-- tiatoolbox/models/architecture/unet.py | 8 +-- tiatoolbox/models/architecture/utils.py | 8 +-- tiatoolbox/models/architecture/vanilla.py | 66 +++++++++---------- tiatoolbox/models/engine/patch_predictor.py | 13 ++-- tiatoolbox/models/models_abc.py | 38 +++++++++-- whitelist.txt | 1 + 29 files changed, 173 insertions(+), 142 deletions(-) diff --git a/.gitignore b/.gitignore index 409fc1261..66c072da5 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,6 @@ ENV/ # vim/vi generated *.swp + +# output zarr generated +*.zarr diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index febcfbdec..61bfde817 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -45,7 +45,7 @@ def test_functionality(remote_sample: Callable) -> None: model = _load_mapde(name="mapde-conic") patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] - model = model.to(select_device(on_gpu=ON_GPU)) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + model = model.to() + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) output = model.postproc(output[0]) assert np.all(output[0:2] == [[19, 171], [53, 89]]) diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index cd4bd0833..e7aa23d5b 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -39,7 +39,7 @@ def test_functionality( model = model.to(map_location) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=map_location) output, _ = model.postproc(output[0]) assert np.max(np.unique(output)) == 46 diff --git a/tests/models/test_arch_nuclick.py b/tests/models/test_arch_nuclick.py index fda0c01a6..b84516125 100644 --- a/tests/models/test_arch_nuclick.py +++ b/tests/models/test_arch_nuclick.py @@ -10,6 +10,7 @@ from tiatoolbox.models import NuClick from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device ON_GPU = False @@ -53,7 +54,7 @@ def test_functional_nuclick( model = NuClick(num_input_channels=5, num_output_channels=1) pretrained = torch.load(weights_path, map_location="cpu") model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) postproc_masks = model.postproc( output, do_reconstruction=True, diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index b3dd94e50..2729d2b3a 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -5,9 +5,10 @@ import numpy as np import torch -from tiatoolbox import utils from tiatoolbox.models import SCCNN from tiatoolbox.models.architecture import fetch_pretrained_weights +from tiatoolbox.utils import env_detection +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -15,7 +16,7 @@ def _load_sccnn(name: str) -> torch.nn.Module: """Loads SCCNN model with specified weights.""" model = SCCNN() weights_path = fetch_pretrained_weights(name) - map_location = utils.misc.select_device(on_gpu=utils.env_detection.has_gpu()) + map_location = select_device(on_gpu=env_detection.has_gpu()) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) @@ -40,11 +41,19 @@ def test_functionality(remote_sample: Callable) -> None: ) batch = torch.from_numpy(patch)[None] model = _load_sccnn(name="sccnn-crchisto") - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) output = model.postproc(output[0]) assert np.all(output == [[8, 7]]) model = _load_sccnn(name="sccnn-conic") - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) output = model.postproc(output[0]) assert np.all(output == [[7, 8]]) diff --git a/tests/models/test_arch_unet.py b/tests/models/test_arch_unet.py index b0cbc6085..2ac231c7c 100644 --- a/tests/models/test_arch_unet.py +++ b/tests/models/test_arch_unet.py @@ -9,6 +9,7 @@ from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.models.architecture.unet import UNetModel +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader ON_GPU = False @@ -48,7 +49,7 @@ def test_functional_unet(remote_sample: Callable) -> None: model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3]) pretrained = torch.load(pretrained_weights, map_location="cpu") model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) _ = output[0] # run untrained network to test for architecture @@ -60,4 +61,4 @@ def test_functional_unet(remote_sample: Callable) -> None: encoder_levels=[32, 64], skip_type="concat", ) - _ = model.infer_batch(model, batch, on_gpu=ON_GPU) + _ = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index 29c76ab4e..a87424dfd 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -5,10 +5,11 @@ import torch from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel -from tiatoolbox.utils.misc import model_to +from tiatoolbox.models.models_abc import model_to ON_GPU = False RNG = np.random.default_rng() # Numpy Random Generator +device = "cuda" if ON_GPU else "cpu" def test_functional() -> None: @@ -43,8 +44,8 @@ def test_functional() -> None: try: for backbone in backbones: model = CNNModel(backbone, num_classes=1) - model_ = model_to(on_gpu=ON_GPU, model=model) - model.infer_batch(model_, samples, on_gpu=ON_GPU) + model_ = model_to(device=device, model=model) + model.infer_batch(model_, samples, device=device) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc @@ -70,8 +71,8 @@ def test_timm_functional() -> None: try: for backbone in backbones: model = TimmModel(backbone=backbone, num_classes=1, pretrained=False) - model_ = model_to(on_gpu=ON_GPU, model=model) - model.infer_batch(model_, samples, on_gpu=ON_GPU) + model_ = model_to(device=device, model=model) + model.infer_batch(model_, samples, device=device) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc diff --git a/tests/models/test_hovernet.py b/tests/models/test_hovernet.py index b2271ab4c..2567018b8 100644 --- a/tests/models/test_hovernet.py +++ b/tests/models/test_hovernet.py @@ -14,6 +14,7 @@ ResidualBlock, TFSamepaddingLayer, ) +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -34,7 +35,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_fast-pannuke") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -51,7 +52,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_fast-monusac") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -68,7 +69,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_original-consep") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -85,7 +86,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_original-kumar") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." diff --git a/tests/models/test_hovernetplus.py b/tests/models/test_hovernetplus.py index 96d0f9d23..1377fdd82 100644 --- a/tests/models/test_hovernetplus.py +++ b/tests/models/test_hovernetplus.py @@ -7,6 +7,7 @@ from tiatoolbox.models import HoVerNetPlus from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device from tiatoolbox.utils.transforms import imresize @@ -28,7 +29,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernetplus-oed") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches." output = [v[0] for v in output] output = model.postproc(output) diff --git a/tests/test_annotation_stores.py b/tests/test_annotation_stores.py index 01bbdac45..66c990161 100644 --- a/tests/test_annotation_stores.py +++ b/tests/test_annotation_stores.py @@ -53,14 +53,6 @@ FILLED_LEN = 2 * (GRID_SIZE[0] * GRID_SIZE[1]) RNG = np.random.default_rng(0) # Numpy Random Generator -# ---------------------------------------------------------------------- -# Resets -# ---------------------------------------------------------------------- - -# Reset filters in logger. -for filter_ in logger.filters: - logger.removeFilter(filter_) - # ---------------------------------------------------------------------- # Helper Functions # ---------------------------------------------------------------------- @@ -546,6 +538,9 @@ def test_sqlite_store_compile_options_missing_math( caplog: pytest.LogCaptureFixture, ) -> None: """Test that a warning is shown if the sqlite math module is missing.""" + # Reset filters in logger. + for filter_ in logger.filters[:]: + logger.removeFilter(filter_) monkeypatch.setattr( SQLiteStore, "compile_options", diff --git a/tests/test_annotation_tilerendering.py b/tests/test_annotation_tilerendering.py index 0734b9164..fbe8239ca 100644 --- a/tests/test_annotation_tilerendering.py +++ b/tests/test_annotation_tilerendering.py @@ -462,6 +462,7 @@ def test_function_mapper(fill_store: Callable, tmp_path: Path) -> None: _, store = fill_store(SQLiteStore, tmp_path / "test.db") def color_fn(props: dict[str, str]) -> tuple[int, int, int]: + """Tests Red for cells, otherwise green.""" # simple test function that returns red for cells, otherwise green. if props["type"] == "cell": return 1, 0, 0 diff --git a/tests/test_init.py b/tests/test_init.py index 509a9c49f..6d8ed8238 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -114,7 +114,7 @@ def test_duplicate_filter(caplog: pytest.LogCaptureFixture) -> None: logger.addFilter(duplicate_filter) # Reset filters in logger. - for filter_ in logger.filters: + for filter_ in logger.filters[:]: logger.removeFilter(filter_) for _ in range(2): diff --git a/tests/test_utils.py b/tests/test_utils.py index fe18e0d36..4df80ba78 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1336,24 +1336,6 @@ def test_select_device() -> None: assert device == "cpu" -def test_model_to() -> None: - """Test for placing model on device.""" - import torchvision.models as torch_models - from torch import nn - - # Test on GPU - # no GPU on Travis so this will crash - if not utils.env_detection.has_gpu(): - model = torch_models.resnet18() - with pytest.raises((AssertionError, RuntimeError)): - _ = misc.model_to(on_gpu=True, model=model) - - # Test on CPU - model = torch_models.resnet18() - model = misc.model_to(on_gpu=False, model=model) - assert isinstance(model, nn.Module) - - def test_save_as_json(tmp_path: Path) -> None: """Test save data to json.""" # This should be broken up into separate tests! @@ -1666,6 +1648,7 @@ def test_patch_pred_store() -> None: """Test patch_pred_store.""" # Define a mock patch_output patch_output = { + "probabilities": [(0.99, 0.01), (0.01, 0.99), (0.99, 0.01)], "predictions": [1, 0, 1], "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], "other": "other", @@ -1700,7 +1683,7 @@ def test_patch_pred_store_cdict() -> None: class_dict = {0: "class0", 1: "class1"} store = misc.dict_to_store(patch_output, (1.0, 1.0), class_dict=class_dict) - # Check that its an SQLiteStore containing the expected annotations + # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) assert len(store) == 3 for annotation in store.values(): diff --git a/tests/test_wsimeta.py b/tests/test_wsimeta.py index bc3555e36..01b1cac8b 100644 --- a/tests/test_wsimeta.py +++ b/tests/test_wsimeta.py @@ -8,7 +8,6 @@ from tiatoolbox.wsicore import WSIMeta, wsimeta, wsireader -# noinspection PyTypeChecker def test_wsimeta_init_fail() -> None: """Test incorrect init for WSIMeta raises TypeError.""" with pytest.raises(TypeError): diff --git a/tiatoolbox/annotation/storage.py b/tiatoolbox/annotation/storage.py index 0cd476358..420e94085 100644 --- a/tiatoolbox/annotation/storage.py +++ b/tiatoolbox/annotation/storage.py @@ -2556,7 +2556,21 @@ def _unpack_wkb( cx: float, cy: float, ) -> bytes: - """Unpack WKB data.""" + """Return the geometry as bytes using WKB. + + Args: + data (bytes or str): + The WKB/WKT data to be unpacked. + cx (int): + The X coordinate of the centroid/representative point. + cy (float): + The Y coordinate of the centroid/representative point. + + Returns: + bytes: + The geometry as bytes. + + """ return ( self._decompress_data(data) if data diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 18e731b4c..2545e9576 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -234,6 +234,18 @@ def cli_pretrained_weights( ) +def cli_device( + usage_help: str = "Select the device (cpu/cuda/mps) to use for inference.", + default: str = "cpu", +) -> Callable: + """Enables --pretrained-weights option for cli.""" + return click.option( + "--device", + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + def cli_return_probabilities( usage_help: str = "Whether to return raw model probabilities.", *, diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index a97ecb571..069b6c367 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -6,13 +6,13 @@ from tiatoolbox.cli.common import ( cli_batch_size, + cli_device, cli_file_type, cli_img_input, cli_masks, cli_merge_predictions, cli_mode, cli_num_loader_workers, - cli_on_gpu, cli_output_path, cli_pretrained_model, cli_pretrained_weights, @@ -45,7 +45,7 @@ @cli_return_probabilities(default=False) @cli_merge_predictions(default=True) @cli_return_labels(default=True) -@cli_on_gpu(default=False) +@cli_device(default="cpu") @cli_batch_size(default=1) @cli_resolution(default=0.5) @cli_units(default="mpp") @@ -64,11 +64,11 @@ def patch_predictor( resolution: float, units: str, num_loader_workers: int, + device: str, *, return_probabilities: bool, return_labels: bool, merge_predictions: bool, - on_gpu: bool, verbose: bool, ) -> None: """Process an image/directory of input images with a patch classification CNN.""" @@ -100,7 +100,7 @@ def patch_predictor( return_labels=return_labels, resolution=resolution, units=units, - on_gpu=on_gpu, + device=device, save_dir=output_path, save_output=True, ) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 8f061d273..1d1cd86e3 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -20,7 +20,6 @@ centre_crop_to_shape, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc from tiatoolbox.utils.misc import get_bounding_box @@ -785,7 +784,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: return pred_inst, nuc_inst_info_dict @staticmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple: + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o @@ -797,8 +796,8 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: tuple: @@ -810,7 +809,6 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 87db17295..acee0106f 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -13,7 +13,6 @@ from tiatoolbox.models.architecture.hovernet import HoVerNet from tiatoolbox.models.architecture.utils import UpSample2x -from tiatoolbox.utils import misc class HoVerNetPlus(HoVerNet): @@ -325,7 +324,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple: return pred_inst, nuc_inst_info_dict, pred_layer, layer_info_dict @staticmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple: + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o @@ -337,13 +336,12 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/mapde.py b/tiatoolbox/models/architecture/mapde.py index a7156531f..bbb468bb8 100644 --- a/tiatoolbox/models/architecture/mapde.py +++ b/tiatoolbox/models/architecture/mapde.py @@ -14,7 +14,6 @@ from skimage.feature import peak_local_max from tiatoolbox.models.architecture.micronet import MicroNet -from tiatoolbox.utils.misc import select_device class MapDe(MicroNet): @@ -259,7 +258,7 @@ def infer_batch( model: torch.nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -272,8 +271,8 @@ def infer_batch( batch_data (:class:`numpy.ndarray`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list(np.ndarray): @@ -282,7 +281,6 @@ def infer_batch( """ patch_imgs = batch_data - device = select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/micronet.py b/tiatoolbox/models/architecture/micronet.py index bfc62c8ab..bbc455e84 100644 --- a/tiatoolbox/models/architecture/micronet.py +++ b/tiatoolbox/models/architecture/micronet.py @@ -19,7 +19,6 @@ from tiatoolbox.models.architecture.hovernet import HoVerNet from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc def group1_forward_branch( @@ -629,7 +628,7 @@ def infer_batch( model: torch.nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -642,8 +641,8 @@ def infer_batch( batch_data (:class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list(np.ndarray): @@ -652,7 +651,6 @@ def infer_batch( """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/nuclick.py b/tiatoolbox/models/architecture/nuclick.py index 339777eb1..77f4ad993 100644 --- a/tiatoolbox/models/architecture/nuclick.py +++ b/tiatoolbox/models/architecture/nuclick.py @@ -22,7 +22,6 @@ from tiatoolbox import logger from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc if TYPE_CHECKING: # pragma: no cover from tiatoolbox.typing import IntPair @@ -647,7 +646,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> np.ndarray: """Run inference on an input batch. @@ -656,16 +655,16 @@ def infer_batch( Args: model (nn.Module): PyTorch defined model. - batch_data (torch.Tensor): a batch of data generated by - torch.utils.data.DataLoader. - on_gpu (bool): Whether to run inference on a GPU. + batch_data (torch.Tensor): + A batch of data generated by torch.utils.data.DataLoader. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: Pixel-wise nuclei prediction for each patch, shape: (no.patch, h, w). """ model.eval() - device = misc.select_device(on_gpu=on_gpu) # Assume batch_data is NCHW batch_data = batch_data.to(device).type(torch.float32) diff --git a/tiatoolbox/models/architecture/sccnn.py b/tiatoolbox/models/architecture/sccnn.py index 9941eabff..4da0f9dca 100644 --- a/tiatoolbox/models/architecture/sccnn.py +++ b/tiatoolbox/models/architecture/sccnn.py @@ -17,7 +17,6 @@ from torch import nn from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc class SCCNN(ModelABC): @@ -354,8 +353,7 @@ def postproc(self: SCCNN, prediction_map: np.ndarray) -> np.ndarray: def infer_batch( model: nn.Module, batch_data: np.ndarray | torch.Tensor, - *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -368,8 +366,8 @@ def infer_batch( batch_data (:class:`numpy.ndarray` or :class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list of :class:`numpy.ndarray`: @@ -378,7 +376,6 @@ def infer_batch( """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/unet.py b/tiatoolbox/models/architecture/unet.py index fe1a97cc9..6385e7587 100644 --- a/tiatoolbox/models/architecture/unet.py +++ b/tiatoolbox/models/architecture/unet.py @@ -12,7 +12,6 @@ from tiatoolbox.models.architecture.utils import UpSample2x, centre_crop from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc class ResNetEncoder(ResNet): @@ -416,7 +415,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list: """Run inference on an input batch. @@ -429,8 +428,8 @@ def infer_batch( batch_data (:class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list: @@ -439,7 +438,6 @@ def infer_batch( """ model.eval() - device = misc.select_device(on_gpu=on_gpu) #### imgs = batch_data diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index 9df4dd56f..2ec47d99d 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -3,7 +3,7 @@ from __future__ import annotations import sys -from typing import Callable, NoReturn +from typing import NoReturn import numpy as np import torch @@ -41,7 +41,7 @@ def compile_model( model: nn.Module | None = None, *, mode: str = "default", -) -> Callable: +) -> nn.Module: """A decorator to compile a model using torch-compile. Args: @@ -60,7 +60,7 @@ def compile_model( CUDA graphs Returns: - Callable: + torch.nn.Module: Compiled model. """ @@ -71,7 +71,7 @@ def compile_model( is_torch_compile_compatible() # This check will be removed when torch.compile is supported in Python 3.12+ - if sys.version_info >= (3, 12): # pragma: no cover + if sys.version_info > (3, 12): # pragma: no cover logger.warning( ("torch-compile is currently not supported in Python 3.12+. ",), ) diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index 4879ce04c..cb487ec53 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -11,7 +11,6 @@ from torch import nn from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import select_device if TYPE_CHECKING: # pragma: no cover from torchvision.models import WeightsEnum @@ -149,9 +148,8 @@ def _postproc(image: np.ndarray) -> np.ndarray: def _infer_batch( model: nn.Module, batch_data: torch.Tensor, - *, - on_gpu: bool, -) -> np.ndarray: + device: str, +) -> dict[str, np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -162,11 +160,11 @@ def _infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( + img_patches_device = batch_data.to(device=device).type( torch.float32, ) # to NCHW img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() @@ -177,7 +175,7 @@ def _infer_batch( with torch.inference_mode(): output = model(img_patches_device) # Output should be a single tensor or scalar - return output.cpu().numpy() + return {"probabilities": output.cpu().numpy()} class CNNModel(ModelABC): @@ -243,9 +241,8 @@ def postproc(image: np.ndarray) -> np.ndarray: def infer_batch( model: nn.Module, batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> np.ndarray: + device: str = "cpu", + ) -> dict[str, np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -256,11 +253,11 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - return _infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu) + return _infer_batch(model=model, batch_data=batch_data, device=device) class TimmModel(ModelABC): @@ -339,9 +336,8 @@ def postproc(image: np.ndarray) -> np.ndarray: def infer_batch( model: nn.Module, batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> np.ndarray: + device: str, + ) -> dict[str, np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -352,11 +348,11 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - return _infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu) + return _infer_batch(model=model, batch_data=batch_data, device=device) class CNNBackbone(ModelABC): @@ -425,9 +421,8 @@ def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor: def infer_batch( model: nn.Module, batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> list[np.ndarray]: + device: str, + ) -> list[dict[str, np.ndarray]]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -438,15 +433,15 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: - list[np.ndarray]: - list of numpy arrays. + list[dict[str, np.ndarray]]: + list of dictionary values with numpy arrays. """ - return [_infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu)] + return [_infer_batch(model=model, batch_data=batch_data, device=device)] class TimmBackbone(ModelABC): @@ -500,9 +495,8 @@ def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor: def infer_batch( model: nn.Module, batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> list[np.ndarray]: + device: str, + ) -> list[dict[str, np.ndarray]]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -513,12 +507,12 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: - list[np.ndarray]: - list of numpy arrays. + list[dict[str, np.ndarray]]: + list of dictionary values with numpy arrays. """ - return [_infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu)] + return [_infer_batch(model=model, batch_data=batch_data, device=device)] diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index da4420cb0..e9859f14b 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -795,10 +795,10 @@ def predict( # noqa: PLR0913 stride_shape: tuple[int, int] | None = None, resolution: Resolution | None = None, units: Units = None, + device: str = "cpu", *, return_probabilities: bool = False, return_labels: bool = False, - on_gpu: bool = True, merge_predictions: bool = False, save_dir: str | Path | None = None, save_output: bool = False, @@ -830,8 +830,11 @@ def predict( # noqa: PLR0913 Whether to return per-class probabilities. return_labels (bool): Whether to return the labels with the predictions. - on_gpu (bool): - Whether to run model on the GPU. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". ioconfig (IOPatchPredictorConfig): Patch Predictor IO configuration. patch_input_shape (tuple): @@ -901,7 +904,7 @@ def predict( # noqa: PLR0913 labels, return_probabilities=return_probabilities, return_labels=return_labels, - on_gpu=on_gpu, + device=device, ) if not isinstance(imgs, list): @@ -948,7 +951,7 @@ def predict( # noqa: PLR0913 labels=labels, mode=mode, return_probabilities=return_probabilities, - on_gpu=on_gpu, + device=device, ioconfig=ioconfig, merge_predictions=merge_predictions, save_dir=save_dir, diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index e16540c87..6ab9c61d3 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -39,6 +39,28 @@ def output_resolutions(self: IOConfigABC) -> None: raise NotImplementedError +def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: + """Transfers model to cpu/gpu. + + Args: + model (torch.nn.Module): + PyTorch defined model. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + torch.nn.Module: + The model after being moved to cpu/gpu. + + """ + if device != "cpu": + # DataParallel work only for cuda + model = torch.nn.DataParallel(model) + + device = torch.device(device) + return model.to(device) + + class ModelABC(ABC, torch.nn.Module): """Abstract base class for models used in tiatoolbox.""" @@ -59,8 +81,7 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: def infer_batch( model: torch.nn.Module, batch_data: np.ndarray, - *, - on_gpu: bool, + device: str, ) -> None: """Run inference on an input batch. @@ -72,8 +93,13 @@ def infer_batch( batch_data (np.ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + dict: + Returns a dictionary of predictions and other expected outputs + depending on the network architecture. """ ... # pragma: no cover @@ -106,7 +132,7 @@ def preproc_func(self: ModelABC, func: Callable) -> None: >>> # `func` is a user defined function >>> model = ModelABC() >>> model.preproc_func = func - >>> transformed_img = model.preproc_func(img) + >>> transformed_img = model.preproc_func(image=np.ndarray) """ if func is not None and not callable(func): @@ -137,7 +163,7 @@ def postproc_func(self: ModelABC, func: Callable) -> None: >>> # `func` is a user defined function >>> model = ModelABC() >>> model.postproc_func = func - >>> transformed_img = model.postproc_func(img) + >>> transformed_img = model.postproc_func(image=np.ndarray) """ if func is not None and not callable(func): diff --git a/whitelist.txt b/whitelist.txt index 07a1b13c3..d1e723f26 100644 --- a/whitelist.txt +++ b/whitelist.txt @@ -96,6 +96,7 @@ coord coords csv cuda +customizable cv2 dataframe dataset