Skip to content

Commit

Permalink
start adding the patch specs
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 15, 2023
1 parent 8227e7c commit f30c8f7
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 43 deletions.
13 changes: 9 additions & 4 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tiling.PatchGridCalculator;
import io.bioimage.modelrunner.tiling.PatchSpec;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import net.imglib2.RandomAccessibleInterval;
Expand Down Expand Up @@ -453,16 +455,19 @@ public void runModel( List< Tensor < ? > > inTensors, List< Tensor < ? > > outTe
* @param <R>
* @param inputImgs
* @return
* @throws ValidationException
* @throws Exception
*/
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
List<Img<T>> runBioimageioModelOnImgLib2WithTiling(List<RandomAccessibleInterval<R>> inputImgs) throws Exception {
if (descriptor == null && modelFolder == null) {
List<Img<T>> runBioimageioModelOnImgLib2WithTiling(List<RandomAccessibleInterval<R>> inputImgs) throws ValidationException {
if (descriptor == null && modelFolder == null)
throw new IllegalArgumentException("");
} else if (descriptor == null && !(new File(modelFolder, Constants.RDF_FNAME).isFile())) {
else if (descriptor == null && !(new File(modelFolder, Constants.RDF_FNAME).isFile()))
throw new IllegalArgumentException("");
} else if (descriptor == null)
else if (descriptor == null)
descriptor = ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME);
PatchGridCalculator tileGrid = PatchGridCalculator.build(descriptor, inputImgs);
List<PatchSpec> specs = tileGrid.call();
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,29 @@
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.utils.Constants;
import net.imglib2.IterableInterval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.NumericType;
import net.imglib2.type.numeric.RealType;

/**
* A calculator for the size of each patch and the patch grid associated to input images when applying a given TensorFlow model.
* A calculator for the size of each patch and the patch grid associated
* to input images when applying a given Bioimage.io model.
*
* @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
*/
public class PatchGridCalculator implements Callable<List<PatchSpec>>
public class PatchGridCalculator
{

private ModelDescriptor descriptor;
Expand All @@ -53,10 +56,27 @@ public class PatchGridCalculator implements Callable<List<PatchSpec>>
* @param descriptor
* the specifications of each input
* @param inputValuesMap
* mapt containing the input images associated to their input tensors
* map containing the input images associated to their input tensors
* @throws IllegalArgumentException if the {@link #inputValuesMap}
*/
private PatchGridCalculator(ModelDescriptor descriptor, Map<String, Object> inputValuesMap)
throws IllegalArgumentException
{
for (TensorSpec tt : descriptor.getInputTensors()) {
if (tt.isImage() && inputValuesMap.get(tt.getName()) == null)
throw new IllegalArgumentException("Model input tensor '" + tt.getName() + "' is specified in the rdf.yaml specs file "
+ "but cannot be found in the model inputs map provided.");
// TODO change isImage() by isTensor()
if (tt.isImage()
&& !(inputValuesMap.get(tt.getName()) instanceof RandomAccessibleInterval)
&& !(inputValuesMap.get(tt.getName()) instanceof IterableInterval)
&& !(inputValuesMap.get(tt.getName()) instanceof Tensor))
throw new IllegalArgumentException("Model input tensor '" + tt.getName() + "' is specified in the rdf.yaml specs file "
+ "as a tensor but. JDLL needs tensor to be specified either as JDLL tensors (io.bioimage.tensor.Tensor) "
+ "or ImgLib2 Imgs (net.imglib2.img.Img), ImgLib2 RandomAccessibleIntervals (net.imglib2.RandomAccessibleInterval) "
+ "or ImgLib2 IterableIntervals (net.imglib2.IterableInterval). However, input "
+ "'" + tt.getName() + "' is defined as: " + inputValuesMap.get(tt.getName()).getClass());
}
this.descriptor = descriptor;
this.inputValuesMap = inputValuesMap;
}
Expand Down Expand Up @@ -88,11 +108,39 @@ public static PatchGridCalculator build(String modelFolder, Map<String, Object>
* @param model
* model specs
* @param inputValuesMap
* mapt containing the input images associated to their input tensors
* map containing the input images associated to their input tensors
* @return the object that creates a list of patch specs for each tensor
* @throws IllegalArgumentException if the inputs provided in the input values map does not correspond
* to the inputs defined in the inputs field of the rdf.yaml specs file.
*/
public static PatchGridCalculator build(ModelDescriptor model, Map<String, Object> inputValuesMap) {
return new PatchGridCalculator(model, inputValuesMap);
public static PatchGridCalculator build(ModelDescriptor model, LinkedHashMap<String, Object> inputValuesMap)
throws IllegalArgumentException {
return new PatchGridCalculator(model, inputValuesMap);
}

/**
* Create the patch specifications for the model specs
* @param <T>
* generic type of the possible ImgLibb2 datatypes that input images can have
* @param model
* model specs as defined in the rdf.yaml file
* @param inputValuesMap
* list of images that correspond to the model inputs specified in the rdf.yaml file.
* The images should be in the same order as the inputs in the rdf.yaml file. First image corresponds
* to the first input, second image to second output and so on.
* @return the object that creates a list of patch specs for each tensor
*/
public static <T extends NumericType<T> & RealType<T>>
PatchGridCalculator build(ModelDescriptor model, List<RandomAccessibleInterval<T>> inputImagesList) {
LinkedHashMap<String, Object> map = new LinkedHashMap<String, Object>();
if (inputImagesList.size() != model.getInputTensors().size())
throw new IllegalArgumentException("The size of the list containing the model input RandomAccessibleIntervals"
+ " was not the same size (" + inputImagesList.size() + ") as the number of "
+ "inputs to the model as defined in the rdf.yaml file(" + model.getInputTensors().size() + ").");
int c = 0;
for (TensorSpec tt : model.getInputTensors())
map.put(tt.getName(), inputImagesList.get(c ++));
return new PatchGridCalculator(model, map);
}

/**
Expand All @@ -102,11 +150,11 @@ public static PatchGridCalculator build(ModelDescriptor model, Map<String, Objec
* @throws IllegalArgumentException if one tensor that allows tiling needs more patches
* in any given axis than the others
*/
@Override
public List<PatchSpec> call() throws RuntimeException, IllegalArgumentException
public List<PatchSpec> call() throws IllegalArgumentException
{
List<TensorSpec> inputTensors = findInputImageTensorSpec();
List<Object> inputImages = findModelInputImages(inputTensors);
List<Object> inputImages = inputTensors.stream()
.map(k -> this.inputValuesMap.get(k)).collect(Collectors.toList());
List<PatchSpec> listPatchSpecs = computePatchSpecsForEveryTensor(inputTensors, inputImages);
// Check that the obtained patch specs are not going to cause errors
checkPatchSpecs(listPatchSpecs);
Expand Down Expand Up @@ -164,34 +212,6 @@ private List<TensorSpec> findInputImageTensorSpec()
return this.descriptor.getInputTensors().stream().filter(tr -> tr.isImage())
.collect(Collectors.toList());
}

/**
* Get the list of sequences that correspond to each of the tensors
* @param inputTensorSpec
* the list of input tensors
* @return the list of input images
* @throws NoSuchElementException if there is an image missing for each of the input tensors
*/
private List<Object> findModelInputImages(List<TensorSpec> inputTensorSpec) throws NoSuchElementException
{
List<Object> seqList = inputTensorSpec.stream()
.filter(t -> inputValuesMap.get(t.getName()) != null)
.map(im -> inputValuesMap.get(im.getName()))
.collect(Collectors.toList());
if (seqList.size() != inputTensorSpec.size()) {
List<String> missing = inputTensorSpec.stream()
.filter(t -> inputValuesMap.get(t.getName()) == null)
.map(im -> im.getName())
.collect(Collectors.toList());
String errMsg = "Could not find any input Icy Sequence, Icy Tensor or NDArray for the following tensors:\n";
for (int i = 0; i < missing.size(); i ++) {
errMsg += " -" + missing.get(i);
}
throw new NoSuchElementException(errMsg);
}

return seqList;
}

/**
* Create list of patch specifications for every tensor aking into account the
Expand Down

0 comments on commit f30c8f7

Please sign in to comment.