Train Model

Script for training and evaluation a residual U-Net model for image segmentation. The training process of RhizoNet includes logging, checkpointing and metrics evaluation.

Dependencies: - PyTorch Lightning - MONAI - Scikit-image - wandb

Usage:

pip install rhizonet train_rhizonet –config_file ./setup_files/setup-train.json –gpus 2 –strategy “ddp” –accelerator “gpu”

main()
train_model(args)

Train and evaluate RhizoNet on a specified dataset.

Parameters

args (Namespace) – Command-line containing: - config_file (str): Path to the JSON Configuration of the model. - gpus (int): Number of gpu nodes to use training. - strategy (str): Strategy to use for training (e.g., ‘ddp’, ‘dp’) - accelerator (str): cpu or gpu training

Returns: None

Notes

  • The evaluation results are saved to the file specified by ‘save_path’ in the configuration file

  • Training and validation metrics are also available in the WandDB project ‘rhizonet’

  • Predictions associated to the full size images specified in ‘pred_data_dir’ are generated after training and saved in the ‘save_path’ directory.

  • Metrics (accuracy, precision, recall and IOU) are evaluated on full size images specified in ‘pred_data_dir’ and results are saved in a metrics.json files

Example::
Run this script using the following command-line if 2 GPU nodes available:

python train.py –gpus 2 –strategy “ddp” –config_file “./setup_files/setup-train.json”

Run this script using the following command-line if 1 GPU node available:

python train.py –gpus 1 –strategy “dp” –config_file “./setup_files/setup-train.json”