Skip to content

Commit

Permalink
refactor: support more formats of bioimageio axes specs
Browse files Browse the repository at this point in the history
Now not only axes with `.id` and `.size` attributes but also string axes such as 'bcyx' coming with `shape` at the same level.

These are due to the legacy of bioimage.io and complexity of versioning.

Tested with `pioneering-rhino`
  • Loading branch information
qin-yu committed Dec 18, 2024
1 parent bff023c commit d8eef33
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 38 deletions.
91 changes: 55 additions & 36 deletions plantseg/functionals/prediction/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,51 +50,70 @@ def biio_prediction(
dims = tuple(

Check warning on line 50 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L48-L50

Added lines #L48 - L50 were not covered by tests
AxisId('channel') if item.lower() == 'c' else AxisId(item.lower()) for item in input_layout
) # `AxisId` has to be "channel" not "c"
members = {
TensorId(tensor_id): Tensor(array=raw, dims=dims).transpose(
[AxisId(a) if isinstance(a, str) else a.id for a in axes]

if isinstance(axes[0], str): # then it's a <=0.4.10 model, `predict_sample_block` is not implemented
logger.warning(

Check warning on line 55 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L54-L55

Added lines #L54 - L55 were not covered by tests
"Model is older than 0.5.0. PlantSeg will try to run BioImage.IO core inference, but it is not supported by BioImage.IO core."
)
}
sample = Sample(members=members, stat={}, id="raw")

for a in axes:
if isinstance(a, str):
raise ValueError(f"Model has a string axis: {a}, please report issue to PlantSeg developers.")
sizes_in_rdf = {a.id: a.size for a in axes}
assert 'x' in sizes_in_rdf, "Model does not have 'x' axis in input tensor."
size_to_check = sizes_in_rdf[AxisId('x')]
if isinstance(size_to_check, int): # e.g. 'emotional-cricket'
# 'emotional-cricket' has {'batch': None, 'channel': 1, 'z': 100, 'y': 128, 'x': 128}
input_block_shape = {
TensorId(tensor_id): {
a.id: a.size if isinstance(a.size, int) else 1
axis_mapping = {'b': 'batch', 'c': 'channel'}
axes = [AxisId(axis_mapping.get(a, a)) for a in list(axes)]
members = {TensorId(tensor_id): Tensor(array=raw, dims=dims).transpose([AxisId(a) for a in axes])}
sample = Sample(members=members, stat={}, id="raw")
sample_out = predict(model=model, inputs=sample)

Check warning on line 62 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L58-L62

Added lines #L58 - L62 were not covered by tests

# If inference is supported by BioImage.IO core, this is how it should be done in PlantSeg:
#
# shape = model.inputs[0].shape
# input_block_shape = {TensorId(tensor_id): {AxisId(a): s for a, s in zip(axes, shape)}}
# sample_out = predict(model=model, inputs=sample, input_block_shape=input_block_shape)
else:
members = {

Check warning on line 70 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L70

Added line #L70 was not covered by tests
TensorId(tensor_id): Tensor(array=raw, dims=dims).transpose(
[AxisId(a) if isinstance(a, str) else a.id for a in axes]
)
}
sample = Sample(members=members, stat={}, id="raw")
sizes_in_rdf = {a.id: a.size for a in axes}
assert 'x' in sizes_in_rdf, "Model does not have 'x' axis in input tensor."
size_to_check = sizes_in_rdf[AxisId('x')]
if isinstance(size_to_check, int): # e.g. 'emotional-cricket'

Check warning on line 79 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L75-L79

Added lines #L75 - L79 were not covered by tests
# 'emotional-cricket' has {'batch': None, 'channel': 1, 'z': 100, 'y': 128, 'x': 128}
input_block_shape = {

Check warning on line 81 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L81

Added line #L81 was not covered by tests
TensorId(tensor_id): {
a.id: a.size if isinstance(a.size, int) else 1
for a in axes
if not isinstance(a, str) # for a.size/a.id type checking only
}
}
sample_out = predict(model=model, inputs=sample, input_block_shape=input_block_shape)
elif isinstance(size_to_check, v0_5.ParameterizedSize): # e.g. 'philosophical-panda'

Check warning on line 89 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L88-L89

Added lines #L88 - L89 were not covered by tests
# 'philosophical-panda' has:
# {'z': ParameterizedSize(min=1, step=1),
# 'channel': 2,
# 'y': ParameterizedSize(min=16, step=16),
# 'x': ParameterizedSize(min=16, step=16)}
blocksize_parameter = {

Check warning on line 95 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L95

Added line #L95 was not covered by tests
(TensorId(tensor_id), a.id): (
(96 - a.size.min) // a.size.step if isinstance(a.size, v0_5.ParameterizedSize) else 1
)
for a in axes
if not isinstance(a, str) # for a.size/a.id type checking only
}
}
sample_out = predict(model=model, inputs=sample, input_block_shape=input_block_shape)
elif isinstance(size_to_check, v0_5.ParameterizedSize): # e.g. 'philosophical-panda'
# 'philosophical-panda' has:
# {'z': ParameterizedSize(min=1, step=1),
# 'channel': 2,
# 'y': ParameterizedSize(min=16, step=16),
# 'x': ParameterizedSize(min=16, step=16)}
blocksize_parameter = {
(TensorId(tensor_id), a.id): (
(96 - a.size.min) // a.size.step if isinstance(a.size, v0_5.ParameterizedSize) else 1
)
for a in axes
if not isinstance(a, str) # for a.size/a.id type checking only
}
sample_out = predict(model=model, inputs=sample, blocksize_parameter=blocksize_parameter)
else:
assert_never(size_to_check)
sample_out = predict(model=model, inputs=sample, blocksize_parameter=blocksize_parameter)

Check warning on line 102 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L102

Added line #L102 was not covered by tests
else:
assert_never(size_to_check)

Check warning on line 104 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L104

Added line #L104 was not covered by tests

assert isinstance(sample_out, Sample)
if len(sample_out.members) != 1:
logger.warning("Model has more than one output tensor. PlantSeg does not support this yet.")

Check warning on line 108 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L106-L108

Added lines #L106 - L108 were not covered by tests

desired_axes_short = [AxisId(a) for a in ['b', 'c', 'z', 'y', 'x']]
desired_axes = [AxisId(a) for a in ['batch', 'channel', 'z', 'y', 'x']]
t = {i: o.transpose(desired_axes) for i, o in sample_out.members.items()}
t = {

Check warning on line 112 in plantseg/functionals/prediction/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/functionals/prediction/prediction.py#L110-L112

Added lines #L110 - L112 were not covered by tests
i: o.transpose(desired_axes_short) if 'b' in o.dims or 'c' in o.dims else o.transpose(desired_axes)
for i, o in sample_out.members.items()
}

named_pmaps = {}
for key, tensor_bczyx in t.items():
bczyx = tensor_bczyx.data.to_numpy()
Expand Down
3 changes: 1 addition & 2 deletions plantseg/tasks/prediction_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def biio_prediction_task(

new_images = []
for name, pmap in named_pmaps.items():

Check warning on line 89 in plantseg/tasks/prediction_tasks.py

View check run for this annotation

Codecov / codecov/patch

plantseg/tasks/prediction_tasks.py#L88-L89

Added lines #L88 - L89 were not covered by tests
# Input layout is always ZYX this loop
pmap = fix_layout(pmap, input_layout=input_layout, output_layout='CZYX')
# Input layout is always CZYX this loop
new_images.append(

Check warning on line 91 in plantseg/tasks/prediction_tasks.py

View check run for this annotation

Codecov / codecov/patch

plantseg/tasks/prediction_tasks.py#L91

Added line #L91 was not covered by tests
image.derive_new(
pmap,
Expand Down

0 comments on commit d8eef33

Please sign in to comment.