Predict
Script for running inference using a pre-trained residual U-Net model for image segmentation.
- Usage:
python predict.py –config_file ./setup_files/setup-predict.json
- get_prediction(file: str, unet: collections.abc.Callable, pred_patch_size: Sequence[int], save_path: str, labels: Sequence[int], binary_preds: bool)
Convert the prediction to a binary segmentation mask and saves the image in the
save_pathfilepath specified in the configuration file.- Parameters
file (str) – input image filepath
unet (Callable) – trained callable model
pred_patch_size (Sequence[int]) – spatial window size for inference
save_path (str) – path in which the predictions will be saved
labels (Sequence[int]) – the labels used for annotating the groundtruth
binary_preds (bool) – generate binary predictions (e.g. root vs background) or keep all class labels
- pred_function(image: torch.Tensor, model: collections.abc.Callable, pred_patch_size: Sequence[int]) torch.Tensor
Sliding window inference on image with model
- Parameters
image (torch.Tensor) – input image to be processed
model (Callable) – given input tensor
imagein shape BCHW[D], the outputs of the function callmodel(input)should be a tensor.pred_patch_size (Sequence[int]) – spatial window size for inference
- Returns
prediction tensor
- Return type
torch.Tensor
- predict_model(args: Dict)
Compile all functions above to run inference on a list of images
- Parameters
args (Dict) – arguments specified in the configuration file
- predict_step(image_path: str, model: collections.abc.Callable, pred_patch_size: Sequence[int]) torch.Tensor
Call trained model and run inference on input image given by the filepath using monai’s
sliding_window_inferencefunction.- Parameters
image_path (str) – filepath of the input image to be processed
model (Callable) – given input tensor
imagein shape BCHW[D], the outputs of the function callmodel(input)should be a tensor.pred_patch_size (Sequence[int]) – spatial window size for inference
- Returns
- prediction obtained by:
using argmax (computes maximum value along the class dimension)
casting the tensor to torch.uint8 (byte) and scaling to 255 for visualization
- Return type
torch.Tensor
- transform_image(img_path: str) Tuple[numpy.ndarray, str]
Reads the filepath and returns the image in the correct shape for inference (C, H, W)
- Parameters
img_path (str) – Filepath of the input image
- Returns
Image in the correct shape, Filepath of the image
- Return type
Tuple[np.ndarray, str]