From d8eef3338ba9b187839bd603593af58868c31b92 Mon Sep 17 00:00:00 2001 From: Qin Yu Date: Wed, 18 Dec 2024 03:10:14 +0100 Subject: [PATCH] refactor: support more formats of bioimageio axes specs Now not only axes with `.id` and `.size` attributes but also string axes such as 'bcyx' coming with `shape` at the same level. These are due to the legacy of bioimage.io and complexity of versioning. Tested with `pioneering-rhino` --- plantseg/functionals/prediction/prediction.py | 91 +++++++++++-------- plantseg/tasks/prediction_tasks.py | 3 +- 2 files changed, 56 insertions(+), 38 deletions(-) diff --git a/plantseg/functionals/prediction/prediction.py b/plantseg/functionals/prediction/prediction.py index 2a895df9..afa5d6d1 100644 --- a/plantseg/functionals/prediction/prediction.py +++ b/plantseg/functionals/prediction/prediction.py @@ -50,51 +50,70 @@ def biio_prediction( dims = tuple( AxisId('channel') if item.lower() == 'c' else AxisId(item.lower()) for item in input_layout ) # `AxisId` has to be "channel" not "c" - members = { - TensorId(tensor_id): Tensor(array=raw, dims=dims).transpose( - [AxisId(a) if isinstance(a, str) else a.id for a in axes] + + if isinstance(axes[0], str): # then it's a <=0.4.10 model, `predict_sample_block` is not implemented + logger.warning( + "Model is older than 0.5.0. PlantSeg will try to run BioImage.IO core inference, but it is not supported by BioImage.IO core." ) - } - sample = Sample(members=members, stat={}, id="raw") - - for a in axes: - if isinstance(a, str): - raise ValueError(f"Model has a string axis: {a}, please report issue to PlantSeg developers.") - sizes_in_rdf = {a.id: a.size for a in axes} - assert 'x' in sizes_in_rdf, "Model does not have 'x' axis in input tensor." - size_to_check = sizes_in_rdf[AxisId('x')] - if isinstance(size_to_check, int): # e.g. 'emotional-cricket' - # 'emotional-cricket' has {'batch': None, 'channel': 1, 'z': 100, 'y': 128, 'x': 128} - input_block_shape = { - TensorId(tensor_id): { - a.id: a.size if isinstance(a.size, int) else 1 + axis_mapping = {'b': 'batch', 'c': 'channel'} + axes = [AxisId(axis_mapping.get(a, a)) for a in list(axes)] + members = {TensorId(tensor_id): Tensor(array=raw, dims=dims).transpose([AxisId(a) for a in axes])} + sample = Sample(members=members, stat={}, id="raw") + sample_out = predict(model=model, inputs=sample) + + # If inference is supported by BioImage.IO core, this is how it should be done in PlantSeg: + # + # shape = model.inputs[0].shape + # input_block_shape = {TensorId(tensor_id): {AxisId(a): s for a, s in zip(axes, shape)}} + # sample_out = predict(model=model, inputs=sample, input_block_shape=input_block_shape) + else: + members = { + TensorId(tensor_id): Tensor(array=raw, dims=dims).transpose( + [AxisId(a) if isinstance(a, str) else a.id for a in axes] + ) + } + sample = Sample(members=members, stat={}, id="raw") + sizes_in_rdf = {a.id: a.size for a in axes} + assert 'x' in sizes_in_rdf, "Model does not have 'x' axis in input tensor." + size_to_check = sizes_in_rdf[AxisId('x')] + if isinstance(size_to_check, int): # e.g. 'emotional-cricket' + # 'emotional-cricket' has {'batch': None, 'channel': 1, 'z': 100, 'y': 128, 'x': 128} + input_block_shape = { + TensorId(tensor_id): { + a.id: a.size if isinstance(a.size, int) else 1 + for a in axes + if not isinstance(a, str) # for a.size/a.id type checking only + } + } + sample_out = predict(model=model, inputs=sample, input_block_shape=input_block_shape) + elif isinstance(size_to_check, v0_5.ParameterizedSize): # e.g. 'philosophical-panda' + # 'philosophical-panda' has: + # {'z': ParameterizedSize(min=1, step=1), + # 'channel': 2, + # 'y': ParameterizedSize(min=16, step=16), + # 'x': ParameterizedSize(min=16, step=16)} + blocksize_parameter = { + (TensorId(tensor_id), a.id): ( + (96 - a.size.min) // a.size.step if isinstance(a.size, v0_5.ParameterizedSize) else 1 + ) for a in axes if not isinstance(a, str) # for a.size/a.id type checking only } - } - sample_out = predict(model=model, inputs=sample, input_block_shape=input_block_shape) - elif isinstance(size_to_check, v0_5.ParameterizedSize): # e.g. 'philosophical-panda' - # 'philosophical-panda' has: - # {'z': ParameterizedSize(min=1, step=1), - # 'channel': 2, - # 'y': ParameterizedSize(min=16, step=16), - # 'x': ParameterizedSize(min=16, step=16)} - blocksize_parameter = { - (TensorId(tensor_id), a.id): ( - (96 - a.size.min) // a.size.step if isinstance(a.size, v0_5.ParameterizedSize) else 1 - ) - for a in axes - if not isinstance(a, str) # for a.size/a.id type checking only - } - sample_out = predict(model=model, inputs=sample, blocksize_parameter=blocksize_parameter) - else: - assert_never(size_to_check) + sample_out = predict(model=model, inputs=sample, blocksize_parameter=blocksize_parameter) + else: + assert_never(size_to_check) assert isinstance(sample_out, Sample) if len(sample_out.members) != 1: logger.warning("Model has more than one output tensor. PlantSeg does not support this yet.") + + desired_axes_short = [AxisId(a) for a in ['b', 'c', 'z', 'y', 'x']] desired_axes = [AxisId(a) for a in ['batch', 'channel', 'z', 'y', 'x']] - t = {i: o.transpose(desired_axes) for i, o in sample_out.members.items()} + t = { + i: o.transpose(desired_axes_short) if 'b' in o.dims or 'c' in o.dims else o.transpose(desired_axes) + for i, o in sample_out.members.items() + } + named_pmaps = {} for key, tensor_bczyx in t.items(): bczyx = tensor_bczyx.data.to_numpy() diff --git a/plantseg/tasks/prediction_tasks.py b/plantseg/tasks/prediction_tasks.py index 42107ae5..e84a7744 100644 --- a/plantseg/tasks/prediction_tasks.py +++ b/plantseg/tasks/prediction_tasks.py @@ -87,8 +87,7 @@ def biio_prediction_task( new_images = [] for name, pmap in named_pmaps.items(): - # Input layout is always ZYX this loop - pmap = fix_layout(pmap, input_layout=input_layout, output_layout='CZYX') + # Input layout is always CZYX this loop new_images.append( image.derive_new( pmap,