Skip to content

Commit

Permalink
feat(pred)!: use bioimageio.core for BioImage.IO Model Zoo model in…
Browse files Browse the repository at this point in the history
…ference
  • Loading branch information
qin-yu committed Dec 14, 2024
1 parent 3656496 commit c0ceb3a
Showing 1 changed file with 58 additions and 1 deletion.
59 changes: 58 additions & 1 deletion plantseg/functionals/prediction/prediction.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import logging
from pathlib import Path
from typing import assert_never

import numpy as np
import torch
from bioimageio.core.axis import AxisId
from bioimageio.core.prediction import predict
from bioimageio.core.sample import Sample
from bioimageio.core.tensor import Tensor
from bioimageio.spec import load_model_description
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import TensorId

from plantseg.core.zoo import model_zoo
from plantseg.functionals.dataprocessing.dataprocessing import ImageLayout, fix_layout_to_CZYX, fix_layout_to_ZYX
Expand All @@ -16,6 +24,52 @@
logger = logging.getLogger(__name__)


def biio_prediction(
raw: np.ndarray,
input_layout: ImageLayout,
model_id: str,
) -> np.ndarray:
model = load_model_description(model_id)
if isinstance(model, v0_4.ModelDescr):
input_ids = [input_tensor.name for input_tensor in model.inputs]
elif isinstance(model, v0_5.ModelDescr):
input_ids = [input_tensor.id for input_tensor in model.inputs]
else:
assert_never(model)

if len(input_ids) < 1:
logger.error("Model needs no input tensor.")
if len(input_ids) > 1:
logger.warning("Model needs more than one input tensor. PlantSeg does not support this yet.")
tensor_id = input_ids[0]

logger.info(f"model expects these inputs: {input_ids}")

assert isinstance(input_layout, str)
dims = tuple(
'channel' if item.lower() == 'c' else item.lower() for item in input_layout
) # `AxisId` has to be "channel" not "c"
sample = Sample(
members={
TensorId(tensor_id): Tensor(array=raw, dims=dims).transpose(
[AxisId(a) if isinstance(a, str) else a.id for a in model.inputs[0].axes]
)
},
stat={},
id="raw",
)

sample_out = predict(model=model, inputs=sample)
assert isinstance(sample_out, Sample)
if len(sample_out.members) != 1:
logger.warning("Model has more than one output tensor. PlantSeg does not support this yet.")
key = list(sample_out.members.keys())[0]
pmaps = sample_out.members[key].data.to_numpy()[0]
assert pmaps.ndim == 4, f"Expected 4D CZXY prediction from `biio_prediction()`, got {pmaps.ndim}D"

return pmaps


def unet_prediction(
raw: np.ndarray,
input_layout: ImageLayout,
Expand Down Expand Up @@ -61,7 +115,10 @@ def unet_prediction(
model, model_config, model_path = model_zoo.get_model_by_config_path(config_path, model_weights_path)
elif model_id is not None: # BioImage.IO zoo mode
logger.info("BioImage.IO prediction: Running model from BioImage.IO model zoo.")
model, model_config, model_path = model_zoo.get_model_by_id(model_id)
if True: # NOTE: For now, do not use native pytorch-3dunet prediction if using BioImage.IO models
return biio_prediction(raw=raw, input_layout=input_layout, model_id=model_id)
else:
model, model_config, model_path = model_zoo.get_model_by_id(model_id)
elif model_name is not None: # PlantSeg zoo mode
logger.info("Zoo prediction: Running model from PlantSeg official zoo.")
model, model_config, model_path = model_zoo.get_model_by_name(model_name, model_update=model_update)
Expand Down

0 comments on commit c0ceb3a

Please sign in to comment.