Evaluate metrics


Script for evaluation metrics on a given set of predictions and groundtruth. The metrics used are Accuracy, Precision, Recall and IOU.

Usage:

pip install rhizonet evalmetrics_rhizonet —pred_path “path” –label_path “path” –log_dir “path” –task “binary” –num_classes “2” –frg_class 85

calculate_all_metrics(pred: torch.Tensor, target: torch.Tensor, task: str = 'binary', num_classes: int = 2, background_index: int = 0) Tuple[float, float, float, float]

Evaluate Accuracy, Precision, Recall and IOU based on prediction and groundtruth and for a given ask (e.g. binary).

Parameters
  • pred (torch.Tensor) – predicted image

  • target (torch.Tensor) – groundtruth image

  • task (str, optional) – type of classification for each pixel (binary or multi-class). Defaults to ‘binary’.

  • num_classes (int, optional) – number of classes. Defaults to 3.

  • background_index (int, optional) – index value associated to the background. Defaults to 0.

Returns

averaged accuracy, precision, recall and IOU

Return type

Tuple[float, float, float, float]

evaluate(pred_path: str, label_path: str, log_dir: str, task: str, num_classes: int, frg_class: int = 255) None

Reads the prediction and groundtruth images, evaluates the metrics Accuracy, Precision, Recall and IOU. Saves results in metrics.json file in the specified log_dir. There are 2 options: - evaluate metrics on multi-class prediction. - evaluate metrics on binary segmentation mask (e.g. if the prediction is processed into a binary mask with root as foreground and the rest labeled as background.)

Parameters
  • pred_path (str) – filepath of the predicted image

  • label_path (str) – filepath of the groundtruth image

  • log_dir (str) – filepath where results will be saved in a json file

  • num_classes (int) – number of class labels

  • task (int) – type of segmentation task if processing binary segmentation masks or multiclass segmentation images (e.g. binary or multiclass)

  • frg_class (int) – value of the foreground class when creating binary segmentation masks

main()