Predict


Script for running inference using a pre-trained residual U-Net model for image segmentation.

Usage:

pip install rhizonet predict_rhizonet –config_file ./setup_files/setup-predict.json

get_prediction(file: str, unet: collections.abc.Callable, pred_patch_size: Sequence[int], save_path: str, labels: Sequence[int], binary_preds: bool, frg_class: int)

Convert the prediction to a binary segmentation mask and saves the image in the save_path filepath specified in the configuration file.

Parameters
  • file (str) – input image filepath

  • unet (Callable) – trained callable model

  • pred_patch_size (Sequence[int]) – spatial window size for inference

  • save_path (str) – path in which the predictions will be saved

  • labels (Sequence[int]) – the labels used for annotating the groundtruth

  • binary_preds (bool) – generate binary predictions (e.g. root vs background) or keep all class labels

  • frg_class (int) – value of the foreground class when using binary segmentation masks

main()
pred_function(image: torch.Tensor, model: collections.abc.Callable, pred_patch_size: Sequence[int]) torch.Tensor

Sliding window inference on image with model

Parameters
  • image (torch.Tensor) – input image to be processed

  • model (Callable) – given input tensor image in shape BCHW[D], the outputs of the function call model(input) should be a tensor.

  • pred_patch_size (Sequence[int]) – spatial window size for inference

Returns

prediction tensor

Return type

torch.Tensor

predict_model(args: Dict)

Compile all functions above to run inference on a list of images

Parameters

args (Dict) – arguments specified in the configuration file

predict_step(image_path: str, model: collections.abc.Callable, pred_patch_size: Sequence[int]) torch.Tensor

Call trained model and run inference on input image given by the filepath using monai’s sliding_window_inference function.

Parameters
  • image_path (str) – filepath of the input image to be processed

  • model (Callable) – given input tensor image in shape BCHW[D], the outputs of the function call model(input) should be a tensor.

  • pred_patch_size (Sequence[int]) – spatial window size for inference

Returns

prediction obtained by:
  • using argmax (computes maximum value along the class dimension)

  • casting the tensor to torch.uint8 (byte) and scaling to 255 for visualization

Return type

torch.Tensor

transform_image(img_path: str) Tuple[numpy.ndarray, str]

Reads the filepath and returns the image in the correct shape for inference (C, H, W)

Parameters

img_path (str) – Filepath of the input image

Returns

Image in the correct shape, Filepath of the image

Return type

Tuple[np.ndarray, str]